mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-15 04:08:55 -04:00
fix(crispasr): write piper TTS WAV at the model's native sample rate (#10277)
CrispASR's piper backend returns PCM at the voice's native rate (from the GGUF piper.sample_rate key: 16 kHz for x_low/low, 22.05 kHz for medium/high) and does not resample, but the Go WAV encoder hardcoded 24000 Hz. Every piper voice was therefore written with a wrong header and played back at the wrong pitch/speed. Read piper.sample_rate from the model's GGUF metadata at Load via the vendored gguf-parser-go and use it for the WAV header, falling back to the 24 kHz default for the other CrispASR TTS engines (vibevoice/orpheus/chatterbox/qwen3-tts) that emit 24 kHz and carry no such key. Adds unit specs (minimal crafted GGUFs + WAV-header decode) and an env-gated end-to-end spec (CRISPASR_PIPER_MODEL_PATH). Verified e2e: en_GB-cori-medium synthesizes a 22050 Hz WAV through backend:piper. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -37,6 +38,39 @@ var (
|
||||
|
||||
type CrispASR struct {
|
||||
base.SingleThread
|
||||
// sampleRate is the output rate (Hz) of the loaded TTS engine's PCM, used to
|
||||
// write a correct WAV header. Most CrispASR TTS backends emit 24 kHz, but
|
||||
// piper returns its model's native rate (16 kHz for x_low/low voices,
|
||||
// 22.05 kHz for medium/high), so it is read from the GGUF metadata at Load.
|
||||
sampleRate int
|
||||
}
|
||||
|
||||
// defaultTTSSampleRate is the output rate assumed for CrispASR TTS engines that
|
||||
// don't advertise one in GGUF metadata (vibevoice/orpheus/chatterbox/qwen3-tts
|
||||
// all emit 24 kHz). piper is the exception and carries piper.sample_rate.
|
||||
const defaultTTSSampleRate = 24000
|
||||
|
||||
// piperSampleRate reads the piper.sample_rate metadata key from a GGUF model.
|
||||
// CrispASR's piper backend returns PCM at the model's native rate without
|
||||
// resampling, so the WAV header must match it. Returns ok=false for non-piper
|
||||
// models (key absent) or an unreadable file, letting the caller fall back to
|
||||
// defaultTTSSampleRate.
|
||||
func piperSampleRate(modelPath string) (int, bool) {
|
||||
// Only scalar architecture keys are read, so skip the large array metadata
|
||||
// (phoneme map) and mmap the header - same rationale as pkg/vram's reader.
|
||||
f, err := gguf.ParseGGUFFile(modelPath, gguf.UseMMap(), gguf.SkipLargeMetadata())
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
kv, ok := f.Header.MetadataKV.Get("piper.sample_rate")
|
||||
if !ok || kv.ValueType != gguf.GGUFMetadataValueTypeUint32 {
|
||||
return 0, false
|
||||
}
|
||||
rate := int(kv.ValueUint32())
|
||||
if rate <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return rate, true
|
||||
}
|
||||
|
||||
// splitOption splits a "prefix:value" model option into its key and value,
|
||||
@@ -103,6 +137,14 @@ func (w *CrispASR) Load(opts *pb.ModelOptions) error {
|
||||
return fmt.Errorf("Failed to load CrispASR transcription model")
|
||||
}
|
||||
|
||||
// Determine the TTS output sample rate for the WAV header. piper voices
|
||||
// carry their native rate in GGUF metadata and CrispASR does not resample;
|
||||
// every other engine emits the 24 kHz default.
|
||||
w.sampleRate = defaultTTSSampleRate
|
||||
if rate, ok := piperSampleRate(opts.ModelFile); ok {
|
||||
w.sampleRate = rate
|
||||
}
|
||||
|
||||
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
||||
// rc==0 means success or "not applicable" for the active backend; only a
|
||||
// negative code is fatal.
|
||||
@@ -390,7 +432,7 @@ func (w *CrispASR) synthesize(text string) ([]float32, error) {
|
||||
}
|
||||
defer CppTTSFree(ptr)
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
||||
out := make([]float32, int(n)) // copy out of C memory before free
|
||||
out := make([]float32, int(n)) // copy out of C memory before free
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
@@ -417,7 +459,7 @@ func (w *CrispASR) TTS(req *pb.TTSRequest) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWAV24k(req.Dst, pcm)
|
||||
return writeWAV(req.Dst, pcm, w.sampleRate)
|
||||
}
|
||||
|
||||
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
||||
@@ -447,7 +489,7 @@ func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
}
|
||||
defer func() { _ = os.Remove(dst) }()
|
||||
|
||||
if err := writeWAV24k(dst, pcm); err != nil {
|
||||
if err := writeWAV(dst, pcm, w.sampleRate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -459,14 +501,14 @@ func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeWAV24k writes pcm as a 24000 Hz, mono, 16-bit PCM WAV at dst.
|
||||
func writeWAV24k(dst string, pcm []float32) error {
|
||||
// writeWAV writes pcm as a sampleRate Hz, mono, 16-bit PCM WAV at dst.
|
||||
func writeWAV(dst string, pcm []float32, sampleRate int) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
||||
}
|
||||
|
||||
enc := wav.NewEncoder(f, 24000, 16, 1, 1)
|
||||
enc := wav.NewEncoder(f, sampleRate, 16, 1, 1)
|
||||
ints := make([]int, len(pcm))
|
||||
for i, s := range pcm {
|
||||
if s > 1 {
|
||||
@@ -477,7 +519,7 @@ func writeWAV24k(dst string, pcm []float32) error {
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 24000},
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: sampleRate},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
|
||||
164
backend/go/crispasr/gocrispasr_samplerate_test.go
Normal file
164
backend/go/crispasr/gocrispasr_samplerate_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// GGUF metadata value type tags (subset) from the GGUF spec.
|
||||
const (
|
||||
ggufTypeUint32 uint32 = 4
|
||||
ggufTypeString uint32 = 8
|
||||
)
|
||||
|
||||
type ggufKV struct {
|
||||
key string
|
||||
vtype uint32
|
||||
val any
|
||||
}
|
||||
|
||||
// writeMinimalGGUF emits a valid, tensor-less GGUF file carrying only the given
|
||||
// metadata key-values. Enough for the header-only parse path piperSampleRate
|
||||
// uses; avoids pulling a real multi-MB voice into the test.
|
||||
func writeMinimalGGUF(path string, kvs []ggufKV) error {
|
||||
var b bytes.Buffer
|
||||
b.WriteString("GGUF") // magic
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint32(3)) // version
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint64(0)) // tensor count
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint64(len(kvs)))
|
||||
for _, kv := range kvs {
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint64(len(kv.key)))
|
||||
b.WriteString(kv.key)
|
||||
_ = binary.Write(&b, binary.LittleEndian, kv.vtype)
|
||||
switch v := kv.val.(type) {
|
||||
case uint32:
|
||||
_ = binary.Write(&b, binary.LittleEndian, v)
|
||||
case string:
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint64(len(v)))
|
||||
b.WriteString(v)
|
||||
}
|
||||
}
|
||||
return os.WriteFile(path, b.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
// wavSampleRate decodes the WAV header at path and returns its sample rate.
|
||||
func wavSampleRate(path string) (int, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
dec := wav.NewDecoder(f)
|
||||
dec.ReadInfo()
|
||||
return int(dec.SampleRate), nil
|
||||
}
|
||||
|
||||
var _ = Describe("piper sample rate", func() {
|
||||
Context("piperSampleRate", func() {
|
||||
It("reads piper.sample_rate from a piper GGUF (medium = 22050)", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "voice.gguf")
|
||||
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||
{key: "general.architecture", vtype: ggufTypeString, val: "piper"},
|
||||
{key: "piper.sample_rate", vtype: ggufTypeUint32, val: uint32(22050)},
|
||||
})).To(Succeed())
|
||||
|
||||
rate, ok := piperSampleRate(p)
|
||||
Expect(ok).To(BeTrue(), "piper.sample_rate should be found")
|
||||
Expect(rate).To(Equal(22050))
|
||||
})
|
||||
|
||||
It("reads the low-quality rate (16000)", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "voice.gguf")
|
||||
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||
{key: "piper.sample_rate", vtype: ggufTypeUint32, val: uint32(16000)},
|
||||
})).To(Succeed())
|
||||
|
||||
rate, ok := piperSampleRate(p)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(rate).To(Equal(16000))
|
||||
})
|
||||
|
||||
It("returns ok=false for a non-piper GGUF (no piper.sample_rate key)", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "other.gguf")
|
||||
Expect(writeMinimalGGUF(p, []ggufKV{
|
||||
{key: "general.architecture", vtype: ggufTypeString, val: "vibevoice"},
|
||||
})).To(Succeed())
|
||||
|
||||
_, ok := piperSampleRate(p)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns ok=false for an unreadable/non-GGUF file", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "garbage.gguf")
|
||||
Expect(os.WriteFile(p, []byte("not a gguf"), 0o644)).To(Succeed())
|
||||
|
||||
_, ok := piperSampleRate(p)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
// End-to-end through the built .so. Gated on CRISPASR_PIPER_MODEL_PATH (a
|
||||
// real piper voice GGUF) like the other model-backed specs; never runs in
|
||||
// default CI. Proves CrispASR's piper backend output rate flows into the
|
||||
// WAV header instead of the hardcoded 24 kHz default.
|
||||
Context("piper TTS end-to-end", func() {
|
||||
It("writes the WAV at the model's native piper.sample_rate", func() {
|
||||
model := os.Getenv("CRISPASR_PIPER_MODEL_PATH")
|
||||
if model == "" {
|
||||
Skip("set CRISPASR_PIPER_MODEL_PATH to run the piper e2e spec")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
|
||||
expected, ok := piperSampleRate(model)
|
||||
Expect(ok).To(BeTrue(), "model should carry piper.sample_rate metadata")
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{
|
||||
ModelFile: model,
|
||||
Options: []string{"backend:piper"},
|
||||
Threads: 4,
|
||||
})).To(Succeed())
|
||||
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "piper.wav")
|
||||
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR piper.", Dst: dst})).To(Succeed())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(info.Size()).To(BeNumerically(">", 1024), "expected a non-trivial WAV")
|
||||
|
||||
rate, err := wavSampleRate(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rate).To(Equal(expected),
|
||||
"WAV header rate must equal the model's native piper.sample_rate, not the 24k default")
|
||||
})
|
||||
})
|
||||
|
||||
Context("writeWAV", func() {
|
||||
It("writes the WAV header at the given sample rate (22050 for piper, not the 24k default)", func() {
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
pcm := make([]float32, 220) // 10 ms of silence is enough for a header
|
||||
Expect(writeWAV(dst, pcm, 22050)).To(Succeed())
|
||||
|
||||
rate, err := wavSampleRate(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rate).To(Equal(22050))
|
||||
})
|
||||
|
||||
It("writes a 16000 Hz header for low-quality piper voices", func() {
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
pcm := make([]float32, 160)
|
||||
Expect(writeWAV(dst, pcm, 16000)).To(Succeed())
|
||||
|
||||
rate, err := wavSampleRate(dst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rate).To(Equal(16000))
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user