mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-16 12:38:01 -04:00
feat(api): add /v1/audio/diarization endpoint with sherpa-onnx + vibevoice.cpp (#9654)
* feat(api): add /v1/audio/diarization endpoint with sherpa-onnx + vibevoice.cpp
Closes #1648.
OpenAI-style multipart endpoint that returns "who spoke when". Single
endpoint instead of the issue's three-endpoint sketch (refactor /vad,
/vad/embedding, /diarization) — the typical client wants one call, and
embeddings can land later as a sibling without breaking this surface.
Response shape borrows from Pyannote/Deepgram: segments carry a
normalised SPEAKER_NN id (zero-padded, stable across the response) plus
the raw backend label, optional per-segment text when the backend bundles
ASR, and a speakers summary in verbose_json. response_format also accepts
rttm so consumers can pipe straight into pyannote.metrics / dscore.
Backends:
* vibevoice-cpp — Diarize() reuses the existing vv_capi_asr pass.
vibevoice's ASR prompt asks the model to emit
[{Start,End,Speaker,Content}] natively, so diarization is a by-product
of the same pass; include_text=true preserves the transcript per
segment, otherwise we drop it.
* sherpa-onnx — wraps the upstream SherpaOnnxOfflineSpeakerDiarization
C API (pyannote segmentation + speaker-embedding extractor + fast
clustering). libsherpa-shim grew config builders, a SetClustering
wrapper for per-call num_clusters/threshold overrides, and a
segment_at accessor (purego can't read field arrays out of
SherpaOnnxOfflineSpeakerDiarizationSegment[] directly).
Plumbing: new Diarize gRPC RPC + DiarizeRequest / DiarizeSegment /
DiarizeResponse messages, threaded through interface.go, base, server,
client, embed. Default Base impl returns unimplemented.
Capability surfaces all updated: FLAG_DIARIZATION usecase,
FeatureAudioDiarization permission (default-on), RouteFeatureRegistry
entries for /v1/audio/diarization and /audio/diarization, audio
instruction-def description widened, CAP_DIARIZATION JS symbol,
swagger regenerated, /api/instructions discovery map updated.
Tests:
* core/backend: speaker-label normalisation (first-seen → SPEAKER_NN,
per-speaker totals, nil-safety, fallback to backend NumSpeakers when
no segments).
* core/http/endpoints/openai: RTTM rendering (file-id basename, negative
duration clamping, fallback id).
* tests/e2e: mock-backend grew a deterministic Diarize that emits
raw labels "5","2","5" so the e2e suite verifies SPEAKER_NN
remapping, verbose_json speakers summary + transcript pass-through
(gated by include_text), RTTM bytes content-type, and rejection of
unknown response_format. mock-diarize model config registered with
known_usecases=[FLAG_DIARIZATION] to bypass the backend-name guard.
Docs: new features/audio-diarization.md (request/response, RTTM example,
sherpa-onnx + vibevoice setup), cross-link from audio-to-text.md, entry
in whats-new.md.
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]
* fix(diarization): correct sherpa-onnx symbol name + lint cleanup
CI failures on #9654:
* sherpa-onnx-grpc-{tts,transcription} and sherpa-onnx-realtime panicked
at backend startup with `undefined symbol: SherpaOnnxDestroyOfflineSpeakerDiarizationResult`.
Upstream's actual symbol is SherpaOnnxOfflineSpeakerDiarizationDestroyResult
(Destroy in the middle, not the prefix); the rest of the diarization
surface follows the same naming pattern. The mismatched name made
purego.RegisterLibFunc fail at dlopen time and crashed the gRPC server
before the BeforeAll could probe Health, taking down every sherpa-onnx
test job — not just the diarization-related ones.
* golangci-lint flagged 5 errcheck violations on new defer cleanups
(os.RemoveAll / Close / conn.Close); wrap each in a `defer func() { _ = X() }()`
closure (matches the pattern other LocalAI files use for new code, since
pre-existing bare defers are grandfathered in via new-from-merge-base).
* golangci-lint also flagged forbidigo violations: the new
diarization_test.go files used testing.T-style `t.Errorf` / `t.Fatalf`,
which are forbidden by the project's coding-style policy
(.agents/coding-style.md). Convert both files to Ginkgo/Gomega
Describe/It with Expect(...) — they get picked up by the existing
TestBackend / TestOpenAI suites, no new suite plumbing needed.
* modernize linter: tightened the diarization segment loop to
`for i := range int(numSegments)` (Go 1.22+ idiom).
Verified locally: golangci-lint with new-from-merge-base=origin/master
reports 0 issues across all touched packages, and the four mocked
diarization e2e specs in tests/e2e/mock_backend_test.go still pass.
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]
* fix(vibevoice-cpp): convert non-WAV input via ffmpeg + raise ASR token budget
Confirmed end-to-end against a real LocalAI instance with vibevoice-asr-q4_k
loaded and the multi-speaker MP3 sample at vibevoice.cpp/samples/2p_argument.mp3:
both /v1/audio/transcriptions and /v1/audio/diarization now succeed and
return correctly attributed speaker turns for the full clip.
Two latent issues surfaced once the diarization endpoint actually exercised
the backend with a non-trivial input:
1. vv_capi_asr only accepts WAV via load_wav_24k_mono. The previous code
passed the uploaded path straight through, so anything that wasn't
already a 24 kHz mono s16le WAV failed at the C side with rc=-8 and
the very unhelpful "vv_capi_asr failed". prepareWavInput shells out
to ffmpeg ("-ar 24000 -ac 1 -acodec pcm_s16le") in a per-call temp
dir, matching the rate the model was trained on; both AudioTranscription
and Diarize now route through it. This is the same shape sherpa-onnx
uses (utils.AudioToWav), but vibevoice needs 24 kHz rather than 16 kHz
so we don't reuse that helper.
2. The C ABI's max_new_tokens defaults to 256 when 0 is passed. That's
fine for a five-second clip but not for anything past ~10 s — vibevoice
stops mid-JSON, the parse fails, and the caller sees a hard error.
Pass a much larger budget (16 384 ≈ ~9 minutes of speech at the
model's ~30 tok/s rate); generation stops at EOS so this is a cap
rather than a target.
3. As a defensive belt-and-braces, mirror AudioTranscription's existing
"fall back to a single segment if the model emits non-JSON text"
pattern in Diarize, so partial / unusual model output never produces
a 500. This kept the endpoint usable while diagnosing (1) and (2),
and is the right behaviour to keep.
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]
* fix(vibevoice-cpp): pass valid WAVs through directly so ffmpeg is not required at runtime
Spotted by tests-e2e-backend (1.25.x): the previous fix forced every
incoming audio file through `ffmpeg -ar 24000 ...`, which meant the
backend container — which does not ship ffmpeg — failed even for the
existing happy path where the caller already uploads a WAV. The
container-side error was:
rpc error: code = Unknown desc = vibevoice-cpp: ffmpeg convert to
24k mono wav: exec: "ffmpeg": executable file not found in $PATH
Reading vibevoice.cpp's audio_io.cpp, `load_wav_24k_mono` uses drwav and
already accepts any PCM/IEEE-float WAV at any sample rate, downmixes
multi-channel input to mono, and resamples to 24 kHz internally. So the
only inputs that genuinely need an external converter are non-WAV
formats (MP3, OGG, FLAC, ...).
Detect WAVs by RIFF/WAVE magic at bytes 0..3 / 8..11 and pass them
straight through with a no-op cleanup; everything else still goes
through ffmpeg with the same 24 kHz mono s16le target. The result:
* Container builds without ffmpeg keep working for WAV uploads
(the e2e-backends fixture is jfk.wav at 16 kHz mono s16le).
* MP3 and other non-WAV inputs still get the new ffmpeg conversion
path so the diarization endpoint stays useful.
* If the caller uploads a non-WAV but ffmpeg isn't on PATH, the
surfaced error is still descriptive enough to act on.
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]
* fix(ci): make gcc-14 install in Dockerfile.golang best-effort for jammy bases
The LocalVQE PR (bb033b16) made `gcc-14 g++-14` an unconditional apt
install in backend/Dockerfile.golang and pointed update-alternatives at
them. That works on the default `BASE_IMAGE=ubuntu:24.04` (noble has
gcc-14 in main), but every Go backend that builds on
`nvcr.io/nvidia/l4t-jetpack:r36.4.0` — jammy under the hood — now fails
at the apt step:
E: Unable to locate package gcc-14
This blocked unrelated jobs:
backend-jobs(*-nvidia-l4t-arm64-{stablediffusion-ggml, sam3-cpp, whisper,
acestep-cpp, qwen3-tts-cpp, vibevoice-cpp}). LocalVQE itself is only
matrix-built on ubuntu:24.04 (CPU + Vulkan), so it doesn't actually
need gcc-14 anywhere else.
Make the gcc-14 install conditional on the package being available in
the configured apt repos. On noble: identical behaviour to today (gcc-14
installed, update-alternatives points at it). On jammy: skip the
gcc-14 stanza entirely and let build-essential's default gcc take over,
which is what the other Go backends compile with anyway.
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]
---------
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
1634eece6b
commit
e86ade54a6
@@ -21,20 +21,28 @@ ENV AMDGPU_TARGETS=${AMDGPU_TARGETS}
|
||||
ARG APT_MIRROR
|
||||
ARG APT_PORTS_MIRROR
|
||||
|
||||
# gcc-14 is the default on noble (ubuntu:24.04) but absent from jammy
|
||||
# (the L4T jetpack r36.4.0 base). LocalVQE specifically needs it; the
|
||||
# other Go backends compile fine with the default gcc shipped via
|
||||
# build-essential. So: try gcc-14 from the configured repos, fall back
|
||||
# gracefully when it's not available so jammy-based builds don't fail
|
||||
# at the apt step.
|
||||
RUN --mount=type=bind,source=.docker/apt-mirror.sh,target=/usr/local/sbin/apt-mirror \
|
||||
APT_MIRROR="${APT_MIRROR}" APT_PORTS_MIRROR="${APT_PORTS_MIRROR}" sh /usr/local/sbin/apt-mirror && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
gcc-14 g++-14 \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget libopenblas-dev \
|
||||
curl unzip \
|
||||
libssl-dev && \
|
||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 \
|
||||
--slave /usr/bin/g++ g++ /usr/bin/g++-14 \
|
||||
--slave /usr/bin/gcov gcov /usr/bin/gcov-14 && \
|
||||
if apt-cache show gcc-14 >/dev/null 2>&1 && apt-cache show g++-14 >/dev/null 2>&1; then \
|
||||
apt-get install -y --no-install-recommends gcc-14 g++-14 && \
|
||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 \
|
||||
--slave /usr/bin/g++ g++ /usr/bin/g++-14 \
|
||||
--slave /usr/bin/gcov gcov /usr/bin/gcov-14; \
|
||||
fi && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -41,6 +41,8 @@ service Backend {
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
|
||||
rpc Diarize(DiarizeRequest) returns (DiarizeResponse) {}
|
||||
|
||||
rpc AudioEncode(AudioEncodeRequest) returns (AudioEncodeResult) {}
|
||||
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
||||
|
||||
@@ -416,6 +418,43 @@ message VADResponse {
|
||||
repeated VADSegment segments = 1;
|
||||
}
|
||||
|
||||
// --- Speaker diarization messages ---
|
||||
//
|
||||
// Pure speaker diarization: "who spoke when". Returns time-stamped segments
|
||||
// labelled with cluster IDs (the same string for the same speaker across
|
||||
// segments). Some backends (e.g. vibevoice.cpp) produce diarization as a
|
||||
// by-product of ASR and may also fill in `text` per segment; backends with a
|
||||
// dedicated diarization pipeline (e.g. sherpa-onnx pyannote) leave `text`
|
||||
// empty and emit only the segmentation.
|
||||
|
||||
message DiarizeRequest {
|
||||
string dst = 1; // path to audio file (HTTP layer materialises uploads to a temp file)
|
||||
uint32 threads = 2;
|
||||
string language = 3; // optional; only meaningful for transcription-bundling backends
|
||||
int32 num_speakers = 4; // exact speaker count if known (>0 forces); 0 = auto
|
||||
int32 min_speakers = 5; // hint when auto-detecting; 0 = unset
|
||||
int32 max_speakers = 6; // hint when auto-detecting; 0 = unset
|
||||
float clustering_threshold = 7; // distance threshold when num_speakers unknown; 0 = backend default
|
||||
float min_duration_on = 8; // discard segments shorter than this (seconds); 0 = backend default
|
||||
float min_duration_off = 9; // merge gaps shorter than this (seconds); 0 = backend default
|
||||
bool include_text = 10; // when the backend can emit per-segment transcript for free, ask it to populate `text`
|
||||
}
|
||||
|
||||
message DiarizeSegment {
|
||||
int32 id = 1;
|
||||
float start = 2; // seconds
|
||||
float end = 3; // seconds
|
||||
string speaker = 4; // backend-emitted speaker label (e.g. "0", "SPEAKER_00")
|
||||
string text = 5; // optional per-segment transcript (empty unless include_text and supported)
|
||||
}
|
||||
|
||||
message DiarizeResponse {
|
||||
repeated DiarizeSegment segments = 1;
|
||||
int32 num_speakers = 2; // count of distinct speaker labels in `segments`
|
||||
float duration = 3; // total audio duration in seconds (0 if unknown)
|
||||
string language = 4; // optional, when the backend bundles transcription
|
||||
}
|
||||
|
||||
message SoundGenerationRequest {
|
||||
string text = 1;
|
||||
string model = 2;
|
||||
|
||||
@@ -29,6 +29,12 @@ type SherpaBackend struct {
|
||||
vadWindowSize int
|
||||
ttsSpeed float32
|
||||
onlineChunkSamples int
|
||||
|
||||
// Speaker diarization (offline pyannote + embedding extractor + clustering).
|
||||
// diarSampleRate is reported by sherpa at create time; we cache it so
|
||||
// runDiarization can resample only when the input doesn't already match.
|
||||
diarizer uintptr
|
||||
diarSampleRate int
|
||||
}
|
||||
|
||||
var onnxProvider = "cpu"
|
||||
@@ -128,6 +134,25 @@ var (
|
||||
|
||||
// TTS streaming callback trampoline
|
||||
shimTtsGenerateWithCallback func(tts uintptr, text string, sid int32, speed float32, cb uintptr, ud uintptr) uintptr
|
||||
|
||||
// Diarization config + result accessors (see csrc/shim.h).
|
||||
shimDiarizeConfigNew func() uintptr
|
||||
shimDiarizeConfigFree func(uintptr)
|
||||
shimDiarizeConfigSetSegmentationModel func(uintptr, string)
|
||||
shimDiarizeConfigSetSegmentationNumThreads func(uintptr, int32)
|
||||
shimDiarizeConfigSetSegmentationProvider func(uintptr, string)
|
||||
shimDiarizeConfigSetSegmentationDebug func(uintptr, int32)
|
||||
shimDiarizeConfigSetEmbeddingModel func(uintptr, string)
|
||||
shimDiarizeConfigSetEmbeddingNumThreads func(uintptr, int32)
|
||||
shimDiarizeConfigSetEmbeddingProvider func(uintptr, string)
|
||||
shimDiarizeConfigSetEmbeddingDebug func(uintptr, int32)
|
||||
shimDiarizeConfigSetClusteringNumClusters func(uintptr, int32)
|
||||
shimDiarizeConfigSetClusteringThreshold func(uintptr, float32)
|
||||
shimDiarizeConfigSetMinDurationOn func(uintptr, float32)
|
||||
shimDiarizeConfigSetMinDurationOff func(uintptr, float32)
|
||||
shimCreateOfflineSpeakerDiarization func(uintptr) uintptr
|
||||
shimDiarizeSetClustering func(uintptr, int32, float32)
|
||||
shimDiarizeSegmentAt func(segs uintptr, i int32, outStart unsafe.Pointer, outEnd unsafe.Pointer, outSpeaker unsafe.Pointer)
|
||||
)
|
||||
|
||||
// libsherpa-onnx-c-api pass-throughs — called directly from Go via purego.
|
||||
@@ -172,6 +197,18 @@ var (
|
||||
sherpaOfflineTtsGenerate func(tts uintptr, text string, sid int32, speed float32) uintptr
|
||||
sherpaDestroyOfflineTtsGeneratedAudio func(audio uintptr)
|
||||
sherpaOfflineTtsSampleRate func(tts uintptr) int32
|
||||
|
||||
// Offline speaker diarization. Result handle owns the segment-array
|
||||
// pointer returned by ResultSortByStartTime; destroy the segment
|
||||
// array first, then the result, then (at backend Free()) the diarizer.
|
||||
sherpaDestroyOfflineSpeakerDiarization func(sd uintptr)
|
||||
sherpaOfflineSpeakerDiarizationGetSampleRate func(sd uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationProcess func(sd uintptr, samples unsafe.Pointer, n int32) uintptr
|
||||
sherpaOfflineSpeakerDiarizationResultGetNumSegments func(result uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationResultGetNumSpeakers func(result uintptr) int32
|
||||
sherpaOfflineSpeakerDiarizationResultSortByStartTime func(result uintptr) uintptr
|
||||
sherpaOfflineSpeakerDiarizationDestroySegment func(segs uintptr)
|
||||
sherpaDestroyOfflineSpeakerDiarizationResult func(result uintptr)
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -292,6 +329,24 @@ func loadSherpaLibsOnce() error {
|
||||
{&shimSpeechSegmentStart, "sherpa_shim_speech_segment_start"},
|
||||
{&shimSpeechSegmentN, "sherpa_shim_speech_segment_n"},
|
||||
{&shimTtsGenerateWithCallback, "sherpa_shim_tts_generate_with_callback"},
|
||||
|
||||
{&shimDiarizeConfigNew, "sherpa_shim_diarize_config_new"},
|
||||
{&shimDiarizeConfigFree, "sherpa_shim_diarize_config_free"},
|
||||
{&shimDiarizeConfigSetSegmentationModel, "sherpa_shim_diarize_config_set_segmentation_model"},
|
||||
{&shimDiarizeConfigSetSegmentationNumThreads, "sherpa_shim_diarize_config_set_segmentation_num_threads"},
|
||||
{&shimDiarizeConfigSetSegmentationProvider, "sherpa_shim_diarize_config_set_segmentation_provider"},
|
||||
{&shimDiarizeConfigSetSegmentationDebug, "sherpa_shim_diarize_config_set_segmentation_debug"},
|
||||
{&shimDiarizeConfigSetEmbeddingModel, "sherpa_shim_diarize_config_set_embedding_model"},
|
||||
{&shimDiarizeConfigSetEmbeddingNumThreads, "sherpa_shim_diarize_config_set_embedding_num_threads"},
|
||||
{&shimDiarizeConfigSetEmbeddingProvider, "sherpa_shim_diarize_config_set_embedding_provider"},
|
||||
{&shimDiarizeConfigSetEmbeddingDebug, "sherpa_shim_diarize_config_set_embedding_debug"},
|
||||
{&shimDiarizeConfigSetClusteringNumClusters, "sherpa_shim_diarize_config_set_clustering_num_clusters"},
|
||||
{&shimDiarizeConfigSetClusteringThreshold, "sherpa_shim_diarize_config_set_clustering_threshold"},
|
||||
{&shimDiarizeConfigSetMinDurationOn, "sherpa_shim_diarize_config_set_min_duration_on"},
|
||||
{&shimDiarizeConfigSetMinDurationOff, "sherpa_shim_diarize_config_set_min_duration_off"},
|
||||
{&shimCreateOfflineSpeakerDiarization, "sherpa_shim_create_offline_speaker_diarization"},
|
||||
{&shimDiarizeSetClustering, "sherpa_shim_diarize_set_clustering"},
|
||||
{&shimDiarizeSegmentAt, "sherpa_shim_diarize_segment_at"},
|
||||
} {
|
||||
purego.RegisterLibFunc(r.ptr, shim, r.name)
|
||||
}
|
||||
@@ -334,6 +389,15 @@ func loadSherpaLibsOnce() error {
|
||||
{&sherpaOfflineTtsGenerate, "SherpaOnnxOfflineTtsGenerate"},
|
||||
{&sherpaDestroyOfflineTtsGeneratedAudio, "SherpaOnnxDestroyOfflineTtsGeneratedAudio"},
|
||||
{&sherpaOfflineTtsSampleRate, "SherpaOnnxOfflineTtsSampleRate"},
|
||||
|
||||
{&sherpaDestroyOfflineSpeakerDiarization, "SherpaOnnxDestroyOfflineSpeakerDiarization"},
|
||||
{&sherpaOfflineSpeakerDiarizationGetSampleRate, "SherpaOnnxOfflineSpeakerDiarizationGetSampleRate"},
|
||||
{&sherpaOfflineSpeakerDiarizationProcess, "SherpaOnnxOfflineSpeakerDiarizationProcess"},
|
||||
{&sherpaOfflineSpeakerDiarizationResultGetNumSegments, "SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments"},
|
||||
{&sherpaOfflineSpeakerDiarizationResultGetNumSpeakers, "SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers"},
|
||||
{&sherpaOfflineSpeakerDiarizationResultSortByStartTime, "SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime"},
|
||||
{&sherpaOfflineSpeakerDiarizationDestroySegment, "SherpaOnnxOfflineSpeakerDiarizationDestroySegment"},
|
||||
{&sherpaDestroyOfflineSpeakerDiarizationResult, "SherpaOnnxOfflineSpeakerDiarizationDestroyResult"},
|
||||
} {
|
||||
purego.RegisterLibFunc(r.ptr, capi, r.name)
|
||||
}
|
||||
@@ -383,6 +447,11 @@ func isVADType(t string) bool {
|
||||
return t == "vad"
|
||||
}
|
||||
|
||||
func isDiarizationType(t string) bool {
|
||||
t = strings.ToLower(t)
|
||||
return t == "diarization" || t == "diarize" || t == "speaker-diarization"
|
||||
}
|
||||
|
||||
// Model-options prefixes recognised by this backend. Kept as typed
|
||||
// constants so the asrFamily / loadWhisperASR / loadGenericASR paths
|
||||
// can all speak the same vocabulary.
|
||||
@@ -423,6 +492,19 @@ const (
|
||||
optionOnlineRule2 = "online.rule2_min_trailing_silence="
|
||||
optionOnlineRule3 = "online.rule3_min_utterance_length="
|
||||
optionOnlineChunkSamples = "online.chunk_samples="
|
||||
|
||||
// Speaker diarization (offline pyannote + speaker-embedding extractor).
|
||||
// `diarize.segmentation_model` overrides the auto-detected pyannote
|
||||
// segmentation .onnx in modelDir; `diarize.embedding_model` does the
|
||||
// same for the speaker-embedding extractor. `diarize.num_clusters`
|
||||
// pins a known speaker count at load time; per-call DiarizeRequest
|
||||
// fields take precedence at process time.
|
||||
optionDiarizeSegmentationModel = "diarize.segmentation_model="
|
||||
optionDiarizeEmbeddingModel = "diarize.embedding_model="
|
||||
optionDiarizeNumClusters = "diarize.num_clusters="
|
||||
optionDiarizeThreshold = "diarize.threshold="
|
||||
optionDiarizeMinDurationOn = "diarize.min_duration_on="
|
||||
optionDiarizeMinDurationOff = "diarize.min_duration_off="
|
||||
)
|
||||
|
||||
func hasOption(opts *pb.ModelOptions, prefix string) bool {
|
||||
@@ -493,6 +575,9 @@ func (s *SherpaBackend) Load(opts *pb.ModelOptions) error {
|
||||
if isVADType(opts.Type) {
|
||||
return s.loadVAD(opts)
|
||||
}
|
||||
if isDiarizationType(opts.Type) {
|
||||
return s.loadDiarization(opts)
|
||||
}
|
||||
// An explicit `subtype=...` option routes to ASR even when Type is
|
||||
// unset — handy for the e2e-backends harness, which doesn't know
|
||||
// about ModelOptions.Type.
|
||||
@@ -1247,3 +1332,176 @@ func (s *SherpaBackend) TTSStream(req *pb.TTSRequest, results chan []byte) error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================
|
||||
// Speaker diarization (offline)
|
||||
// =============================================================
|
||||
//
|
||||
// Conventions:
|
||||
// - opts.ModelFile is the pyannote segmentation .onnx (e.g. model.onnx
|
||||
// under sherpa-onnx-pyannote-segmentation-3-0/). Override with
|
||||
// `diarize.segmentation_model=` if the gallery layout differs.
|
||||
// - The speaker-embedding extractor must be provided via
|
||||
// `diarize.embedding_model=`. There's no reliable filename heuristic
|
||||
// we can rely on (3dspeaker, NeMo, WeSpeaker all ship with
|
||||
// model-specific names), so we require it to be explicit.
|
||||
// - Both paths are resolved relative to opts.ModelPath if not absolute.
|
||||
|
||||
func (s *SherpaBackend) loadDiarization(opts *pb.ModelOptions) error {
|
||||
if s.diarizer != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelDir := filepath.Dir(opts.ModelFile)
|
||||
segModel := findOptionValue(opts, optionDiarizeSegmentationModel, opts.ModelFile)
|
||||
if segModel != "" && !filepath.IsAbs(segModel) && opts.ModelPath != "" {
|
||||
segModel = filepath.Join(opts.ModelPath, segModel)
|
||||
}
|
||||
if !fileExists(segModel) {
|
||||
return fmt.Errorf("sherpa-onnx diarization: pyannote segmentation model not found at %q (set diarize.segmentation_model=...)", segModel)
|
||||
}
|
||||
|
||||
embModel := findOptionValue(opts, optionDiarizeEmbeddingModel, "")
|
||||
if embModel == "" {
|
||||
return fmt.Errorf("sherpa-onnx diarization: speaker-embedding model is required — pass options: [diarize.embedding_model=<path>] (e.g. 3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx)")
|
||||
}
|
||||
if !filepath.IsAbs(embModel) {
|
||||
base := opts.ModelPath
|
||||
if base == "" {
|
||||
base = modelDir
|
||||
}
|
||||
embModel = filepath.Join(base, embModel)
|
||||
}
|
||||
if !fileExists(embModel) {
|
||||
return fmt.Errorf("sherpa-onnx diarization: speaker-embedding model not found at %q", embModel)
|
||||
}
|
||||
|
||||
threads := int32(1)
|
||||
if opts.Threads != 0 {
|
||||
threads = opts.Threads
|
||||
}
|
||||
|
||||
cfg := shimDiarizeConfigNew()
|
||||
defer shimDiarizeConfigFree(cfg)
|
||||
|
||||
shimDiarizeConfigSetSegmentationModel(cfg, segModel)
|
||||
shimDiarizeConfigSetSegmentationNumThreads(cfg, threads)
|
||||
shimDiarizeConfigSetSegmentationProvider(cfg, onnxProvider)
|
||||
shimDiarizeConfigSetSegmentationDebug(cfg, 0)
|
||||
|
||||
shimDiarizeConfigSetEmbeddingModel(cfg, embModel)
|
||||
shimDiarizeConfigSetEmbeddingNumThreads(cfg, threads)
|
||||
shimDiarizeConfigSetEmbeddingProvider(cfg, onnxProvider)
|
||||
shimDiarizeConfigSetEmbeddingDebug(cfg, 0)
|
||||
|
||||
shimDiarizeConfigSetClusteringNumClusters(cfg, findOptionInt(opts, optionDiarizeNumClusters, -1))
|
||||
shimDiarizeConfigSetClusteringThreshold(cfg, findOptionFloat(opts, optionDiarizeThreshold, 0.5))
|
||||
shimDiarizeConfigSetMinDurationOn(cfg, findOptionFloat(opts, optionDiarizeMinDurationOn, 0.3))
|
||||
shimDiarizeConfigSetMinDurationOff(cfg, findOptionFloat(opts, optionDiarizeMinDurationOff, 0.5))
|
||||
|
||||
sd := shimCreateOfflineSpeakerDiarization(cfg)
|
||||
if sd == 0 {
|
||||
return fmt.Errorf("sherpa-onnx diarization: failed to create diarizer (segmentation=%s embedding=%s)", segModel, embModel)
|
||||
}
|
||||
s.diarizer = sd
|
||||
s.diarSampleRate = int(sherpaOfflineSpeakerDiarizationGetSampleRate(sd))
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyDiarizeOverrides re-applies clustering knobs onto an existing
|
||||
// diarizer when per-call DiarizeRequest fields are set. Both -1/0 sentinels
|
||||
// follow sherpa's convention: num_clusters<=0 → use threshold-based
|
||||
// clustering, threshold<=0 → keep load-time default.
|
||||
func (s *SherpaBackend) applyDiarizeOverrides(req *pb.DiarizeRequest) {
|
||||
num := int32(-1)
|
||||
if req.NumSpeakers > 0 {
|
||||
num = req.NumSpeakers
|
||||
}
|
||||
threshold := float32(0)
|
||||
if req.ClusteringThreshold > 0 {
|
||||
threshold = req.ClusteringThreshold
|
||||
}
|
||||
if num > 0 || threshold > 0 {
|
||||
shimDiarizeSetClustering(s.diarizer, num, threshold)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SherpaBackend) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, error) {
|
||||
if s.diarizer == 0 {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization not loaded (model must be loaded with type=diarization)")
|
||||
}
|
||||
if req.Dst == "" {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: DiarizeRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "sherpa-diarize")
|
||||
if err != nil {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
wavPath := filepath.Join(dir, "input.wav")
|
||||
if err := utils.AudioToWav(req.Dst, wavPath); err != nil {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("failed to convert audio to wav: %w", err)
|
||||
}
|
||||
|
||||
wave := sherpaReadWave(wavPath)
|
||||
if wave == 0 {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("failed to read wav %s", wavPath)
|
||||
}
|
||||
defer sherpaFreeWave(wave)
|
||||
|
||||
sr := int(shimWaveSampleRate(wave))
|
||||
nSamples := shimWaveNumSamples(wave)
|
||||
samples := shimWaveSamples(wave)
|
||||
duration := float32(nSamples) / float32(sr)
|
||||
if sr != s.diarSampleRate {
|
||||
// AudioToWav already targets 16 kHz; pyannote-3.0 also wants 16 kHz, so
|
||||
// this branch should be unreachable. Fail loudly instead of silently
|
||||
// passing mismatched audio to the model.
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: input sample rate %d Hz does not match model %d Hz", sr, s.diarSampleRate)
|
||||
}
|
||||
|
||||
s.applyDiarizeOverrides(req)
|
||||
|
||||
result := sherpaOfflineSpeakerDiarizationProcess(s.diarizer, samples, nSamples)
|
||||
if result == 0 {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: process failed")
|
||||
}
|
||||
defer sherpaDestroyOfflineSpeakerDiarizationResult(result)
|
||||
|
||||
numSegments := sherpaOfflineSpeakerDiarizationResultGetNumSegments(result)
|
||||
numSpeakers := sherpaOfflineSpeakerDiarizationResultGetNumSpeakers(result)
|
||||
if numSegments <= 0 {
|
||||
return pb.DiarizeResponse{
|
||||
Segments: []*pb.DiarizeSegment{},
|
||||
NumSpeakers: numSpeakers,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
segs := sherpaOfflineSpeakerDiarizationResultSortByStartTime(result)
|
||||
if segs == 0 {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("sherpa-onnx diarization: failed to retrieve segments")
|
||||
}
|
||||
defer sherpaOfflineSpeakerDiarizationDestroySegment(segs)
|
||||
|
||||
out := make([]*pb.DiarizeSegment, 0, numSegments)
|
||||
for i := range int(numSegments) {
|
||||
var start, end float32
|
||||
var spk int32
|
||||
shimDiarizeSegmentAt(segs, int32(i),
|
||||
unsafe.Pointer(&start), unsafe.Pointer(&end), unsafe.Pointer(&spk))
|
||||
out = append(out, &pb.DiarizeSegment{
|
||||
Id: int32(i),
|
||||
Start: start,
|
||||
End: end,
|
||||
Speaker: strconv.FormatInt(int64(spk), 10),
|
||||
})
|
||||
}
|
||||
return pb.DiarizeResponse{
|
||||
Segments: out,
|
||||
NumSpeakers: numSpeakers,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -310,6 +310,87 @@ int32_t sherpa_shim_speech_segment_n(const void *h) {
|
||||
return ((const SherpaOnnxSpeechSegment *)h)->n;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Offline speaker diarization config
|
||||
// ==================================================================
|
||||
|
||||
void *sherpa_shim_diarize_config_new(void) {
|
||||
return calloc(1, sizeof(SherpaOnnxOfflineSpeakerDiarizationConfig));
|
||||
}
|
||||
|
||||
void sherpa_shim_diarize_config_free(void *h) {
|
||||
if (!h) return;
|
||||
SherpaOnnxOfflineSpeakerDiarizationConfig *c =
|
||||
(SherpaOnnxOfflineSpeakerDiarizationConfig *)h;
|
||||
free((char *)c->segmentation.pyannote.model);
|
||||
free((char *)c->segmentation.provider);
|
||||
free((char *)c->embedding.model);
|
||||
free((char *)c->embedding.provider);
|
||||
free(c);
|
||||
}
|
||||
|
||||
void sherpa_shim_diarize_config_set_segmentation_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.pyannote.model, v);
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_segmentation_num_threads(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.num_threads = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_segmentation_provider(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.provider, v);
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_segmentation_debug(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->segmentation.debug = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_embedding_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.model, v);
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_embedding_num_threads(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.num_threads = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_embedding_provider(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.provider, v);
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_embedding_debug(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->embedding.debug = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_clustering_num_clusters(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->clustering.num_clusters = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_clustering_threshold(void *h, float v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->clustering.threshold = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_min_duration_on(void *h, float v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->min_duration_on = v;
|
||||
}
|
||||
void sherpa_shim_diarize_config_set_min_duration_off(void *h, float v) {
|
||||
((SherpaOnnxOfflineSpeakerDiarizationConfig *)h)->min_duration_off = v;
|
||||
}
|
||||
|
||||
void *sherpa_shim_create_offline_speaker_diarization(void *h) {
|
||||
return (void *)SherpaOnnxCreateOfflineSpeakerDiarization(
|
||||
(const SherpaOnnxOfflineSpeakerDiarizationConfig *)h);
|
||||
}
|
||||
|
||||
void sherpa_shim_diarize_set_clustering(void *sd, int32_t num_clusters, float threshold) {
|
||||
if (!sd) return;
|
||||
SherpaOnnxOfflineSpeakerDiarizationConfig cfg;
|
||||
memset(&cfg, 0, sizeof(cfg));
|
||||
cfg.clustering.num_clusters = num_clusters;
|
||||
cfg.clustering.threshold = threshold;
|
||||
SherpaOnnxOfflineSpeakerDiarizationSetConfig(
|
||||
(const SherpaOnnxOfflineSpeakerDiarization *)sd, &cfg);
|
||||
}
|
||||
|
||||
void sherpa_shim_diarize_segment_at(const void *segs, int32_t i,
|
||||
float *out_start, float *out_end,
|
||||
int32_t *out_speaker) {
|
||||
const SherpaOnnxOfflineSpeakerDiarizationSegment *arr =
|
||||
(const SherpaOnnxOfflineSpeakerDiarizationSegment *)segs;
|
||||
if (out_start) *out_start = arr[i].start;
|
||||
if (out_end) *out_end = arr[i].end;
|
||||
if (out_speaker) *out_speaker = arr[i].speaker;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// TTS streaming callback trampoline
|
||||
// ==================================================================
|
||||
|
||||
@@ -109,6 +109,41 @@ const float *sherpa_shim_generated_audio_samples(const void *audio);
|
||||
int32_t sherpa_shim_speech_segment_start(const void *seg);
|
||||
int32_t sherpa_shim_speech_segment_n(const void *seg);
|
||||
|
||||
// --- Offline speaker diarization config -----------------------------
|
||||
// Pyannote segmentation + speaker-embedding extractor + fast clustering.
|
||||
// The upstream config is a struct of nested structs; purego can't read or
|
||||
// build those across dlopen, so we expose a calloc'd opaque holder plus
|
||||
// flat setters, then hand it to sherpa via the create wrapper.
|
||||
void *sherpa_shim_diarize_config_new(void);
|
||||
void sherpa_shim_diarize_config_free(void *cfg);
|
||||
void sherpa_shim_diarize_config_set_segmentation_model(void *cfg, const char *path);
|
||||
void sherpa_shim_diarize_config_set_segmentation_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_segmentation_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_diarize_config_set_segmentation_debug(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_embedding_model(void *cfg, const char *path);
|
||||
void sherpa_shim_diarize_config_set_embedding_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_embedding_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_diarize_config_set_embedding_debug(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_clustering_num_clusters(void *cfg, int32_t v);
|
||||
void sherpa_shim_diarize_config_set_clustering_threshold(void *cfg, float v);
|
||||
void sherpa_shim_diarize_config_set_min_duration_on(void *cfg, float v);
|
||||
void sherpa_shim_diarize_config_set_min_duration_off(void *cfg, float v);
|
||||
void *sherpa_shim_create_offline_speaker_diarization(void *cfg);
|
||||
|
||||
// Apply just the clustering knobs onto a loaded diarizer (sherpa
|
||||
// supports re-clustering after Create), so per-call overrides like
|
||||
// num_speakers don't require re-loading the heavy ONNX models.
|
||||
void sherpa_shim_diarize_set_clustering(void *sd, int32_t num_clusters, float threshold);
|
||||
|
||||
// Sherpa's ResultSortByStartTime returns a sherpa-allocated array of
|
||||
// SherpaOnnxOfflineSpeakerDiarizationSegment structs (free with
|
||||
// SherpaOnnxOfflineSpeakerDiarizationDestroySegment). Purego can't read
|
||||
// fields out of an array of C structs, so this getter copies one
|
||||
// segment's fields into the caller-supplied float/int32 cells.
|
||||
void sherpa_shim_diarize_segment_at(const void *segs, int32_t i,
|
||||
float *out_start, float *out_end,
|
||||
int32_t *out_speaker);
|
||||
|
||||
// --- TTS streaming callback trampoline -----------------------------
|
||||
// Replaces the //export sherpaTtsGoCallback + callbacks.c bridge pattern.
|
||||
// `callback_ptr` is the C-callable function pointer returned by
|
||||
|
||||
@@ -3,7 +3,9 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -12,6 +14,84 @@ import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// vv_capi_asr loads audio with load_wav_24k_mono — a 24 kHz mono s16le
|
||||
// WAV is the format the model was trained on. Inputs already in that
|
||||
// format pass through; everything else is converted via ffmpeg, which
|
||||
// is therefore a runtime requirement only when callers upload non-WAV
|
||||
// (or non-24 kHz mono s16le WAV) audio. Skipping ffmpeg on the happy
|
||||
// path matters for the e2e-backends test container, which does not
|
||||
// ship ffmpeg but feeds the backend pre-cooked 24 kHz mono WAVs.
|
||||
const vibevoiceASRSampleRate = 24000
|
||||
|
||||
// prepareWavInput resolves `src` to a 24 kHz mono s16le WAV path that
|
||||
// vv_capi_asr's load_wav_24k_mono accepts. Returns the resolved path
|
||||
// plus a cleanup func; both must be honoured by the caller.
|
||||
//
|
||||
// Pass-through happens when `src` already has the right WAV format —
|
||||
// no ffmpeg required. Otherwise we shell out to ffmpeg into a temp
|
||||
// dir; if ffmpeg isn't on PATH we surface a clear error mentioning the
|
||||
// underlying format mismatch.
|
||||
func prepareWavInput(src string) (string, func(), error) {
|
||||
if src == "" {
|
||||
return "", func() {}, fmt.Errorf("empty audio path")
|
||||
}
|
||||
if isVibevoiceCompatibleWav(src) {
|
||||
return src, func() {}, nil
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "vibevoice-asr")
|
||||
if err != nil {
|
||||
return "", func() {}, fmt.Errorf("mkdtemp: %w", err)
|
||||
}
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
wavPath := filepath.Join(dir, "input.wav")
|
||||
|
||||
// -y: overwrite, -ar 24000: target sample rate, -ac 1: mono,
|
||||
// -acodec pcm_s16le: signed 16-bit little-endian PCM (load_wav_24k_mono
|
||||
// only accepts s16le).
|
||||
cmd := exec.Command("ffmpeg",
|
||||
"-y", "-i", src,
|
||||
"-ar", fmt.Sprintf("%d", vibevoiceASRSampleRate),
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
wavPath,
|
||||
)
|
||||
cmd.Env = []string{}
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, fmt.Errorf("ffmpeg convert to 24k mono wav: %w (output: %s)", err, string(out))
|
||||
}
|
||||
return wavPath, cleanup, nil
|
||||
}
|
||||
|
||||
// isVibevoiceCompatibleWav returns true when `src` carries the RIFF/WAVE
|
||||
// magic bytes. vibevoice's load_wav_24k_mono uses drwav under the hood,
|
||||
// which accepts any PCM/IEEE-float WAV at any sample rate and downmixes
|
||||
// multi-channel input to mono on its own — so any valid WAV passes
|
||||
// through to the C side without conversion. Anything else (MP3, OGG,
|
||||
// FLAC, ...) needs ffmpeg.
|
||||
func isVibevoiceCompatibleWav(src string) bool {
|
||||
f, err := os.Open(src)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
// 0..3 = "RIFF", 8..11 = "WAVE".
|
||||
var hdr [12]byte
|
||||
if _, err := io.ReadFull(f, hdr[:]); err != nil {
|
||||
return false
|
||||
}
|
||||
return string(hdr[0:4]) == "RIFF" && string(hdr[8:12]) == "WAVE"
|
||||
}
|
||||
|
||||
// asrMaxNewTokens caps the ASR generation budget. The C ABI defaults to
|
||||
// 256 when 0 is passed — far too small for anything past ~10s of speech.
|
||||
// Vibevoice generates ~30 tokens per second of audio, so 16 384 covers
|
||||
// roughly 9 minutes of dialogue, well past any normal /v1/audio/diarization
|
||||
// upload. Going higher costs little since generation stops at EOS.
|
||||
const asrMaxNewTokens = 16384
|
||||
|
||||
// vibevoice.cpp synthesizes 24 kHz mono 16-bit PCM. Hardcoded - the
|
||||
// model itself is fixed-rate; if the upstream ever changes this we'll
|
||||
// pick it up via vv_capi_version().
|
||||
@@ -302,7 +382,13 @@ func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.Transcr
|
||||
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
out, err := v.callASR(req.Dst, 0)
|
||||
wavPath, cleanup, err := prepareWavInput(req.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
out, err := v.callASR(wavPath, asrMaxNewTokens)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
@@ -346,6 +432,83 @@ func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.Transcr
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Diarize runs vibevoice's ASR and projects the speaker-labelled segment
|
||||
// list it returns natively. vibevoice.cpp's ASR prompt asks the model to
|
||||
// emit `[{"Start":..,"End":..,"Speaker":..,"Content":..}]`, so diarization
|
||||
// is a by-product of the same pass — we reuse callASR and re-shape.
|
||||
//
|
||||
// Speaker hints (num_speakers/min/max/threshold) and min_duration_on/off are
|
||||
// not actionable here: vibevoice's model picks the speaker count itself and
|
||||
// has no clustering knob. The HTTP layer documents this; we accept the
|
||||
// fields for API symmetry and ignore them.
|
||||
func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, error) {
|
||||
if v.asrModel == "" {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: Diarize requires an ASR model (load options: type=asr)")
|
||||
}
|
||||
if req.Dst == "" {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: DiarizeRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
wavPath, cleanup, err := prepareWavInput(req.Dst)
|
||||
if err != nil {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("vibevoice-cpp: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
out, err := v.callASR(wavPath, asrMaxNewTokens)
|
||||
if err != nil {
|
||||
return pb.DiarizeResponse{}, err
|
||||
}
|
||||
if out == "" {
|
||||
return pb.DiarizeResponse{}, nil
|
||||
}
|
||||
|
||||
var segs []asrSegment
|
||||
if err := json.Unmarshal([]byte(out), &segs); err != nil {
|
||||
// Mirror AudioTranscription's fallback: vibevoice's ASR sometimes
|
||||
// emits free-form text instead of JSON for short or unusual audio.
|
||||
// Surface a single unknown-speaker segment carrying the full text
|
||||
// (when include_text is set) so the caller still gets coverage of
|
||||
// the whole clip rather than a hard failure.
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"[vibevoice-cpp] WARNING: vv_capi_asr returned non-JSON for diarization, falling back to single segment: %v\n", err)
|
||||
text := strings.TrimSpace(out)
|
||||
seg := &pb.DiarizeSegment{Id: 0, Speaker: "0"}
|
||||
if req.IncludeText {
|
||||
seg.Text = text
|
||||
}
|
||||
return pb.DiarizeResponse{
|
||||
Segments: []*pb.DiarizeSegment{seg},
|
||||
NumSpeakers: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
speakers := make(map[int]struct{})
|
||||
segments := make([]*pb.DiarizeSegment, 0, len(segs))
|
||||
var duration float32
|
||||
for i, s := range segs {
|
||||
ds := &pb.DiarizeSegment{
|
||||
Id: int32(i),
|
||||
Start: float32(s.Start),
|
||||
End: float32(s.End),
|
||||
Speaker: fmt.Sprintf("%d", s.Speaker),
|
||||
}
|
||||
if req.IncludeText {
|
||||
ds.Text = strings.TrimSpace(s.Content)
|
||||
}
|
||||
segments = append(segments, ds)
|
||||
speakers[s.Speaker] = struct{}{}
|
||||
if float32(s.End) > duration {
|
||||
duration = float32(s.End)
|
||||
}
|
||||
}
|
||||
return pb.DiarizeResponse{
|
||||
Segments: segments,
|
||||
NumSpeakers: int32(len(speakers)),
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream wraps AudioTranscription so the streaming
|
||||
// gRPC endpoint (server.go:AudioTranscriptionStream) sees its channel
|
||||
// close and the client doesn't sit waiting until deadline. vibevoice's
|
||||
|
||||
158
core/backend/diarization.go
Normal file
158
core/backend/diarization.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// DiarizationRequest carries the diarization-specific knobs the HTTP
|
||||
// layer collects. Speaker hints (NumSpeakers / MinSpeakers / MaxSpeakers)
|
||||
// and clustering knobs are optional — backends ignore the ones they
|
||||
// don't act on. IncludeText only matters for backends that emit
|
||||
// per-segment transcripts as a by-product (e.g. vibevoice.cpp).
|
||||
type DiarizationRequest struct {
|
||||
Audio string
|
||||
Language string
|
||||
NumSpeakers int32
|
||||
MinSpeakers int32
|
||||
MaxSpeakers int32
|
||||
ClusteringThreshold float32
|
||||
MinDurationOn float32
|
||||
MinDurationOff float32
|
||||
IncludeText bool
|
||||
}
|
||||
|
||||
func (r *DiarizationRequest) toProto(threads uint32) *proto.DiarizeRequest {
|
||||
return &proto.DiarizeRequest{
|
||||
Dst: r.Audio,
|
||||
Threads: threads,
|
||||
Language: r.Language,
|
||||
NumSpeakers: r.NumSpeakers,
|
||||
MinSpeakers: r.MinSpeakers,
|
||||
MaxSpeakers: r.MaxSpeakers,
|
||||
ClusteringThreshold: r.ClusteringThreshold,
|
||||
MinDurationOn: r.MinDurationOn,
|
||||
MinDurationOff: r.MinDurationOff,
|
||||
IncludeText: r.IncludeText,
|
||||
}
|
||||
}
|
||||
|
||||
func loadDiarizationModel(ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) {
|
||||
if modelConfig.Backend == "" {
|
||||
return nil, fmt.Errorf("diarization: model %q has no backend set; supported backends include vibevoice-cpp and sherpa-onnx", modelConfig.Name)
|
||||
}
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
m, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if m == nil {
|
||||
return nil, fmt.Errorf("could not load diarization model")
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ModelDiarization runs the Diarize RPC against the configured backend
|
||||
// and returns a normalized schema.DiarizationResult.
|
||||
func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) {
|
||||
m, err := loadDiarizationModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
threads := uint32(0)
|
||||
if modelConfig.Threads != nil {
|
||||
threads = uint32(*modelConfig.Threads)
|
||||
}
|
||||
|
||||
r, err := m.Diarize(context.Background(), req.toProto(threads))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return diarizationResultFromProto(r), nil
|
||||
}
|
||||
|
||||
// diarizationResultFromProto normalizes backend speaker labels to
|
||||
// "SPEAKER_NN" — the convention pyannote/RTTM tooling expects — while
|
||||
// keeping the original label available via the Speaker field. Each
|
||||
// distinct backend label gets its own normalized id, in first-seen order.
|
||||
func diarizationResultFromProto(r *proto.DiarizeResponse) *schema.DiarizationResult {
|
||||
if r == nil {
|
||||
return &schema.DiarizationResult{Segments: []schema.DiarizationSegment{}}
|
||||
}
|
||||
|
||||
out := &schema.DiarizationResult{
|
||||
Task: "diarize",
|
||||
Duration: float64(r.Duration),
|
||||
Language: r.Language,
|
||||
Segments: make([]schema.DiarizationSegment, 0, len(r.Segments)),
|
||||
}
|
||||
|
||||
type speakerStats struct {
|
||||
idx int
|
||||
duration float64
|
||||
segments int
|
||||
}
|
||||
stats := map[string]*speakerStats{}
|
||||
order := []string{}
|
||||
|
||||
for i, s := range r.Segments {
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
raw := s.Speaker
|
||||
if raw == "" {
|
||||
raw = "0"
|
||||
}
|
||||
st, ok := stats[raw]
|
||||
if !ok {
|
||||
st = &speakerStats{idx: len(order)}
|
||||
stats[raw] = st
|
||||
order = append(order, raw)
|
||||
}
|
||||
dur := float64(s.End) - float64(s.Start)
|
||||
if dur > 0 {
|
||||
st.duration += dur
|
||||
}
|
||||
st.segments++
|
||||
|
||||
out.Segments = append(out.Segments, schema.DiarizationSegment{
|
||||
Id: i,
|
||||
Speaker: fmt.Sprintf("SPEAKER_%02d", st.idx),
|
||||
Label: raw,
|
||||
Start: float64(s.Start),
|
||||
End: float64(s.End),
|
||||
Text: s.Text,
|
||||
})
|
||||
}
|
||||
|
||||
out.NumSpeakers = len(order)
|
||||
if out.NumSpeakers == 0 && r.NumSpeakers > 0 {
|
||||
out.NumSpeakers = int(r.NumSpeakers)
|
||||
}
|
||||
|
||||
out.Speakers = make([]schema.DiarizationSpeaker, 0, len(order))
|
||||
for _, raw := range order {
|
||||
st := stats[raw]
|
||||
out.Speakers = append(out.Speakers, schema.DiarizationSpeaker{
|
||||
Id: fmt.Sprintf("SPEAKER_%02d", st.idx),
|
||||
Label: raw,
|
||||
TotalSpeechDuration: st.duration,
|
||||
SegmentCount: st.segments,
|
||||
})
|
||||
}
|
||||
sort.SliceStable(out.Speakers, func(i, j int) bool {
|
||||
return out.Speakers[i].Id < out.Speakers[j].Id
|
||||
})
|
||||
|
||||
return out
|
||||
}
|
||||
76
core/backend/diarization_test.go
Normal file
76
core/backend/diarization_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("diarizationResultFromProto", func() {
|
||||
It("normalises raw backend speaker labels to SPEAKER_NN in first-seen order", func() {
|
||||
in := &proto.DiarizeResponse{
|
||||
Duration: 10.5,
|
||||
Language: "en",
|
||||
Segments: []*proto.DiarizeSegment{
|
||||
{Start: 0.0, End: 1.0, Speaker: "5", Text: "hi"},
|
||||
{Start: 1.0, End: 2.0, Speaker: "2"},
|
||||
{Start: 2.0, End: 3.5, Speaker: "5"},
|
||||
{Start: 3.5, End: 4.0, Speaker: ""}, // empty → coerced to "0"
|
||||
},
|
||||
}
|
||||
|
||||
got := diarizationResultFromProto(in)
|
||||
|
||||
Expect(got.Task).To(Equal("diarize"))
|
||||
Expect(got.NumSpeakers).To(Equal(3), "expected 3 distinct speakers (5, 2, 0)")
|
||||
Expect(got.Duration).To(BeEquivalentTo(10.5))
|
||||
Expect(got.Language).To(Equal("en"))
|
||||
Expect(got.Segments).To(HaveLen(4))
|
||||
|
||||
// First-seen-order normalisation: "5"→SPEAKER_00, "2"→SPEAKER_01, ""→SPEAKER_02
|
||||
want := []struct {
|
||||
speaker string
|
||||
label string
|
||||
}{
|
||||
{"SPEAKER_00", "5"},
|
||||
{"SPEAKER_01", "2"},
|
||||
{"SPEAKER_00", "5"},
|
||||
{"SPEAKER_02", "0"},
|
||||
}
|
||||
for i, w := range want {
|
||||
Expect(got.Segments[i].Speaker).To(Equal(w.speaker), "seg[%d].speaker", i)
|
||||
Expect(got.Segments[i].Label).To(Equal(w.label), "seg[%d].label", i)
|
||||
}
|
||||
|
||||
// Per-speaker totals reflect cumulative speech duration and segment count.
|
||||
Expect(got.Speakers).To(HaveLen(3))
|
||||
byID := map[string]float64{}
|
||||
countByID := map[string]int{}
|
||||
for _, sp := range got.Speakers {
|
||||
byID[sp.Id] = sp.TotalSpeechDuration
|
||||
countByID[sp.Id] = sp.SegmentCount
|
||||
}
|
||||
Expect(byID["SPEAKER_00"]).To(BeNumerically("~", 2.5, 0.001), "1.0 + 1.5")
|
||||
Expect(byID["SPEAKER_01"]).To(BeNumerically("~", 1.0, 0.001))
|
||||
Expect(countByID["SPEAKER_00"]).To(Equal(2))
|
||||
Expect(countByID["SPEAKER_01"]).To(Equal(1))
|
||||
Expect(countByID["SPEAKER_02"]).To(Equal(1))
|
||||
})
|
||||
|
||||
It("returns a non-nil result with a non-nil segments slice for nil input", func() {
|
||||
got := diarizationResultFromProto(nil)
|
||||
Expect(got).ToNot(BeNil())
|
||||
Expect(got.Segments).ToNot(BeNil())
|
||||
Expect(got.Segments).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("keeps the backend speaker count when no segments are returned", func() {
|
||||
// Backend reports a non-zero NumSpeakers but no segments (early stop,
|
||||
// silence-only audio after VAD trim). Surface the backend's count.
|
||||
in := &proto.DiarizeResponse{NumSpeakers: 2, Duration: 5}
|
||||
got := diarizationResultFromProto(in)
|
||||
Expect(got.NumSpeakers).To(Equal(2))
|
||||
Expect(got.Segments).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -634,6 +634,7 @@ const (
|
||||
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b10000000000000
|
||||
FLAG_SPEAKER_RECOGNITION ModelConfigUsecase = 0b100000000000000
|
||||
FLAG_AUDIO_TRANSFORM ModelConfigUsecase = 0b1000000000000000
|
||||
FLAG_DIARIZATION ModelConfigUsecase = 0b10000000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
@@ -660,6 +661,7 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
|
||||
"FLAG_SPEAKER_RECOGNITION": FLAG_SPEAKER_RECOGNITION,
|
||||
"FLAG_AUDIO_TRANSFORM": FLAG_AUDIO_TRANSFORM,
|
||||
"FLAG_DIARIZATION": FLAG_DIARIZATION,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -824,6 +826,16 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_DIARIZATION) == FLAG_DIARIZATION {
|
||||
// vibevoice-cpp emits speaker-labelled segments natively from its
|
||||
// ASR pass; sherpa-onnx pipes pyannote segmentation + speaker
|
||||
// embeddings + clustering. Both surface as a Diarize gRPC.
|
||||
diarizationBackends := []string{"vibevoice-cpp", "sherpa-onnx"}
|
||||
if !slices.Contains(diarizationBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,10 @@ var RouteFeatureRegistry = []RouteFeature{
|
||||
{"POST", "/v1/audio/transcriptions", FeatureAudioTranscription},
|
||||
{"POST", "/audio/transcriptions", FeatureAudioTranscription},
|
||||
|
||||
// Audio diarization (speaker turns)
|
||||
{"POST", "/v1/audio/diarization", FeatureAudioDiarization},
|
||||
{"POST", "/audio/diarization", FeatureAudioDiarization},
|
||||
|
||||
// Audio speech / TTS
|
||||
{"POST", "/v1/audio/speech", FeatureAudioSpeech},
|
||||
{"POST", "/audio/speech", FeatureAudioSpeech},
|
||||
@@ -163,6 +167,7 @@ func APIFeatureMetas() []FeatureMeta {
|
||||
{FeatureImages, "Image Generation", true},
|
||||
{FeatureAudioSpeech, "Audio Speech / TTS", true},
|
||||
{FeatureAudioTranscription, "Audio Transcription", true},
|
||||
{FeatureAudioDiarization, "Audio Diarization", true},
|
||||
{FeatureVAD, "Voice Activity Detection", true},
|
||||
{FeatureDetection, "Detection", true},
|
||||
{FeatureVideo, "Video Generation", true},
|
||||
|
||||
@@ -42,6 +42,7 @@ const (
|
||||
FeatureImages = "images"
|
||||
FeatureAudioSpeech = "audio_speech"
|
||||
FeatureAudioTranscription = "audio_transcription"
|
||||
FeatureAudioDiarization = "audio_diarization"
|
||||
FeatureVAD = "vad"
|
||||
FeatureDetection = "detection"
|
||||
FeatureVideo = "video"
|
||||
@@ -66,6 +67,7 @@ var GeneralFeatures = []string{FeatureFineTuning, FeatureQuantization}
|
||||
// APIFeatures lists API endpoint features (default ON).
|
||||
var APIFeatures = []string{
|
||||
FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription,
|
||||
FeatureAudioDiarization,
|
||||
FeatureVAD, FeatureDetection, FeatureVideo, FeatureEmbeddings, FeatureSound,
|
||||
FeatureRealtime, FeatureRerank, FeatureTokenize, FeatureMCP, FeatureStores,
|
||||
FeatureFaceRecognition, FeatureVoiceRecognition, FeatureAudioTransform,
|
||||
|
||||
@@ -32,8 +32,9 @@ var instructionDefs = []instructionDef{
|
||||
},
|
||||
{
|
||||
Name: "audio",
|
||||
Description: "Text-to-speech, voice activity detection, transcription, and sound generation",
|
||||
Description: "Text-to-speech, voice activity detection, transcription, speaker diarization, and sound generation",
|
||||
Tags: []string{"audio"},
|
||||
Intro: "Diarization (/v1/audio/diarization) returns speaker-labelled time segments. Backends with native ASR-diarization (vibevoice-cpp) can also emit per-segment text via include_text=true; backends with a dedicated pipeline (sherpa-onnx + pyannote) emit segmentation only. Response formats: json (default), verbose_json (adds speakers summary + text), rttm (NIST format).",
|
||||
},
|
||||
{
|
||||
Name: "images",
|
||||
|
||||
181
core/http/endpoints/openai/diarization.go
Normal file
181
core/http/endpoints/openai/diarization.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// DiarizationEndpoint runs offline speaker diarization on an uploaded
|
||||
// audio file and returns "who spoke when". Backends with a pure
|
||||
// diarization pipeline (sherpa-onnx + pyannote) emit only segmentation;
|
||||
// backends that produce diarization as a by-product of ASR (vibevoice.cpp)
|
||||
// can additionally fill in the per-segment transcript when the caller
|
||||
// passes `include_text=true`.
|
||||
//
|
||||
// Response formats follow transcription's: `json` (default, segments only),
|
||||
// `verbose_json` (adds speaker summary and per-segment text), and `rttm`
|
||||
// (NIST RTTM, the standard interchange format used by pyannote/dscore).
|
||||
//
|
||||
// @Summary Identify speakers in audio (who spoke when).
|
||||
// @Tags audio
|
||||
// @accept multipart/form-data
|
||||
// @Param model formData string true "model"
|
||||
// @Param file formData file true "audio file"
|
||||
// @Param num_speakers formData int false "exact speaker count (>0 forces; 0 = auto)"
|
||||
// @Param min_speakers formData int false "lower bound when auto-detecting"
|
||||
// @Param max_speakers formData int false "upper bound when auto-detecting"
|
||||
// @Param clustering_threshold formData number false "clustering distance threshold when num_speakers is unknown"
|
||||
// @Param min_duration_on formData number false "discard segments shorter than this (seconds)"
|
||||
// @Param min_duration_off formData number false "merge gaps shorter than this (seconds)"
|
||||
// @Param language formData string false "audio language hint (only meaningful for backends that bundle ASR)"
|
||||
// @Param include_text formData boolean false "include per-segment transcript when the backend supports it"
|
||||
// @Param response_format formData string false "json (default), verbose_json, or rttm"
|
||||
// @Success 200 {object} schema.DiarizationResult
|
||||
// @Router /v1/audio/diarization [post]
|
||||
func DiarizationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
modelConfig, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || modelConfig == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
req := backend.DiarizationRequest{
|
||||
Language: input.Language,
|
||||
IncludeText: parseFormBool(c, "include_text", false),
|
||||
}
|
||||
req.NumSpeakers = int32(parseFormInt(c, "num_speakers", 0))
|
||||
req.MinSpeakers = int32(parseFormInt(c, "min_speakers", 0))
|
||||
req.MaxSpeakers = int32(parseFormInt(c, "max_speakers", 0))
|
||||
req.ClusteringThreshold = float32(parseFormFloat(c, "clustering_threshold", 0))
|
||||
req.MinDurationOn = float32(parseFormFloat(c, "min_duration_on", 0))
|
||||
req.MinDurationOff = float32(parseFormFloat(c, "min_duration_off", 0))
|
||||
|
||||
responseFormat := schema.DiarizationResponseFormatType(strings.ToLower(c.FormValue("response_format")))
|
||||
if responseFormat == "" {
|
||||
responseFormat = schema.DiarizationResponseFormatJson
|
||||
}
|
||||
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
dir, err := os.MkdirTemp("", "diarize")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename))
|
||||
dstFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(dstFile, f); err != nil {
|
||||
xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err)
|
||||
_ = dstFile.Close()
|
||||
return err
|
||||
}
|
||||
_ = dstFile.Close()
|
||||
req.Audio = dst
|
||||
|
||||
result, err := backend.ModelDiarization(req, ml, *modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch responseFormat {
|
||||
case schema.DiarizationResponseFormatRTTM:
|
||||
c.Response().Header().Set(echo.HeaderContentType, "text/plain; charset=utf-8")
|
||||
return c.String(http.StatusOK, renderRTTM(result, file.Filename))
|
||||
case schema.DiarizationResponseFormatJson:
|
||||
// Default JSON: drop the heavy per-speaker summary and any
|
||||
// optional per-segment text so simple consumers see a tight
|
||||
// payload. verbose_json keeps everything.
|
||||
result.Speakers = nil
|
||||
for i := range result.Segments {
|
||||
result.Segments[i].Text = ""
|
||||
}
|
||||
return c.JSON(http.StatusOK, result)
|
||||
case schema.DiarizationResponseFormatJsonVerbose:
|
||||
return c.JSON(http.StatusOK, result)
|
||||
default:
|
||||
return errors.New("invalid response_format (expected: json, verbose_json, rttm)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renderRTTM emits NIST RTTM rows. Each row:
|
||||
// SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker> <NA> <NA>
|
||||
// Field separators are spaces; one row per segment.
|
||||
func renderRTTM(r *schema.DiarizationResult, sourceFile string) string {
|
||||
id := strings.TrimSuffix(filepath.Base(sourceFile), filepath.Ext(sourceFile))
|
||||
// filepath.Base("") returns "." — treat both as a missing source name and
|
||||
// fall back to a stable placeholder so the RTTM row stays parseable.
|
||||
if id == "" || id == "." {
|
||||
id = "audio"
|
||||
}
|
||||
var sb strings.Builder
|
||||
for _, seg := range r.Segments {
|
||||
dur := seg.End - seg.Start
|
||||
if dur < 0 {
|
||||
dur = 0
|
||||
}
|
||||
fmt.Fprintf(&sb, "SPEAKER %s 1 %.3f %.3f <NA> <NA> %s <NA> <NA>\n",
|
||||
id, seg.Start, dur, seg.Speaker)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func parseFormInt(c echo.Context, key string, def int) int {
|
||||
if v := c.FormValue(key); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func parseFormFloat(c echo.Context, key string, def float64) float64 {
|
||||
if v := c.FormValue(key); v != "" {
|
||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func parseFormBool(c echo.Context, key string, def bool) bool {
|
||||
if v := c.FormValue(key); v != "" {
|
||||
if b, err := strconv.ParseBool(v); err == nil {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
51
core/http/endpoints/openai/diarization_test.go
Normal file
51
core/http/endpoints/openai/diarization_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("renderRTTM", func() {
|
||||
It("formats segments as NIST RTTM rows", func() {
|
||||
r := &schema.DiarizationResult{
|
||||
Segments: []schema.DiarizationSegment{
|
||||
{Id: 0, Speaker: "SPEAKER_00", Start: 0, End: 2.34},
|
||||
{Id: 1, Speaker: "SPEAKER_01", Start: 2.34, End: 4.10},
|
||||
},
|
||||
}
|
||||
out := renderRTTM(r, "/tmp/uploads/meeting.wav")
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(out), "\n")
|
||||
Expect(lines).To(HaveLen(2))
|
||||
|
||||
// File ID should be the basename without extension; durations are
|
||||
// (end - start) with millisecond precision.
|
||||
Expect(lines[0]).To(HavePrefix("SPEAKER meeting 1 "))
|
||||
Expect(lines[0]).To(ContainSubstring(" 0.000 2.340 <NA> <NA> SPEAKER_00 <NA> <NA>"))
|
||||
Expect(lines[1]).To(ContainSubstring(" 2.340 1.760 <NA> <NA> SPEAKER_01 <NA> <NA>"))
|
||||
})
|
||||
|
||||
It("clamps negative duration to zero", func() {
|
||||
// Backends shouldn't emit end<start, but if they do (clock skew during a
|
||||
// long pipeline), the RTTM duration must stay non-negative.
|
||||
r := &schema.DiarizationResult{
|
||||
Segments: []schema.DiarizationSegment{
|
||||
{Id: 0, Speaker: "SPEAKER_00", Start: 5, End: 4},
|
||||
},
|
||||
}
|
||||
out := renderRTTM(r, "x.wav")
|
||||
Expect(out).To(ContainSubstring(" 5.000 0.000 "))
|
||||
})
|
||||
|
||||
It("falls back to 'audio' when the source file name is empty", func() {
|
||||
r := &schema.DiarizationResult{
|
||||
Segments: []schema.DiarizationSegment{{Id: 0, Speaker: "SPEAKER_00", Start: 0, End: 1}},
|
||||
}
|
||||
out := renderRTTM(r, "")
|
||||
Expect(out).To(HavePrefix("SPEAKER audio 1 "))
|
||||
})
|
||||
})
|
||||
1
core/http/react-ui/src/utils/capabilities.js
vendored
1
core/http/react-ui/src/utils/capabilities.js
vendored
@@ -14,6 +14,7 @@ export const CAP_TTS = 'FLAG_TTS'
|
||||
export const CAP_SOUND_GENERATION = 'FLAG_SOUND_GENERATION'
|
||||
export const CAP_TOKENIZE = 'FLAG_TOKENIZE'
|
||||
export const CAP_VAD = 'FLAG_VAD'
|
||||
export const CAP_DIARIZATION = 'FLAG_DIARIZATION'
|
||||
export const CAP_VIDEO = 'FLAG_VIDEO'
|
||||
export const CAP_DETECTION = 'FLAG_DETECTION'
|
||||
export const CAP_FACE_RECOGNITION = 'FLAG_FACE_RECOGNITION'
|
||||
|
||||
@@ -272,6 +272,7 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
"completions": "/v1/completions",
|
||||
"embeddings": "/v1/embeddings",
|
||||
"transcription": "/v1/audio/transcriptions",
|
||||
"diarization": "/v1/audio/diarization",
|
||||
"image_generation": "/v1/images/generations",
|
||||
},
|
||||
"config_management": map[string]string{
|
||||
|
||||
@@ -130,6 +130,23 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
app.POST("/v1/audio/transcriptions", audioHandler, audioMiddleware...)
|
||||
app.POST("/audio/transcriptions", audioHandler, audioMiddleware...)
|
||||
|
||||
diarizationHandler := openai.DiarizationEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
diarizationMiddleware := []echo.MiddlewareFunc{
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DIARIZATION)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := re.SetOpenAIRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
app.POST("/v1/audio/diarization", diarizationHandler, diarizationMiddleware...)
|
||||
app.POST("/audio/diarization", diarizationHandler, diarizationMiddleware...)
|
||||
|
||||
audioSpeechHandler := localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
audioSpeechMiddleware := []echo.MiddlewareFunc{
|
||||
traceMiddleware,
|
||||
|
||||
48
core/schema/diarization.go
Normal file
48
core/schema/diarization.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package schema
|
||||
|
||||
// DiarizationSegment is one continuous span of speech attributed to a
|
||||
// single speaker. Times are in seconds. Speaker is the normalized label
|
||||
// (SPEAKER_NN, zero-padded, stable across segments); Label preserves the
|
||||
// raw backend-emitted identifier for clients that already track their
|
||||
// own speaker dictionary.
|
||||
type DiarizationSegment struct {
|
||||
Id int `json:"id"`
|
||||
Speaker string `json:"speaker"`
|
||||
Label string `json:"label,omitempty"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// DiarizationSpeaker summarizes one speaker across the whole audio so
|
||||
// clients can build per-speaker UIs (timeline strips, talk-time charts)
|
||||
// without re-aggregating the segment list.
|
||||
type DiarizationSpeaker struct {
|
||||
Id string `json:"id"`
|
||||
Label string `json:"label,omitempty"`
|
||||
TotalSpeechDuration float64 `json:"total_speech_duration"`
|
||||
SegmentCount int `json:"segment_count"`
|
||||
}
|
||||
|
||||
// DiarizationResult is the JSON payload returned by /v1/audio/diarization.
|
||||
// Speakers and segment text are omitted when empty so the default `json`
|
||||
// response stays minimal; verbose_json keeps both populated.
|
||||
type DiarizationResult struct {
|
||||
Task string `json:"task"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
NumSpeakers int `json:"num_speakers"`
|
||||
Segments []DiarizationSegment `json:"segments"`
|
||||
Speakers []DiarizationSpeaker `json:"speakers,omitempty"`
|
||||
}
|
||||
|
||||
// DiarizationResponseFormatType mirrors transcription's response_format
|
||||
// pattern: json (default, no per-segment text), verbose_json (adds
|
||||
// speakers summary + text when available), and rttm (NIST RTTM rows).
|
||||
type DiarizationResponseFormatType string
|
||||
|
||||
const (
|
||||
DiarizationResponseFormatJson DiarizationResponseFormatType = "json"
|
||||
DiarizationResponseFormatJsonVerbose DiarizationResponseFormatType = "verbose_json"
|
||||
DiarizationResponseFormatRTTM DiarizationResponseFormatType = "rttm"
|
||||
)
|
||||
@@ -217,6 +217,9 @@ func (c *fakeBackendClient) GetTokenMetrics(_ context.Context, _ *pb.MetricsRequ
|
||||
func (c *fakeBackendClient) VAD(_ context.Context, _ *pb.VADRequest, _ ...ggrpc.CallOption) (*pb.VADResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeBackendClient) Diarize(_ context.Context, _ *pb.DiarizeRequest, _ ...ggrpc.CallOption) (*pb.DiarizeResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeBackendClient) AudioEncode(_ context.Context, _ *pb.AudioEncodeRequest, _ ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -156,6 +156,10 @@ func (f *fakeGRPCBackend) VAD(_ context.Context, _ *pb.VADRequest, _ ...ggrpc.Ca
|
||||
return &pb.VADResponse{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeGRPCBackend) Diarize(_ context.Context, _ *pb.DiarizeRequest, _ ...ggrpc.CallOption) (*pb.DiarizeResponse, error) {
|
||||
return &pb.DiarizeResponse{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeGRPCBackend) AudioEncode(_ context.Context, _ *pb.AudioEncodeRequest, _ ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) {
|
||||
return &pb.AudioEncodeResult{}, nil
|
||||
}
|
||||
|
||||
152
docs/content/features/audio-diarization.md
Normal file
152
docs/content/features/audio-diarization.md
Normal file
@@ -0,0 +1,152 @@
|
||||
+++
|
||||
disableToc = false
|
||||
title = "Speaker Diarization"
|
||||
weight = 17
|
||||
url = "/features/audio-diarization/"
|
||||
+++
|
||||
|
||||
Speaker diarization answers the question **"who spoke when?"** — given an audio clip with multiple speakers, it returns time-stamped segments labelled with a stable speaker ID (`SPEAKER_00`, `SPEAKER_01`, …).
|
||||
|
||||
LocalAI exposes this through the `/v1/audio/diarization` endpoint, modelled after `/v1/audio/transcriptions`. Two backends are supported today:
|
||||
|
||||
- **[sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx)** — pyannote-3.0 segmentation + a speaker-embedding extractor (3D-Speaker, NeMo, WeSpeaker) + fast clustering. Pure diarization — no transcription cost. Recommended when you only need speaker turns.
|
||||
- **[vibevoice.cpp](https://github.com/microsoft/VibeVoice)** — produces speaker-labelled segments as a by-product of its long-form ASR pass, so you can optionally get a transcript per segment for free.
|
||||
|
||||
Because diarization is exposed as a regular OpenAI-compatible endpoint, any HTTP client works. There is no Python dependency on pyannote or NeMo on the consumer side.
|
||||
|
||||
## Endpoint
|
||||
|
||||
```
|
||||
POST /v1/audio/diarization
|
||||
Content-Type: multipart/form-data
|
||||
```
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `file` | file (required) | audio file in any format `ffmpeg` accepts |
|
||||
| `model` | string (required) | name of the diarization-capable model |
|
||||
| `num_speakers` | int | exact speaker count when known (>0 forces; 0 = auto) |
|
||||
| `min_speakers` | int | hint when auto-detecting |
|
||||
| `max_speakers` | int | hint when auto-detecting |
|
||||
| `clustering_threshold` | float | cosine distance threshold used when `num_speakers` is unknown |
|
||||
| `min_duration_on` | float | discard segments shorter than this many seconds |
|
||||
| `min_duration_off` | float | merge gaps shorter than this many seconds |
|
||||
| `language` | string | only meaningful for backends that bundle ASR (e.g. vibevoice) |
|
||||
| `include_text` | bool | when the backend can emit per-segment transcript for free, populate it |
|
||||
| `response_format` | string | `json` (default), `verbose_json`, or `rttm` |
|
||||
|
||||
### Response — `json` (default)
|
||||
|
||||
Compact payload, no transcription, no per-speaker summary:
|
||||
|
||||
```json
|
||||
{
|
||||
"task": "diarize",
|
||||
"duration": 12.34,
|
||||
"num_speakers": 2,
|
||||
"segments": [
|
||||
{"id": 0, "speaker": "SPEAKER_00", "label": "0", "start": 0.00, "end": 2.34},
|
||||
{"id": 1, "speaker": "SPEAKER_01", "label": "1", "start": 2.34, "end": 4.10}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`speaker` is the normalized, zero-padded label clients should display. `label` preserves the raw backend-emitted ID for clients that maintain their own speaker dictionary.
|
||||
|
||||
### Response — `verbose_json`
|
||||
|
||||
Adds per-speaker totals and (when the backend supports it and `include_text=true`) the per-segment transcript:
|
||||
|
||||
```json
|
||||
{
|
||||
"task": "diarize",
|
||||
"duration": 12.34,
|
||||
"language": "en",
|
||||
"num_speakers": 2,
|
||||
"segments": [
|
||||
{"id": 0, "speaker": "SPEAKER_00", "label": "0", "start": 0.00, "end": 2.34, "text": "Hello, world."},
|
||||
{"id": 1, "speaker": "SPEAKER_01", "label": "1", "start": 2.34, "end": 4.10, "text": "How are you?"}
|
||||
],
|
||||
"speakers": [
|
||||
{"id": "SPEAKER_00", "label": "0", "total_speech_duration": 5.6, "segment_count": 3},
|
||||
{"id": "SPEAKER_01", "label": "1", "total_speech_duration": 1.76, "segment_count": 1}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Response — `rttm`
|
||||
|
||||
NIST RTTM, the standard interchange format used by `pyannote.metrics` / `dscore`:
|
||||
|
||||
```
|
||||
SPEAKER audio 1 0.000 2.340 <NA> <NA> SPEAKER_00 <NA> <NA>
|
||||
SPEAKER audio 1 2.340 1.760 <NA> <NA> SPEAKER_01 <NA> <NA>
|
||||
```
|
||||
|
||||
Returned as `Content-Type: text/plain; charset=utf-8`.
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/audio/diarization \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F file="@meeting.wav" \
|
||||
-F model="pyannote-diarization" \
|
||||
-F num_speakers=3
|
||||
```
|
||||
|
||||
## Backend setup — sherpa-onnx (pure diarization)
|
||||
|
||||
Sherpa-onnx needs two ONNX models: pyannote segmentation and a speaker-embedding extractor. Place them under your LocalAI models directory and reference them from the YAML:
|
||||
|
||||
```yaml
|
||||
name: pyannote-diarization
|
||||
backend: sherpa-onnx
|
||||
type: diarization
|
||||
parameters:
|
||||
model: sherpa-onnx-pyannote-segmentation-3-0/model.onnx
|
||||
options:
|
||||
- diarize.embedding_model=3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx
|
||||
# Optional clustering knobs (per-call DiarizeRequest fields override these):
|
||||
- diarize.threshold=0.5
|
||||
- diarize.min_duration_on=0.3
|
||||
- diarize.min_duration_off=0.5
|
||||
known_usecases:
|
||||
- FLAG_DIARIZATION
|
||||
```
|
||||
|
||||
Both `model:` and `diarize.embedding_model=` are resolved relative to the LocalAI models directory.
|
||||
|
||||
## Backend setup — vibevoice.cpp (diarization + ASR)
|
||||
|
||||
vibevoice.cpp's ASR mode emits `[{Start, End, Speaker, Content}]` natively, so a single pass gives both diarization and transcription:
|
||||
|
||||
```yaml
|
||||
name: vibevoice-diarize
|
||||
backend: vibevoice-cpp
|
||||
parameters:
|
||||
model: vibevoice-asr.gguf
|
||||
options:
|
||||
- type=asr
|
||||
- tokenizer=vibevoice-tokenizer.gguf
|
||||
known_usecases:
|
||||
- FLAG_DIARIZATION
|
||||
- FLAG_TRANSCRIPT
|
||||
```
|
||||
|
||||
Pass `include_text=true` on the request to populate the `text` field on each diarization segment.
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/audio/diarization \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F file="@interview.wav" \
|
||||
-F model="vibevoice-diarize" \
|
||||
-F include_text=true \
|
||||
-F response_format=verbose_json
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- **Speaker identity across files**: speaker IDs (`SPEAKER_00`, `SPEAKER_01`, …) are local to each request. To track the same person across multiple recordings, combine `/v1/audio/diarization` with `/v1/voice/embed` (speaker embedding) and maintain your own embedding store.
|
||||
- **Hints vs. forces**: `num_speakers` overrides clustering when set; `min_speakers` / `max_speakers` are advisory and only honored by backends that expose a range hint. vibevoice.cpp ignores them — its model picks the count itself.
|
||||
- **Sample rate**: input is automatically converted to 16 kHz mono via ffmpeg before the backend sees it; sherpa-onnx pyannote-3.0 requires 16 kHz.
|
||||
@@ -16,6 +16,8 @@ The transcription endpoint allows to convert audio files to text. The endpoint s
|
||||
|
||||
The endpoint input supports all the audio formats supported by `ffmpeg`.
|
||||
|
||||
> Looking for **"who spoke when"** instead of a flat transcript? See [Speaker Diarization](/features/audio-diarization/) — `/v1/audio/diarization` returns time-stamped speaker segments and supports the `rttm` format used by `pyannote.metrics`.
|
||||
|
||||
## Usage
|
||||
|
||||
Once LocalAI is started and whisper models are installed, you can use the `/v1/audio/transcriptions` API endpoint.
|
||||
|
||||
@@ -14,6 +14,7 @@ You can see the release notes [here](https://github.com/mudler/LocalAI/releases)
|
||||
|
||||
- **April 2026**: [Audio Transform](/features/audio-transform/) — generic audio-in / audio-out endpoint with optional reference signal. First implementation: [LocalVQE](https://github.com/localai-org/LocalVQE) C++ backend (joint AEC + noise suppression + dereverberation, DeepVQE-style). Both batch (`POST /audio/transformations`) and bidirectional WebSocket streaming (`/audio/transformations/stream`). Studio "Transform" tab with synchronized waveform players for input / reference / output.
|
||||
- **April 2026**: [Face recognition backend](/features/face-recognition/) — `insightface`-powered 1:1 verification, 1:N identification, face embedding, face detection, and demographic analysis. Ships both a non-commercial `buffalo_l` model and an Apache 2.0 OpenCV Zoo alternative.
|
||||
- **May 2026**: [Speaker diarization](/features/audio-diarization/) — new `/v1/audio/diarization` endpoint returning "who spoke when" segments. Backed by `sherpa-onnx` (pyannote-3.0 + speaker embeddings + clustering) for pure diarization, and `vibevoice-cpp` for diarization bundled with long-form ASR. Supports `json` / `verbose_json` / `rttm` response formats.
|
||||
|
||||
## 2024 Highlights
|
||||
|
||||
|
||||
@@ -75,6 +75,8 @@ type Backend interface {
|
||||
|
||||
VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error)
|
||||
|
||||
Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error)
|
||||
|
||||
AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error)
|
||||
AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error)
|
||||
|
||||
|
||||
@@ -101,6 +101,10 @@ func (llm *Base) VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error
|
||||
return pb.VoiceEmbedResponse{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) Diarize(*pb.DiarizeRequest) (pb.DiarizeResponse, error) {
|
||||
return pb.DiarizeResponse{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
return pb.TokenizationResponse{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
@@ -562,6 +562,24 @@ func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOp
|
||||
return client.VAD(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := c.dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.Diarize(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
|
||||
@@ -136,6 +136,10 @@ func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.
|
||||
return e.s.VAD(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error) {
|
||||
return e.s.Diarize(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error) {
|
||||
return e.s.AudioEncode(ctx, in)
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type AIModel interface {
|
||||
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
|
||||
|
||||
VAD(*pb.VADRequest) (pb.VADResponse, error)
|
||||
Diarize(*pb.DiarizeRequest) (pb.DiarizeResponse, error)
|
||||
|
||||
AudioEncode(*pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error)
|
||||
AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error)
|
||||
|
||||
@@ -377,6 +377,18 @@ func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, e
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (s *server) Diarize(ctx context.Context, in *pb.DiarizeRequest) (*pb.DiarizeResponse, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.Diarize(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (s *server) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
|
||||
158
swagger/docs.go
158
swagger/docs.go
@@ -1795,6 +1795,95 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/audio/diarization": {
|
||||
"post": {
|
||||
"consumes": [
|
||||
"multipart/form-data"
|
||||
],
|
||||
"tags": [
|
||||
"audio"
|
||||
],
|
||||
"summary": "Identify speakers in audio (who spoke when).",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "model",
|
||||
"name": "model",
|
||||
"in": "formData",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"description": "audio file",
|
||||
"name": "file",
|
||||
"in": "formData",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "exact speaker count (\u003e0 forces; 0 = auto)",
|
||||
"name": "num_speakers",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "lower bound when auto-detecting",
|
||||
"name": "min_speakers",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "upper bound when auto-detecting",
|
||||
"name": "max_speakers",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"description": "clustering distance threshold when num_speakers is unknown",
|
||||
"name": "clustering_threshold",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"description": "discard segments shorter than this (seconds)",
|
||||
"name": "min_duration_on",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"description": "merge gaps shorter than this (seconds)",
|
||||
"name": "min_duration_off",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "audio language hint (only meaningful for backends that bundle ASR)",
|
||||
"name": "language",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "include per-segment transcript when the backend supports it",
|
||||
"name": "include_text",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "json (default), verbose_json, or rttm",
|
||||
"name": "response_format",
|
||||
"in": "formData"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/schema.DiarizationResult"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/audio/speech": {
|
||||
"post": {
|
||||
"consumes": [
|
||||
@@ -3712,6 +3801,75 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"schema.DiarizationResult": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"duration": {
|
||||
"type": "number"
|
||||
},
|
||||
"language": {
|
||||
"type": "string"
|
||||
},
|
||||
"num_speakers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"segments": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/schema.DiarizationSegment"
|
||||
}
|
||||
},
|
||||
"speakers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/schema.DiarizationSpeaker"
|
||||
}
|
||||
},
|
||||
"task": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"schema.DiarizationSegment": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"end": {
|
||||
"type": "number"
|
||||
},
|
||||
"id": {
|
||||
"type": "integer"
|
||||
},
|
||||
"label": {
|
||||
"type": "string"
|
||||
},
|
||||
"speaker": {
|
||||
"type": "string"
|
||||
},
|
||||
"start": {
|
||||
"type": "number"
|
||||
},
|
||||
"text": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"schema.DiarizationSpeaker": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"label": {
|
||||
"type": "string"
|
||||
},
|
||||
"segment_count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_speech_duration": {
|
||||
"type": "number"
|
||||
}
|
||||
}
|
||||
},
|
||||
"schema.ElevenLabsSoundGenerationRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -1792,6 +1792,95 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/audio/diarization": {
|
||||
"post": {
|
||||
"consumes": [
|
||||
"multipart/form-data"
|
||||
],
|
||||
"tags": [
|
||||
"audio"
|
||||
],
|
||||
"summary": "Identify speakers in audio (who spoke when).",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "model",
|
||||
"name": "model",
|
||||
"in": "formData",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"description": "audio file",
|
||||
"name": "file",
|
||||
"in": "formData",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "exact speaker count (\u003e0 forces; 0 = auto)",
|
||||
"name": "num_speakers",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "lower bound when auto-detecting",
|
||||
"name": "min_speakers",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "upper bound when auto-detecting",
|
||||
"name": "max_speakers",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"description": "clustering distance threshold when num_speakers is unknown",
|
||||
"name": "clustering_threshold",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"description": "discard segments shorter than this (seconds)",
|
||||
"name": "min_duration_on",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"description": "merge gaps shorter than this (seconds)",
|
||||
"name": "min_duration_off",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "audio language hint (only meaningful for backends that bundle ASR)",
|
||||
"name": "language",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "include per-segment transcript when the backend supports it",
|
||||
"name": "include_text",
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "json (default), verbose_json, or rttm",
|
||||
"name": "response_format",
|
||||
"in": "formData"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/schema.DiarizationResult"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/audio/speech": {
|
||||
"post": {
|
||||
"consumes": [
|
||||
@@ -3709,6 +3798,75 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"schema.DiarizationResult": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"duration": {
|
||||
"type": "number"
|
||||
},
|
||||
"language": {
|
||||
"type": "string"
|
||||
},
|
||||
"num_speakers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"segments": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/schema.DiarizationSegment"
|
||||
}
|
||||
},
|
||||
"speakers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/schema.DiarizationSpeaker"
|
||||
}
|
||||
},
|
||||
"task": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"schema.DiarizationSegment": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"end": {
|
||||
"type": "number"
|
||||
},
|
||||
"id": {
|
||||
"type": "integer"
|
||||
},
|
||||
"label": {
|
||||
"type": "string"
|
||||
},
|
||||
"speaker": {
|
||||
"type": "string"
|
||||
},
|
||||
"start": {
|
||||
"type": "number"
|
||||
},
|
||||
"text": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"schema.DiarizationSpeaker": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"label": {
|
||||
"type": "string"
|
||||
},
|
||||
"segment_count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_speech_duration": {
|
||||
"type": "number"
|
||||
}
|
||||
}
|
||||
},
|
||||
"schema.ElevenLabsSoundGenerationRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -595,6 +595,51 @@ definitions:
|
||||
$ref: '#/definitions/schema.Detection'
|
||||
type: array
|
||||
type: object
|
||||
schema.DiarizationResult:
|
||||
properties:
|
||||
duration:
|
||||
type: number
|
||||
language:
|
||||
type: string
|
||||
num_speakers:
|
||||
type: integer
|
||||
segments:
|
||||
items:
|
||||
$ref: '#/definitions/schema.DiarizationSegment'
|
||||
type: array
|
||||
speakers:
|
||||
items:
|
||||
$ref: '#/definitions/schema.DiarizationSpeaker'
|
||||
type: array
|
||||
task:
|
||||
type: string
|
||||
type: object
|
||||
schema.DiarizationSegment:
|
||||
properties:
|
||||
end:
|
||||
type: number
|
||||
id:
|
||||
type: integer
|
||||
label:
|
||||
type: string
|
||||
speaker:
|
||||
type: string
|
||||
start:
|
||||
type: number
|
||||
text:
|
||||
type: string
|
||||
type: object
|
||||
schema.DiarizationSpeaker:
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
label:
|
||||
type: string
|
||||
segment_count:
|
||||
type: integer
|
||||
total_speech_duration:
|
||||
type: number
|
||||
type: object
|
||||
schema.ElevenLabsSoundGenerationRequest:
|
||||
properties:
|
||||
bpm:
|
||||
@@ -3322,6 +3367,66 @@ paths:
|
||||
summary: Generates audio from the input text.
|
||||
tags:
|
||||
- audio
|
||||
/v1/audio/diarization:
|
||||
post:
|
||||
consumes:
|
||||
- multipart/form-data
|
||||
parameters:
|
||||
- description: model
|
||||
in: formData
|
||||
name: model
|
||||
required: true
|
||||
type: string
|
||||
- description: audio file
|
||||
in: formData
|
||||
name: file
|
||||
required: true
|
||||
type: file
|
||||
- description: exact speaker count (>0 forces; 0 = auto)
|
||||
in: formData
|
||||
name: num_speakers
|
||||
type: integer
|
||||
- description: lower bound when auto-detecting
|
||||
in: formData
|
||||
name: min_speakers
|
||||
type: integer
|
||||
- description: upper bound when auto-detecting
|
||||
in: formData
|
||||
name: max_speakers
|
||||
type: integer
|
||||
- description: clustering distance threshold when num_speakers is unknown
|
||||
in: formData
|
||||
name: clustering_threshold
|
||||
type: number
|
||||
- description: discard segments shorter than this (seconds)
|
||||
in: formData
|
||||
name: min_duration_on
|
||||
type: number
|
||||
- description: merge gaps shorter than this (seconds)
|
||||
in: formData
|
||||
name: min_duration_off
|
||||
type: number
|
||||
- description: audio language hint (only meaningful for backends that bundle
|
||||
ASR)
|
||||
in: formData
|
||||
name: language
|
||||
type: string
|
||||
- description: include per-segment transcript when the backend supports it
|
||||
in: formData
|
||||
name: include_text
|
||||
type: boolean
|
||||
- description: json (default), verbose_json, or rttm
|
||||
in: formData
|
||||
name: response_format
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
$ref: '#/definitions/schema.DiarizationResult'
|
||||
summary: Identify speakers in audio (who spoke when).
|
||||
tags:
|
||||
- audio
|
||||
/v1/audio/speech:
|
||||
post:
|
||||
consumes:
|
||||
|
||||
@@ -169,6 +169,21 @@ var _ = BeforeSuite(func() {
|
||||
Expect(os.WriteFile(filepath.Join(modelsPath, name+".yaml"), data, 0644)).To(Succeed())
|
||||
}
|
||||
|
||||
// Diarization model — known_usecases bypasses the FLAG_DIARIZATION
|
||||
// backend-name guard so the /v1/audio/diarization route can dispatch
|
||||
// to the mock backend.
|
||||
diarizeCfg := map[string]any{
|
||||
"name": "mock-diarize",
|
||||
"backend": "mock-backend",
|
||||
"known_usecases": []string{"FLAG_DIARIZATION"},
|
||||
"parameters": map[string]any{
|
||||
"model": "mock-diarize.bin",
|
||||
},
|
||||
}
|
||||
diarizeData, err := yaml.Marshal(diarizeCfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(modelsPath, "mock-diarize.yaml"), diarizeData, 0644)).To(Succeed())
|
||||
|
||||
// Pipeline model that wires the component models together.
|
||||
pipelineCfg := map[string]any{
|
||||
"name": "realtime-pipeline",
|
||||
|
||||
@@ -631,6 +631,36 @@ func (m *MockBackend) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADRespon
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Diarize returns a deterministic two-speaker layout that exercises the
|
||||
// HTTP layer's normalisation: raw labels "5" and "2" should become
|
||||
// SPEAKER_00 and SPEAKER_01 in first-seen order, the SPEAKER_00 totals
|
||||
// should reflect two segments (1.0s + 1.5s = 2.5s), and IncludeText must
|
||||
// gate the per-segment Text field.
|
||||
func (m *MockBackend) Diarize(ctx context.Context, in *pb.DiarizeRequest) (*pb.DiarizeResponse, error) {
|
||||
xlog.Debug("Diarize called",
|
||||
"dst", in.Dst,
|
||||
"num_speakers", in.NumSpeakers,
|
||||
"include_text", in.IncludeText)
|
||||
|
||||
seg := func(start, end float32, speaker, text string) *pb.DiarizeSegment {
|
||||
out := &pb.DiarizeSegment{Start: start, End: end, Speaker: speaker}
|
||||
if in.IncludeText {
|
||||
out.Text = text
|
||||
}
|
||||
return out
|
||||
}
|
||||
return &pb.DiarizeResponse{
|
||||
Segments: []*pb.DiarizeSegment{
|
||||
seg(0.0, 1.0, "5", "hello there"),
|
||||
seg(1.0, 2.0, "2", "general kenobi"),
|
||||
seg(2.0, 3.5, "5", "you are a bold one"),
|
||||
},
|
||||
NumSpeakers: 2,
|
||||
Duration: 3.5,
|
||||
Language: in.Language,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockBackend) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) {
|
||||
xlog.Debug("AudioEncode called", "pcm_len", len(in.PcmData), "sample_rate", in.SampleRate)
|
||||
// Return a single mock Opus frame per 960-sample chunk (20ms at 48kHz).
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package e2e_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -225,6 +227,124 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Audio Diarization API", func() {
|
||||
// Helper: build a multipart/form-data request to /v1/audio/diarization
|
||||
// with a tiny stub WAV. The backend ignores the audio payload
|
||||
// (it returns a deterministic three-segment layout), so a 4-byte
|
||||
// stub is enough to exercise the HTTP layer.
|
||||
postDiarize := func(extraFields map[string]string) (*http.Response, []byte) {
|
||||
body := &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
|
||||
Expect(mw.WriteField("model", "mock-diarize")).To(Succeed())
|
||||
for k, v := range extraFields {
|
||||
Expect(mw.WriteField(k, v)).To(Succeed())
|
||||
}
|
||||
|
||||
part, err := mw.CreateFormFile("file", "stub.wav")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = part.Write([]byte{0, 0, 0, 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(mw.Close()).To(Succeed())
|
||||
|
||||
req, err := http.NewRequest("POST", apiURL+"/audio/diarization", body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := httpClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return resp, data
|
||||
}
|
||||
|
||||
It("normalizes raw backend speaker labels to SPEAKER_NN in first-seen order", func() {
|
||||
resp, data := postDiarize(nil)
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
var got map[string]any
|
||||
Expect(json.Unmarshal(data, &got)).To(Succeed())
|
||||
|
||||
Expect(got["task"]).To(Equal("diarize"))
|
||||
Expect(got["num_speakers"]).To(BeEquivalentTo(2))
|
||||
// json (default) drops the heavy speakers summary
|
||||
Expect(got).ToNot(HaveKey("speakers"))
|
||||
|
||||
segs, ok := got["segments"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(segs).To(HaveLen(3))
|
||||
|
||||
// Mock emits raw labels "5", "2", "5" — first-seen order maps:
|
||||
// 5 → SPEAKER_00, 2 → SPEAKER_01.
|
||||
seg0 := segs[0].(map[string]any)
|
||||
seg1 := segs[1].(map[string]any)
|
||||
seg2 := segs[2].(map[string]any)
|
||||
Expect(seg0["speaker"]).To(Equal("SPEAKER_00"))
|
||||
Expect(seg0["label"]).To(Equal("5"))
|
||||
Expect(seg1["speaker"]).To(Equal("SPEAKER_01"))
|
||||
Expect(seg2["speaker"]).To(Equal("SPEAKER_00"))
|
||||
|
||||
// json default suppresses per-segment text even when the backend
|
||||
// happened to emit some (here, IncludeText was not set so the
|
||||
// backend already stripped — but the HTTP layer also gates).
|
||||
_, hasText := seg0["text"].(string)
|
||||
if hasText {
|
||||
Expect(seg0["text"]).To(Equal(""))
|
||||
}
|
||||
})
|
||||
|
||||
It("verbose_json emits speakers summary and per-segment transcripts when include_text is set", func() {
|
||||
resp, data := postDiarize(map[string]string{
|
||||
"response_format": "verbose_json",
|
||||
"include_text": "true",
|
||||
})
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
var got map[string]any
|
||||
Expect(json.Unmarshal(data, &got)).To(Succeed())
|
||||
|
||||
speakers, ok := got["speakers"].([]any)
|
||||
Expect(ok).To(BeTrue(), "verbose_json must include speakers summary")
|
||||
Expect(speakers).To(HaveLen(2))
|
||||
|
||||
// SPEAKER_00 should reflect both 1.0s segments (1.0 + 1.5 = 2.5s, 2 segments)
|
||||
byID := map[string]map[string]any{}
|
||||
for _, sp := range speakers {
|
||||
m := sp.(map[string]any)
|
||||
byID[m["id"].(string)] = m
|
||||
}
|
||||
Expect(byID).To(HaveKey("SPEAKER_00"))
|
||||
Expect(byID["SPEAKER_00"]["total_speech_duration"]).To(BeNumerically("~", 2.5, 0.001))
|
||||
Expect(byID["SPEAKER_00"]["segment_count"]).To(BeEquivalentTo(2))
|
||||
|
||||
segs := got["segments"].([]any)
|
||||
Expect(segs[0].(map[string]any)["text"]).To(Equal("hello there"))
|
||||
Expect(segs[1].(map[string]any)["text"]).To(Equal("general kenobi"))
|
||||
})
|
||||
|
||||
It("rttm response_format returns NIST RTTM rows", func() {
|
||||
resp, data := postDiarize(map[string]string{"response_format": "rttm"})
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
Expect(resp.Header.Get("Content-Type")).To(HavePrefix("text/plain"))
|
||||
|
||||
body := string(data)
|
||||
lines := strings.Split(strings.TrimSpace(body), "\n")
|
||||
Expect(lines).To(HaveLen(3))
|
||||
// "SPEAKER stub 1 0.000 1.000 <NA> <NA> SPEAKER_00 <NA> <NA>"
|
||||
Expect(lines[0]).To(HavePrefix("SPEAKER stub 1 "))
|
||||
Expect(lines[0]).To(ContainSubstring(" SPEAKER_00 "))
|
||||
Expect(lines[1]).To(ContainSubstring(" SPEAKER_01 "))
|
||||
Expect(lines[2]).To(ContainSubstring(" SPEAKER_00 "))
|
||||
})
|
||||
|
||||
It("rejects unknown response_format with 4xx/5xx", func() {
|
||||
resp, _ := postDiarize(map[string]string{"response_format": "csv"})
|
||||
Expect(resp.StatusCode).To(BeNumerically(">=", 400))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Rerank API", func() {
|
||||
It("should return mocked reranking results", func() {
|
||||
req, err := http.NewRequest("POST", apiURL+"/rerank", nil)
|
||||
|
||||
Reference in New Issue
Block a user