diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go index d634bf81d..fa1803649 100644 --- a/core/http/endpoints/elevenlabs/soundgeneration.go +++ b/core/http/endpoints/elevenlabs/soundgeneration.go @@ -8,6 +8,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) @@ -51,7 +52,11 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader if err != nil { return err } - return c.Attachment(filePath, filepath.Base(filePath)) + filePath, contentType := audio.NormalizeAudioFile(filePath) + if contentType != "" { + c.Response().Header().Set("Content-Type", contentType) + } + return c.Attachment(filePath, filepath.Base(filePath)) } } diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index 658eb56ba..ff859b04d 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -8,6 +8,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) @@ -39,6 +40,10 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig if err != nil { return err } + filePath, contentType := audio.NormalizeAudioFile(filePath) + if contentType != "" { + c.Response().Header().Set("Content-Type", contentType) + } return c.Attachment(filePath, filepath.Base(filePath)) } } diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 01bc1cd82..4e25cb138 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -7,12 +7,11 @@ import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" - "github.com/mudler/LocalAI/pkg/model" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/xlog" - + "github.com/mudler/LocalAI/pkg/audio" + "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/xlog" ) // TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech @@ -86,6 +85,10 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig return err } + filePath, contentType := audio.NormalizeAudioFile(filePath) + if contentType != "" { + c.Response().Header().Set("Content-Type", contentType) + } return c.Attachment(filePath, filepath.Base(filePath)) } } diff --git a/go.mod b/go.mod index cd02e3d5f..946c4f6dc 100644 --- a/go.mod +++ b/go.mod @@ -66,6 +66,7 @@ require ( ) require ( + github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8 // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/labstack/gommon v0.4.2 // indirect github.com/swaggo/files/v2 v2.0.2 // indirect diff --git a/go.sum b/go.sum index 9a5eafb1b..b59357bbd 100644 --- a/go.sum +++ b/go.sum @@ -122,6 +122,8 @@ github.com/decred/dcrd/crypto/blake256 v1.1.0 h1:zPMNGQCm0g4QTY27fOCorQW7EryeQ/U github.com/decred/dcrd/crypto/blake256 v1.1.0/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8 h1:OtSeLS5y0Uy01jaKK4mA/WVIYtpzVm63vLVAPzJXigg= +github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8/go.mod h1:apkPC/CR3s48O2D7Y++n1XWEpgPNNCjXYga3PPbJe2E= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= diff --git a/pkg/audio/identify.go b/pkg/audio/identify.go new file mode 100644 index 000000000..78a5ed4a5 --- /dev/null +++ b/pkg/audio/identify.go @@ -0,0 +1,130 @@ +package audio + +import ( + "io" + "os" + "path/filepath" + "strings" + + "github.com/dhowden/tag" + "github.com/mudler/xlog" +) + +// extensionFromFileType returns the file extension for tag.FileType. +func extensionFromFileType(ft tag.FileType) string { + switch ft { + case tag.FLAC: + return "flac" + case tag.MP3: + return "mp3" + case tag.OGG: + return "ogg" + case tag.M4A: + return "m4a" + case tag.M4B: + return "m4b" + case tag.M4P: + return "m4p" + case tag.ALAC: + return "m4a" + case tag.DSF: + return "dsf" + default: + return "" + } +} + +// contentTypeFromFileType returns the MIME type for tag.FileType. +func contentTypeFromFileType(ft tag.FileType) string { + switch ft { + case tag.FLAC: + return "audio/flac" + case tag.MP3: + return "audio/mpeg" + case tag.OGG: + return "audio/ogg" + case tag.M4A, tag.M4B, tag.M4P, tag.ALAC: + return "audio/mp4" + case tag.DSF: + return "audio/dsd" + default: + return "" + } +} + +// Identify reads from r and returns the detected audio extension and Content-Type. +// It uses github.com/dhowden/tag to identify the format from the stream. +// Returns ("", "", err) if the format could not be identified. +func Identify(r io.ReadSeeker) (ext string, contentType string, err error) { + _, fileType, err := tag.Identify(r) + if err != nil || fileType == tag.UnknownFileType { + return "", "", err + } + ext = extensionFromFileType(fileType) + contentType = contentTypeFromFileType(fileType) + if ext == "" || contentType == "" { + return "", "", nil + } + return ext, contentType, nil +} + +// ContentTypeFromExtension returns the MIME type for common audio file extensions. +// Use as a fallback when Identify fails or when the file is not openable. +func ContentTypeFromExtension(path string) string { + ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), ".")) + switch ext { + case "flac": + return "audio/flac" + case "mp3": + return "audio/mpeg" + case "wav": + return "audio/wav" + case "ogg": + return "audio/ogg" + case "m4a", "m4b", "m4p": + return "audio/mp4" + case "webm": + return "audio/webm" + default: + return "" + } +} + +// NormalizeAudioFile opens the file at path, identifies its format with tag.Identify, +// and renames the file to have the correct extension if the current one does not match. +// It returns the path to use (possibly the renamed file) and the Content-Type to set. +// If identification fails, returns (path, ContentTypeFromExtension(path)). +func NormalizeAudioFile(path string) (finalPath string, contentType string) { + finalPath = path + f, err := os.Open(path) + if err != nil { + contentType = ContentTypeFromExtension(path) + return finalPath, contentType + } + defer f.Close() + + ext, ct, identifyErr := Identify(f) + if identifyErr != nil || ext == "" || ct == "" { + contentType = ContentTypeFromExtension(path) + return finalPath, contentType + } + contentType = ct + + currentExt := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), ".")) + if currentExt == ext { + return finalPath, contentType + } + + dir := filepath.Dir(path) + base := filepath.Base(path) + baseNoExt := strings.TrimSuffix(base, filepath.Ext(base)) + if baseNoExt == "" { + baseNoExt = base + } + newPath := filepath.Join(dir, baseNoExt+"."+ext) + if renameErr := os.Rename(path, newPath); renameErr != nil { + xlog.Debug("Could not rename audio file to match type", "from", path, "to", newPath, "error", renameErr) + return finalPath, contentType + } + return newPath, contentType +} diff --git a/tests/e2e/mock-backend/main.go b/tests/e2e/mock-backend/main.go index 79a71fd5b..e94a7bf42 100644 --- a/tests/e2e/mock-backend/main.go +++ b/tests/e2e/mock-backend/main.go @@ -120,7 +120,15 @@ func (m *MockBackend) GenerateVideo(ctx context.Context, in *pb.GenerateVideoReq func (m *MockBackend) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { xlog.Debug("TTS called", "text", in.Text) - // Return success - actual audio would be in the Result message for real backends + dst := in.GetDst() + if dst != "" { + if err := os.MkdirAll(filepath.Dir(dst), 0750); err != nil { + return &pb.Result{Message: err.Error(), Success: false}, nil + } + if err := writeMinimalWAV(dst); err != nil { + return &pb.Result{Message: err.Error(), Success: false}, nil + } + } return &pb.Result{ Message: "TTS audio generated successfully (mocked)", Success: true, diff --git a/tests/e2e/mock_backend_test.go b/tests/e2e/mock_backend_test.go index 241fbae2b..6ba121792 100644 --- a/tests/e2e/mock_backend_test.go +++ b/tests/e2e/mock_backend_test.go @@ -75,23 +75,20 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { Describe("TTS APIs", func() { Context("TTS", func() { It("should generate mocked audio", func() { - req, err := http.NewRequest("POST", apiURL+"/audio/speech", nil) + body := `{"model":"mock-model","input":"Hello world","voice":"default"}` + req, err := http.NewRequest("POST", apiURL+"/audio/speech", io.NopCloser(strings.NewReader(body))) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") - body := `{"model":"mock-model","input":"Hello world","voice":"default"}` - req.Body = http.NoBody - req.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(strings.NewReader(body)), nil - } - - // Use direct HTTP client for TTS endpoint httpClient := &http.Client{Timeout: 30 * time.Second} resp, err := httpClient.Do(req) - if err == nil { - defer resp.Body.Close() - Expect(resp.StatusCode).To(BeNumerically("<", 500)) - } + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "TTS response should set an audio Content-Type") + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(len(data)).To(BeNumerically(">", 0), "TTS response body should be non-empty") }) }) }) @@ -107,7 +104,11 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { resp, err := httpClient.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() - Expect(resp.StatusCode).To(BeNumerically("<", 500)) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "sound-generation response should set an audio Content-Type (pkg/audio normalization)") + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(len(data)).To(BeNumerically(">", 0), "sound-generation response body should be non-empty") }) It("should generate mocked sound (advanced mode)", func() { @@ -120,7 +121,11 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { resp, err := httpClient.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() - Expect(resp.StatusCode).To(BeNumerically("<", 500)) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "sound-generation response should set an audio Content-Type (pkg/audio normalization)") + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(len(data)).To(BeNumerically(">", 0), "sound-generation response body should be non-empty") }) })