diff --git a/backend/go/crispasr/gocrispasr.go b/backend/go/crispasr/gocrispasr.go index dc21a28fd..5c3528d38 100644 --- a/backend/go/crispasr/gocrispasr.go +++ b/backend/go/crispasr/gocrispasr.go @@ -11,6 +11,7 @@ import ( "github.com/go-audio/audio" "github.com/go-audio/wav" + gguf "github.com/gpustack/gguf-parser-go" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/utils" @@ -37,6 +38,39 @@ var ( type CrispASR struct { base.SingleThread + // sampleRate is the output rate (Hz) of the loaded TTS engine's PCM, used to + // write a correct WAV header. Most CrispASR TTS backends emit 24 kHz, but + // piper returns its model's native rate (16 kHz for x_low/low voices, + // 22.05 kHz for medium/high), so it is read from the GGUF metadata at Load. + sampleRate int +} + +// defaultTTSSampleRate is the output rate assumed for CrispASR TTS engines that +// don't advertise one in GGUF metadata (vibevoice/orpheus/chatterbox/qwen3-tts +// all emit 24 kHz). piper is the exception and carries piper.sample_rate. +const defaultTTSSampleRate = 24000 + +// piperSampleRate reads the piper.sample_rate metadata key from a GGUF model. +// CrispASR's piper backend returns PCM at the model's native rate without +// resampling, so the WAV header must match it. Returns ok=false for non-piper +// models (key absent) or an unreadable file, letting the caller fall back to +// defaultTTSSampleRate. +func piperSampleRate(modelPath string) (int, bool) { + // Only scalar architecture keys are read, so skip the large array metadata + // (phoneme map) and mmap the header - same rationale as pkg/vram's reader. + f, err := gguf.ParseGGUFFile(modelPath, gguf.UseMMap(), gguf.SkipLargeMetadata()) + if err != nil { + return 0, false + } + kv, ok := f.Header.MetadataKV.Get("piper.sample_rate") + if !ok || kv.ValueType != gguf.GGUFMetadataValueTypeUint32 { + return 0, false + } + rate := int(kv.ValueUint32()) + if rate <= 0 { + return 0, false + } + return rate, true } // splitOption splits a "prefix:value" model option into its key and value, @@ -103,6 +137,14 @@ func (w *CrispASR) Load(opts *pb.ModelOptions) error { return fmt.Errorf("Failed to load CrispASR transcription model") } + // Determine the TTS output sample rate for the WAV header. piper voices + // carry their native rate in GGUF metadata and CrispASR does not resample; + // every other engine emits the 24 kHz default. + w.sampleRate = defaultTTSSampleRate + if rate, ok := piperSampleRate(opts.ModelFile); ok { + w.sampleRate = rate + } + // Load the companion file (codec/tokenizer/s3gen) after the session is open. // rc==0 means success or "not applicable" for the active backend; only a // negative code is fatal. @@ -390,7 +432,7 @@ func (w *CrispASR) synthesize(text string) ([]float32, error) { } defer CppTTSFree(ptr) src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free. - out := make([]float32, int(n)) // copy out of C memory before free + out := make([]float32, int(n)) // copy out of C memory before free copy(out, src) return out, nil } @@ -417,7 +459,7 @@ func (w *CrispASR) TTS(req *pb.TTSRequest) error { if err != nil { return err } - return writeWAV24k(req.Dst, pcm) + return writeWAV(req.Dst, pcm, w.sampleRate) } // TTSStream is the streaming counterpart to TTS. CrispASR has no progressive @@ -447,7 +489,7 @@ func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error { } defer func() { _ = os.Remove(dst) }() - if err := writeWAV24k(dst, pcm); err != nil { + if err := writeWAV(dst, pcm, w.sampleRate); err != nil { return err } @@ -459,14 +501,14 @@ func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error { return nil } -// writeWAV24k writes pcm as a 24000 Hz, mono, 16-bit PCM WAV at dst. -func writeWAV24k(dst string, pcm []float32) error { +// writeWAV writes pcm as a sampleRate Hz, mono, 16-bit PCM WAV at dst. +func writeWAV(dst string, pcm []float32, sampleRate int) error { f, err := os.Create(dst) if err != nil { return fmt.Errorf("crispasr: create %q: %w", dst, err) } - enc := wav.NewEncoder(f, 24000, 16, 1, 1) + enc := wav.NewEncoder(f, sampleRate, 16, 1, 1) ints := make([]int, len(pcm)) for i, s := range pcm { if s > 1 { @@ -477,7 +519,7 @@ func writeWAV24k(dst string, pcm []float32) error { ints[i] = int(s * 32767) } buf := &audio.IntBuffer{ - Format: &audio.Format{NumChannels: 1, SampleRate: 24000}, + Format: &audio.Format{NumChannels: 1, SampleRate: sampleRate}, Data: ints, SourceBitDepth: 16, } diff --git a/backend/go/crispasr/gocrispasr_samplerate_test.go b/backend/go/crispasr/gocrispasr_samplerate_test.go new file mode 100644 index 000000000..6b0cf726b --- /dev/null +++ b/backend/go/crispasr/gocrispasr_samplerate_test.go @@ -0,0 +1,164 @@ +package main + +import ( + "bytes" + "encoding/binary" + "os" + "path/filepath" + + "github.com/go-audio/wav" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// GGUF metadata value type tags (subset) from the GGUF spec. +const ( + ggufTypeUint32 uint32 = 4 + ggufTypeString uint32 = 8 +) + +type ggufKV struct { + key string + vtype uint32 + val any +} + +// writeMinimalGGUF emits a valid, tensor-less GGUF file carrying only the given +// metadata key-values. Enough for the header-only parse path piperSampleRate +// uses; avoids pulling a real multi-MB voice into the test. +func writeMinimalGGUF(path string, kvs []ggufKV) error { + var b bytes.Buffer + b.WriteString("GGUF") // magic + _ = binary.Write(&b, binary.LittleEndian, uint32(3)) // version + _ = binary.Write(&b, binary.LittleEndian, uint64(0)) // tensor count + _ = binary.Write(&b, binary.LittleEndian, uint64(len(kvs))) + for _, kv := range kvs { + _ = binary.Write(&b, binary.LittleEndian, uint64(len(kv.key))) + b.WriteString(kv.key) + _ = binary.Write(&b, binary.LittleEndian, kv.vtype) + switch v := kv.val.(type) { + case uint32: + _ = binary.Write(&b, binary.LittleEndian, v) + case string: + _ = binary.Write(&b, binary.LittleEndian, uint64(len(v))) + b.WriteString(v) + } + } + return os.WriteFile(path, b.Bytes(), 0o644) +} + +// wavSampleRate decodes the WAV header at path and returns its sample rate. +func wavSampleRate(path string) (int, error) { + f, err := os.Open(path) + if err != nil { + return 0, err + } + defer func() { _ = f.Close() }() + dec := wav.NewDecoder(f) + dec.ReadInfo() + return int(dec.SampleRate), nil +} + +var _ = Describe("piper sample rate", func() { + Context("piperSampleRate", func() { + It("reads piper.sample_rate from a piper GGUF (medium = 22050)", func() { + p := filepath.Join(GinkgoT().TempDir(), "voice.gguf") + Expect(writeMinimalGGUF(p, []ggufKV{ + {key: "general.architecture", vtype: ggufTypeString, val: "piper"}, + {key: "piper.sample_rate", vtype: ggufTypeUint32, val: uint32(22050)}, + })).To(Succeed()) + + rate, ok := piperSampleRate(p) + Expect(ok).To(BeTrue(), "piper.sample_rate should be found") + Expect(rate).To(Equal(22050)) + }) + + It("reads the low-quality rate (16000)", func() { + p := filepath.Join(GinkgoT().TempDir(), "voice.gguf") + Expect(writeMinimalGGUF(p, []ggufKV{ + {key: "piper.sample_rate", vtype: ggufTypeUint32, val: uint32(16000)}, + })).To(Succeed()) + + rate, ok := piperSampleRate(p) + Expect(ok).To(BeTrue()) + Expect(rate).To(Equal(16000)) + }) + + It("returns ok=false for a non-piper GGUF (no piper.sample_rate key)", func() { + p := filepath.Join(GinkgoT().TempDir(), "other.gguf") + Expect(writeMinimalGGUF(p, []ggufKV{ + {key: "general.architecture", vtype: ggufTypeString, val: "vibevoice"}, + })).To(Succeed()) + + _, ok := piperSampleRate(p) + Expect(ok).To(BeFalse()) + }) + + It("returns ok=false for an unreadable/non-GGUF file", func() { + p := filepath.Join(GinkgoT().TempDir(), "garbage.gguf") + Expect(os.WriteFile(p, []byte("not a gguf"), 0o644)).To(Succeed()) + + _, ok := piperSampleRate(p) + Expect(ok).To(BeFalse()) + }) + }) + + // End-to-end through the built .so. Gated on CRISPASR_PIPER_MODEL_PATH (a + // real piper voice GGUF) like the other model-backed specs; never runs in + // default CI. Proves CrispASR's piper backend output rate flows into the + // WAV header instead of the hardcoded 24 kHz default. + Context("piper TTS end-to-end", func() { + It("writes the WAV at the model's native piper.sample_rate", func() { + model := os.Getenv("CRISPASR_PIPER_MODEL_PATH") + if model == "" { + Skip("set CRISPASR_PIPER_MODEL_PATH to run the piper e2e spec") + } + ensureLibLoaded() + + expected, ok := piperSampleRate(model) + Expect(ok).To(BeTrue(), "model should carry piper.sample_rate metadata") + + w := &CrispASR{} + Expect(w.Load(&pb.ModelOptions{ + ModelFile: model, + Options: []string{"backend:piper"}, + Threads: 4, + })).To(Succeed()) + + dst := filepath.Join(GinkgoT().TempDir(), "piper.wav") + Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR piper.", Dst: dst})).To(Succeed()) + + info, err := os.Stat(dst) + Expect(err).ToNot(HaveOccurred()) + Expect(info.Size()).To(BeNumerically(">", 1024), "expected a non-trivial WAV") + + rate, err := wavSampleRate(dst) + Expect(err).ToNot(HaveOccurred()) + Expect(rate).To(Equal(expected), + "WAV header rate must equal the model's native piper.sample_rate, not the 24k default") + }) + }) + + Context("writeWAV", func() { + It("writes the WAV header at the given sample rate (22050 for piper, not the 24k default)", func() { + dst := filepath.Join(GinkgoT().TempDir(), "out.wav") + pcm := make([]float32, 220) // 10 ms of silence is enough for a header + Expect(writeWAV(dst, pcm, 22050)).To(Succeed()) + + rate, err := wavSampleRate(dst) + Expect(err).ToNot(HaveOccurred()) + Expect(rate).To(Equal(22050)) + }) + + It("writes a 16000 Hz header for low-quality piper voices", func() { + dst := filepath.Join(GinkgoT().TempDir(), "out.wav") + pcm := make([]float32, 160) + Expect(writeWAV(dst, pcm, 16000)).To(Succeed()) + + rate, err := wavSampleRate(dst) + Expect(err).ToNot(HaveOccurred()) + Expect(rate).To(Equal(16000)) + }) + }) +})