mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-05 20:23:20 -05:00
feat(audio): set audio content type (#8416)
* feat(audio): set audio content type Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: add tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
218d0526cb
commit
697f6aa71c
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/audio"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -51,7 +52,11 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
|
||||
filePath, contentType := audio.NormalizeAudioFile(filePath)
|
||||
if contentType != "" {
|
||||
c.Response().Header().Set("Content-Type", contentType)
|
||||
}
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/audio"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -39,6 +40,10 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filePath, contentType := audio.NormalizeAudioFile(filePath)
|
||||
if contentType != "" {
|
||||
c.Response().Header().Set("Content-Type", contentType)
|
||||
}
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/audio"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
|
||||
@@ -86,6 +85,10 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
return err
|
||||
}
|
||||
|
||||
filePath, contentType := audio.NormalizeAudioFile(filePath)
|
||||
if contentType != "" {
|
||||
c.Response().Header().Set("Content-Type", contentType)
|
||||
}
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -66,6 +66,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8 // indirect
|
||||
github.com/ghodss/yaml v1.0.0 // indirect
|
||||
github.com/labstack/gommon v0.4.2 // indirect
|
||||
github.com/swaggo/files/v2 v2.0.2 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -122,6 +122,8 @@ github.com/decred/dcrd/crypto/blake256 v1.1.0 h1:zPMNGQCm0g4QTY27fOCorQW7EryeQ/U
|
||||
github.com/decred/dcrd/crypto/blake256 v1.1.0/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
|
||||
github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8 h1:OtSeLS5y0Uy01jaKK4mA/WVIYtpzVm63vLVAPzJXigg=
|
||||
github.com/dhowden/tag v0.0.0-20240417053706-3d75831295e8/go.mod h1:apkPC/CR3s48O2D7Y++n1XWEpgPNNCjXYga3PPbJe2E=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||
|
||||
130
pkg/audio/identify.go
Normal file
130
pkg/audio/identify.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/dhowden/tag"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// extensionFromFileType returns the file extension for tag.FileType.
|
||||
func extensionFromFileType(ft tag.FileType) string {
|
||||
switch ft {
|
||||
case tag.FLAC:
|
||||
return "flac"
|
||||
case tag.MP3:
|
||||
return "mp3"
|
||||
case tag.OGG:
|
||||
return "ogg"
|
||||
case tag.M4A:
|
||||
return "m4a"
|
||||
case tag.M4B:
|
||||
return "m4b"
|
||||
case tag.M4P:
|
||||
return "m4p"
|
||||
case tag.ALAC:
|
||||
return "m4a"
|
||||
case tag.DSF:
|
||||
return "dsf"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// contentTypeFromFileType returns the MIME type for tag.FileType.
|
||||
func contentTypeFromFileType(ft tag.FileType) string {
|
||||
switch ft {
|
||||
case tag.FLAC:
|
||||
return "audio/flac"
|
||||
case tag.MP3:
|
||||
return "audio/mpeg"
|
||||
case tag.OGG:
|
||||
return "audio/ogg"
|
||||
case tag.M4A, tag.M4B, tag.M4P, tag.ALAC:
|
||||
return "audio/mp4"
|
||||
case tag.DSF:
|
||||
return "audio/dsd"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Identify reads from r and returns the detected audio extension and Content-Type.
|
||||
// It uses github.com/dhowden/tag to identify the format from the stream.
|
||||
// Returns ("", "", err) if the format could not be identified.
|
||||
func Identify(r io.ReadSeeker) (ext string, contentType string, err error) {
|
||||
_, fileType, err := tag.Identify(r)
|
||||
if err != nil || fileType == tag.UnknownFileType {
|
||||
return "", "", err
|
||||
}
|
||||
ext = extensionFromFileType(fileType)
|
||||
contentType = contentTypeFromFileType(fileType)
|
||||
if ext == "" || contentType == "" {
|
||||
return "", "", nil
|
||||
}
|
||||
return ext, contentType, nil
|
||||
}
|
||||
|
||||
// ContentTypeFromExtension returns the MIME type for common audio file extensions.
|
||||
// Use as a fallback when Identify fails or when the file is not openable.
|
||||
func ContentTypeFromExtension(path string) string {
|
||||
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
|
||||
switch ext {
|
||||
case "flac":
|
||||
return "audio/flac"
|
||||
case "mp3":
|
||||
return "audio/mpeg"
|
||||
case "wav":
|
||||
return "audio/wav"
|
||||
case "ogg":
|
||||
return "audio/ogg"
|
||||
case "m4a", "m4b", "m4p":
|
||||
return "audio/mp4"
|
||||
case "webm":
|
||||
return "audio/webm"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeAudioFile opens the file at path, identifies its format with tag.Identify,
|
||||
// and renames the file to have the correct extension if the current one does not match.
|
||||
// It returns the path to use (possibly the renamed file) and the Content-Type to set.
|
||||
// If identification fails, returns (path, ContentTypeFromExtension(path)).
|
||||
func NormalizeAudioFile(path string) (finalPath string, contentType string) {
|
||||
finalPath = path
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
contentType = ContentTypeFromExtension(path)
|
||||
return finalPath, contentType
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ext, ct, identifyErr := Identify(f)
|
||||
if identifyErr != nil || ext == "" || ct == "" {
|
||||
contentType = ContentTypeFromExtension(path)
|
||||
return finalPath, contentType
|
||||
}
|
||||
contentType = ct
|
||||
|
||||
currentExt := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
|
||||
if currentExt == ext {
|
||||
return finalPath, contentType
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
base := filepath.Base(path)
|
||||
baseNoExt := strings.TrimSuffix(base, filepath.Ext(base))
|
||||
if baseNoExt == "" {
|
||||
baseNoExt = base
|
||||
}
|
||||
newPath := filepath.Join(dir, baseNoExt+"."+ext)
|
||||
if renameErr := os.Rename(path, newPath); renameErr != nil {
|
||||
xlog.Debug("Could not rename audio file to match type", "from", path, "to", newPath, "error", renameErr)
|
||||
return finalPath, contentType
|
||||
}
|
||||
return newPath, contentType
|
||||
}
|
||||
@@ -120,7 +120,15 @@ func (m *MockBackend) GenerateVideo(ctx context.Context, in *pb.GenerateVideoReq
|
||||
|
||||
func (m *MockBackend) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
|
||||
xlog.Debug("TTS called", "text", in.Text)
|
||||
// Return success - actual audio would be in the Result message for real backends
|
||||
dst := in.GetDst()
|
||||
if dst != "" {
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0750); err != nil {
|
||||
return &pb.Result{Message: err.Error(), Success: false}, nil
|
||||
}
|
||||
if err := writeMinimalWAV(dst); err != nil {
|
||||
return &pb.Result{Message: err.Error(), Success: false}, nil
|
||||
}
|
||||
}
|
||||
return &pb.Result{
|
||||
Message: "TTS audio generated successfully (mocked)",
|
||||
Success: true,
|
||||
|
||||
@@ -75,23 +75,20 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||
Describe("TTS APIs", func() {
|
||||
Context("TTS", func() {
|
||||
It("should generate mocked audio", func() {
|
||||
req, err := http.NewRequest("POST", apiURL+"/audio/speech", nil)
|
||||
body := `{"model":"mock-model","input":"Hello world","voice":"default"}`
|
||||
req, err := http.NewRequest("POST", apiURL+"/audio/speech", io.NopCloser(strings.NewReader(body)))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
body := `{"model":"mock-model","input":"Hello world","voice":"default"}`
|
||||
req.Body = http.NoBody
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(strings.NewReader(body)), nil
|
||||
}
|
||||
|
||||
// Use direct HTTP client for TTS endpoint
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(BeNumerically("<", 500))
|
||||
}
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "TTS response should set an audio Content-Type")
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(data)).To(BeNumerically(">", 0), "TTS response body should be non-empty")
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -107,7 +104,11 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||
resp, err := httpClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(BeNumerically("<", 500))
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "sound-generation response should set an audio Content-Type (pkg/audio normalization)")
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(data)).To(BeNumerically(">", 0), "sound-generation response body should be non-empty")
|
||||
})
|
||||
|
||||
It("should generate mocked sound (advanced mode)", func() {
|
||||
@@ -120,7 +121,11 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||
resp, err := httpClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(BeNumerically("<", 500))
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
Expect(resp.Header.Get("Content-Type")).To(HavePrefix("audio/"), "sound-generation response should set an audio Content-Type (pkg/audio normalization)")
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(data)).To(BeNumerically(">", 0), "sound-generation response body should be non-empty")
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user