mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-16 20:52:08 -04:00
feat(whisper): honor client cancellation via ggml abort_callback (#9710)
* refactor(transcription): propagate request ctx through ModelTranscription* Replaces context.Background() with the HTTP request ctx so client disconnects start cancelling the gRPC call. No backend-side abort wiring yet — that comes in a later commit. Pure plumbing. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cli): pass ctx to backend.ModelTranscription Follow-up toe65d3e1fwhich threaded ctx through ModelTranscription but missed the CLI caller. CLI commands have no request-scoped ctx, so context.Background() is correct here. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(audio): propagate request ctx into TTS, sound-gen, audio-transform Same ctx-plumbing pattern applied to the rest of the audio path. CLI callers use context.Background() since there is no request scope; HTTP callers use c.Request().Context(). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(backend): propagate request ctx into biometric, detection, rerank, diarization paths Replaces remaining context.Background() sites in core/backend with the caller's ctx. After this commit, every core/backend/*.go entry point threads the request ctx end-to-end to the gRPC client. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream} Adds context.Context as first parameter to the AIModel interface methods that wrap whisper-style transcription. Server-side gRPC handler now forwards the per-RPC ctx (server-streaming uses stream.Context()). Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter; none uses it yet — the actual cancellation primitive lands in the next commit so this is pure plumbing. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): add abort_callback hook in the C++ bridge Installs a std::atomic<int> flag, wires it into whisper_full_params.abort_callback, and exposes a set_abort(int) C symbol so Go can flip the flag from a goroutine watching the request context. transcribe() now distinguishes abort (return 2) from real whisper_full failure (return 1). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): register set_abort symbol in the purego loader Adds the Go-side binding for the new C export so the next commit can call CppSetAbort(1) from a watcher goroutine on ctx.Done(). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): honor ctx cancellation and return codes.Canceled A watcher goroutine watches ctx.Done() during AudioTranscription and calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return true at the next compute graph step, returns non-zero, and the bridge returns 2 -> AudioTranscription maps that to codes.Canceled. Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH) that asserts cancellation latency under 5s and proves the abort flag resets cleanly so the next transcription succeeds. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(whisper): join the cancel watcher goroutine before returning Follow-up to85edf9d2. The previous commit used `defer close(done)` and called the watcher "joined synchronously" — but close() only signals, it does not block until the goroutine exits. That left a window where a late CppSetAbort(1) from a cancelled call could land on the next call, after its C-side g_abort reset but before whisper_full() began polling the abort callback, corrupting the second transcription. Switch to a sync.WaitGroup join so wg.Wait() blocks until the watcher has actually returned from its select. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(whisper): short-circuit pre-cancelled ctx in AudioTranscription If ctx is already Done() at entry, return codes.Canceled immediately instead of running the full transcription. The C-side g_abort reset happens at the start of transcribe() and would otherwise overwrite a watcher-set abort flag from an already-cancelled ctx, producing a spurious successful transcription on a request the client has already abandoned. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(tests/distributed): update testLLM mock for new AudioTranscription signature Phase B (93c48e19) added context.Context to AIModel.AudioTranscription but missed the testLLM mock in tests/e2e/distributed. CI golangci-lint caught it: *testLLM did not implement grpc.AIModel because the method signature lacked the ctx parameter, which broke the distributed test suite compilation and cascaded through every backend-build job that runs `go build ./...`. Assisted-by: Claude:claude-opus-4-7 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * test(whisper): port cancellation test to Ginkgo/Gomega Project policy (.agents/coding-style.md, enforced by golangci-lint forbidigo) is that all Go tests must use Ginkgo v2 + Gomega — no stdlib testing patterns (t.Skip, t.Fatalf, etc.). Convert the cancellation test to a Describe/It block with Skip(...) for env gating and Expect/HaveOccurred for assertions. Same coverage: cancel mid-flight returns codes.Canceled within 5s and a follow-up transcription succeeds, proving the C-side g_abort flag resets cleanly. Assisted-by: Claude:claude-opus-4-7 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -998,7 +999,7 @@ func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error {
|
||||
// Transcription
|
||||
// =============================================================
|
||||
|
||||
func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (s *SherpaBackend) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if s.onlineRecognizer != 0 {
|
||||
return s.runOnlineASR(req, nil)
|
||||
}
|
||||
@@ -1056,6 +1057,7 @@ func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.Transc
|
||||
// Closes `results` before returning so the server wrapper's reader
|
||||
// goroutine can exit.
|
||||
func (s *SherpaBackend) AudioTranscriptionStream(
|
||||
_ context.Context,
|
||||
req *pb.TranscriptRequest,
|
||||
results chan *pb.TranscriptStreamResponse,
|
||||
) error {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -79,7 +80,7 @@ var _ = Describe("Sherpa-ONNX", func() {
|
||||
})
|
||||
|
||||
It("rejects AudioTranscription", func() {
|
||||
_, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{
|
||||
_, err := (&SherpaBackend{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: "/tmp/nonexistent.wav",
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -480,7 +481,7 @@ func (w *byteWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (v *VibevoiceCpp) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if v.asrModel == "" {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded")
|
||||
}
|
||||
@@ -623,9 +624,9 @@ func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, erro
|
||||
// transcription, emit each segment's content as a delta, then close
|
||||
// with a final_result whose Text equals the concatenated deltas (the
|
||||
// e2e harness asserts those match).
|
||||
func (v *VibevoiceCpp) AudioTranscriptionStream(req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
func (v *VibevoiceCpp) AudioTranscriptionStream(ctx context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
res, err := v.AudioTranscription(req)
|
||||
res, err := v.AudioTranscription(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ var _ = Describe("VibeVoice-cpp", func() {
|
||||
})
|
||||
|
||||
It("rejects AudioTranscription without a loaded ASR model", func() {
|
||||
_, err := (&VibevoiceCpp{}).AudioTranscription(&pb.TranscriptRequest{
|
||||
_, err := (&VibevoiceCpp{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: "/tmp/some.wav",
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -255,7 +255,7 @@ var _ = Describe("VibeVoice-cpp", func() {
|
||||
|
||||
It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() {
|
||||
ch := make(chan *pb.TranscriptStreamResponse, 4)
|
||||
err := (&VibevoiceCpp{}).AudioTranscriptionStream(&pb.TranscriptRequest{
|
||||
err := (&VibevoiceCpp{}).AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: "/tmp/some.wav",
|
||||
}, ch)
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -27,7 +28,7 @@ func (v *Voxtral) Load(opts *pb.ModelOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (v *Voxtral) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
dir, err := os.MkdirTemp("", "voxtral")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
#include "gowhisper.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "whisper.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static struct whisper_context *ctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static std::atomic<int> g_abort{0};
|
||||
|
||||
static bool abort_cb(void * /*user_data*/) {
|
||||
return g_abort.load(std::memory_order_relaxed) != 0;
|
||||
}
|
||||
|
||||
extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
@@ -124,10 +135,20 @@ int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||
wparams.tdrz_enable = tdrz;
|
||||
wparams.initial_prompt = prompt;
|
||||
|
||||
// Reset stale abort flag from any prior cancelled call, then install the
|
||||
// ggml abort hook so a subsequent set_abort(1) from Go aborts the next
|
||||
// compute graph step.
|
||||
g_abort.store(0, std::memory_order_relaxed);
|
||||
wparams.abort_callback = abort_cb;
|
||||
wparams.abort_callback_user_data = nullptr;
|
||||
|
||||
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
|
||||
fprintf(stderr, "info: Initial prompt: \"%s\"\n", prompt);
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
|
||||
if (g_abort.load(std::memory_order_relaxed)) {
|
||||
return 2; // aborted by client
|
||||
}
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -15,4 +15,5 @@ int64_t get_segment_t1(int i);
|
||||
int n_tokens(int i);
|
||||
int32_t get_token_id(int i, int j);
|
||||
bool get_segment_speaker_turn_next(int i);
|
||||
void set_abort(int v);
|
||||
}
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -24,6 +28,7 @@ var (
|
||||
CppNTokens func(i int) int
|
||||
CppGetTokenID func(i int, j int) int
|
||||
CppGetSegmentSpeakerTurnNext func(i int) bool
|
||||
CppSetAbort func(v int)
|
||||
)
|
||||
|
||||
type Whisper struct {
|
||||
@@ -92,7 +97,11 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (w *Whisper) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
@@ -105,14 +114,12 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
@@ -120,8 +127,6 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
// whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate
|
||||
// for the converted file produced by AudioToWav above.
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
@@ -129,7 +134,31 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 {
|
||||
// Watcher: flips the C-side abort flag when ctx is cancelled. The
|
||||
// goroutine is joined synchronously (close(done) signals it to exit,
|
||||
// wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire
|
||||
// after the function returns and corrupt the next transcription call.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
|
||||
112
backend/go/whisper/gowhisper_test.go
Normal file
112
backend/go/whisper/gowhisper_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestWhisper(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Whisper Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// bridge without spinning up the gRPC server. Skips the current spec when the
|
||||
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("WHISPER_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgowhisper-fallback.so"
|
||||
}
|
||||
if _, err := os.Stat(libName); err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||
purego.RegisterLibFunc(&CppNTokens, gosd, "n_tokens")
|
||||
purego.RegisterLibFunc(&CppGetTokenID, gosd, "get_token_id")
|
||||
purego.RegisterLibFunc(&CppGetSegmentSpeakerTurnNext, gosd, "get_segment_speaker_turn_next")
|
||||
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("whisper library not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if either
|
||||
// env var is unset. The test never runs in default CI — it requires a real
|
||||
// whisper model and a long audio file (~3 minutes) on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("WHISPER_MODEL_PATH")
|
||||
audioPath := os.Getenv("WHISPER_AUDIO_PATH")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
var _ = Describe("Whisper", func() {
|
||||
Context("AudioTranscription cancellation", func() {
|
||||
It("returns codes.Canceled and resets the abort flag for the next call", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &Whisper{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
Expect(err).To(HaveOccurred(), "transcription completed in %s without cancel — try a longer audio file", elapsed)
|
||||
st, ok := status.FromError(err)
|
||||
Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err)
|
||||
Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err)
|
||||
Expect(elapsed).To(BeNumerically("<", 5*time.Second), "cancellation took %s, expected <5s", elapsed)
|
||||
|
||||
// Subsequent transcription must succeed — proves g_abort reset.
|
||||
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed")
|
||||
Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text")
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -41,6 +41,7 @@ func main() {
|
||||
{&CppNTokens, "n_tokens"},
|
||||
{&CppGetTokenID, "get_token_id"},
|
||||
{&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"},
|
||||
{&CppSetAbort, "set_abort"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
|
||||
@@ -40,6 +40,7 @@ type AudioTransformOutputs struct {
|
||||
// required; `referencePath` is optional (empty => backend zero-fills the
|
||||
// reference channel).
|
||||
func ModelAudioTransform(
|
||||
ctx context.Context,
|
||||
audioPath, referencePath string,
|
||||
opts AudioTransformOptions,
|
||||
loader *model.ModelLoader,
|
||||
@@ -81,7 +82,7 @@ func ModelAudioTransform(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := transformModel.AudioTransform(context.Background(), &proto.AudioTransformRequest{
|
||||
res, err := transformModel.AudioTransform(ctx, &proto.AudioTransformRequest{
|
||||
AudioPath: audioPath,
|
||||
ReferencePath: referencePath,
|
||||
Dst: dst,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func Detection(
|
||||
ctx context.Context,
|
||||
sourceFile string,
|
||||
prompt string,
|
||||
points []float32,
|
||||
@@ -38,7 +39,7 @@ func Detection(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
||||
res, err := detectionModel.Detect(ctx, &proto.DetectOptions{
|
||||
Src: sourceFile,
|
||||
Prompt: prompt,
|
||||
Points: points,
|
||||
|
||||
@@ -63,7 +63,7 @@ func loadDiarizationModel(ml *model.ModelLoader, modelConfig config.ModelConfig,
|
||||
|
||||
// ModelDiarization runs the Diarize RPC against the configured backend
|
||||
// and returns a normalized schema.DiarizationResult.
|
||||
func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) {
|
||||
func ModelDiarization(ctx context.Context, req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) {
|
||||
m, err := loadDiarizationModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -74,7 +74,7 @@ func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig
|
||||
threads = uint32(*modelConfig.Threads)
|
||||
}
|
||||
|
||||
r, err := m.Diarize(context.Background(), req.toProto(threads))
|
||||
r, err := m.Diarize(ctx, req.toProto(threads))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func FaceAnalyze(
|
||||
ctx context.Context,
|
||||
img string,
|
||||
actions []string,
|
||||
antiSpoofing bool,
|
||||
@@ -35,7 +36,7 @@ func FaceAnalyze(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := faceModel.FaceAnalyze(context.Background(), &proto.FaceAnalyzeRequest{
|
||||
res, err := faceModel.FaceAnalyze(ctx, &proto.FaceAnalyzeRequest{
|
||||
Img: img,
|
||||
Actions: actions,
|
||||
AntiSpoofing: antiSpoofing,
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
// backend picks the highest-confidence face and returns its
|
||||
// L2-normalized embedding.
|
||||
func FaceEmbed(
|
||||
ctx context.Context,
|
||||
imgBase64 string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
@@ -32,7 +33,7 @@ func FaceEmbed(
|
||||
predictOpts := gRPCPredictOpts(modelConfig, loader.ModelPath)
|
||||
predictOpts.Images = []string{imgBase64}
|
||||
|
||||
res, err := faceModel.Embeddings(context.Background(), predictOpts)
|
||||
res, err := faceModel.Embeddings(ctx, predictOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func FaceVerify(
|
||||
ctx context.Context,
|
||||
img1, img2 string,
|
||||
threshold float32,
|
||||
antiSpoofing bool,
|
||||
@@ -35,7 +36,7 @@ func FaceVerify(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := faceModel.FaceVerify(context.Background(), &proto.FaceVerifyRequest{
|
||||
res, err := faceModel.FaceVerify(ctx, &proto.FaceVerifyRequest{
|
||||
Img1: img1,
|
||||
Img2: img2,
|
||||
Threshold: threshold,
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
rerankModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
@@ -29,7 +29,7 @@ func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := rerankModel.Rerank(context.Background(), request)
|
||||
res, err := rerankModel.Rerank(ctx, request)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func SoundGeneration(
|
||||
ctx context.Context,
|
||||
text string,
|
||||
duration *float32,
|
||||
temperature *float32,
|
||||
@@ -101,7 +102,7 @@ func SoundGeneration(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := soundGenModel.SoundGeneration(context.Background(), req)
|
||||
res, err := soundGenModel.SoundGeneration(ctx, req)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TokenMetrics(
|
||||
ctx context.Context,
|
||||
modelFile string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
@@ -26,7 +27,7 @@ func TokenMetrics(
|
||||
return nil, fmt.Errorf("could not loadmodel model")
|
||||
}
|
||||
|
||||
res, err := model.GetTokenMetrics(context.Background(), &proto.MetricsRequest{})
|
||||
res, err := model.GetTokenMetrics(ctx, &proto.MetricsRequest{})
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
@@ -57,8 +57,8 @@ func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfi
|
||||
return transcriptionModel, nil
|
||||
}
|
||||
|
||||
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
return ModelTranscriptionWithOptions(TranscriptionRequest{
|
||||
func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{
|
||||
Audio: audio,
|
||||
Language: language,
|
||||
Translate: translate,
|
||||
@@ -67,7 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
}, ml, modelConfig, appConfig)
|
||||
}
|
||||
|
||||
func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -82,7 +82,7 @@ func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoad
|
||||
audioSnippet = trace.AudioSnippet(req.Audio)
|
||||
}
|
||||
|
||||
r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads)))
|
||||
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
|
||||
if err != nil {
|
||||
if appConfig.EnableTracing {
|
||||
errData := map[string]any{
|
||||
@@ -149,7 +149,7 @@ type TranscriptionStreamChunk struct {
|
||||
// invokes onChunk for each event the backend produces. Backends that don't
|
||||
// support real streaming should still emit one terminal event with Final set,
|
||||
// which the HTTP layer turns into a single delta + done SSE pair.
|
||||
func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -158,7 +158,7 @@ func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, m
|
||||
pbReq := req.toProto(uint32(*modelConfig.Threads))
|
||||
pbReq.Stream = true
|
||||
|
||||
return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
||||
return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
||||
if chunk == nil {
|
||||
return
|
||||
}
|
||||
@@ -187,12 +187,12 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
|
||||
}
|
||||
var words []schema.TranscriptionWord
|
||||
for _, w := range s.Words {
|
||||
var word = schema.TranscriptionWord {
|
||||
var word = schema.TranscriptionWord{
|
||||
Start: time.Duration(w.Start),
|
||||
End: time.Duration(w.End),
|
||||
Text: w.Text,
|
||||
}
|
||||
words = append(words, word)
|
||||
words = append(words, word)
|
||||
tr.Words = append(tr.Words, word)
|
||||
}
|
||||
tr.Segments = append(tr.Segments,
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
)
|
||||
|
||||
func ModelTTS(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
@@ -70,7 +71,7 @@ func ModelTTS(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
@@ -121,6 +122,7 @@ func ModelTTS(
|
||||
}
|
||||
|
||||
func ModelTTSStream(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
@@ -172,7 +174,7 @@ func ModelTTSStream(
|
||||
var totalPCMBytes int
|
||||
snippetCapped := false
|
||||
|
||||
err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{
|
||||
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func VoiceAnalyze(
|
||||
ctx context.Context,
|
||||
audio string,
|
||||
actions []string,
|
||||
loader *model.ModelLoader,
|
||||
@@ -34,7 +35,7 @@ func VoiceAnalyze(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceAnalyze(context.Background(), &proto.VoiceAnalyzeRequest{
|
||||
res, err := voiceModel.VoiceAnalyze(ctx, &proto.VoiceAnalyzeRequest{
|
||||
Audio: audio,
|
||||
Actions: actions,
|
||||
})
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
// OpenAI-compatible and text-only), this call takes an audio path and
|
||||
// returns the backend's speaker-encoder output.
|
||||
func VoiceEmbed(
|
||||
ctx context.Context,
|
||||
audioPath string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
@@ -37,7 +38,7 @@ func VoiceEmbed(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceEmbed(context.Background(), &proto.VoiceEmbedRequest{
|
||||
res, err := voiceModel.VoiceEmbed(ctx, &proto.VoiceEmbedRequest{
|
||||
Audio: audioPath,
|
||||
})
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func VoiceVerify(
|
||||
ctx context.Context,
|
||||
audio1, audio2 string,
|
||||
threshold float32,
|
||||
antiSpoofing bool,
|
||||
@@ -35,7 +36,7 @@ func VoiceVerify(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceVerify(context.Background(), &proto.VoiceVerifyRequest{
|
||||
res, err := voiceModel.VoiceVerify(ctx, &proto.VoiceVerifyRequest{
|
||||
Audio1: audio1,
|
||||
Audio2: audio2,
|
||||
Threshold: threshold,
|
||||
|
||||
@@ -97,7 +97,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||
inputFile = &t.InputFile
|
||||
}
|
||||
|
||||
filePath, _, err := backend.SoundGeneration(text,
|
||||
filePath, _, err := backend.SoundGeneration(context.Background(), text,
|
||||
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
|
||||
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor),
|
||||
nil, "", "", nil, "", "", "", nil,
|
||||
|
||||
@@ -71,7 +71,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts)
|
||||
tr, err := backend.ModelTranscription(context.Background(), t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||
options.Backend = t.Backend
|
||||
options.Model = t.Model
|
||||
|
||||
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
|
||||
bpm = &b
|
||||
}
|
||||
filePath, _, err := backend.SoundGeneration(
|
||||
c.Request().Context(),
|
||||
input.Text, input.Duration, input.Temperature, input.DoSample,
|
||||
nil, nil,
|
||||
input.Think, input.Caption, input.Lyrics, bpm, input.Keyscale,
|
||||
|
||||
@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
|
||||
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
|
||||
|
||||
filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
Documents: input.Documents,
|
||||
}
|
||||
|
||||
results, err := backend.Rerank(request, ml, appConfig, *cfg)
|
||||
results, err := backend.Rerank(c.Request().Context(), request, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func AudioTransformEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
}
|
||||
}
|
||||
|
||||
out, _, err := backend.ModelAudioTransform(audioPath, referencePath, backend.AudioTransformOptions{
|
||||
out, _, err := backend.ModelAudioTransform(c.Request().Context(), audioPath, referencePath, backend.AudioTransformOptions{
|
||||
Params: params,
|
||||
}, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -38,7 +38,7 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := backend.Detection(image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
|
||||
res, err := backend.Detection(c.Request().Context(), image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func FaceAnalyzeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
}
|
||||
|
||||
xlog.Debug("FaceAnalyze", "model", cfg.Name, "backend", cfg.Backend, "actions", input.Actions)
|
||||
res, err := backend.FaceAnalyze(img, input.Actions, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
res, err := backend.FaceAnalyze(c.Request().Context(), img, input.Actions, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func FaceEmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
}
|
||||
|
||||
xlog.Debug("FaceEmbed", "model", cfg.Name, "backend", cfg.Backend)
|
||||
vec, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
|
||||
vec, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func FaceIdentifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
threshold := cmp.Or(input.Threshold, defaultIdentifyThreshold)
|
||||
|
||||
xlog.Debug("FaceIdentify", "model", cfg.Name, "topK", topK, "threshold", threshold)
|
||||
probe, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
|
||||
probe, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func FaceRegisterEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
}
|
||||
|
||||
xlog.Debug("FaceRegister", "model", cfg.Name, "name", input.Name)
|
||||
embedding, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
|
||||
embedding, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func FaceVerifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
}
|
||||
|
||||
xlog.Debug("FaceVerify", "model", cfg.Name, "backend", cfg.Backend)
|
||||
res, err := backend.FaceVerify(img1, img2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
res, err := backend.FaceVerify(c.Request().Context(), img1, img2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
}
|
||||
xlog.Debug("Token Metrics for model", "model", modelFile)
|
||||
|
||||
response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg)
|
||||
response, err := backend.TokenMetrics(c.Request().Context(), modelFile, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Stream audio chunks as they're generated
|
||||
err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
_, writeErr := c.Response().Write(audioChunk)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
}
|
||||
|
||||
// Non-streaming TTS (existing behavior)
|
||||
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func VoiceAnalyzeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
defer cleanup()
|
||||
|
||||
xlog.Debug("VoiceAnalyze", "model", cfg.Name, "backend", cfg.Backend, "actions", input.Actions)
|
||||
res, err := backend.VoiceAnalyze(audio, input.Actions, ml, appConfig, *cfg)
|
||||
res, err := backend.VoiceAnalyze(c.Request().Context(), audio, input.Actions, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func VoiceEmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
defer cleanup()
|
||||
|
||||
xlog.Debug("VoiceEmbed", "model", cfg.Name, "backend", cfg.Backend)
|
||||
res, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
|
||||
res, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func VoiceIdentifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
threshold := cmp.Or(input.Threshold, defaultVoiceIdentifyThreshold)
|
||||
|
||||
xlog.Debug("VoiceIdentify", "model", cfg.Name, "topK", topK, "threshold", threshold)
|
||||
embed, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
|
||||
embed, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func VoiceRegisterEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
defer cleanup()
|
||||
|
||||
xlog.Debug("VoiceRegister", "model", cfg.Name, "name", input.Name)
|
||||
res, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
|
||||
res, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func VoiceVerifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
defer cleanup2()
|
||||
|
||||
xlog.Debug("VoiceVerify", "model", cfg.Name, "backend", cfg.Backend)
|
||||
res, err := backend.VoiceVerify(audio1, audio2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
res, err := backend.VoiceVerify(c.Request().Context(), audio1, audio2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return mapBackendError(err)
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func DiarizationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
_ = dstFile.Close()
|
||||
req.Audio = dst
|
||||
|
||||
result, err := backend.ModelDiarization(req, ml, *modelConfig, appConfig)
|
||||
result, err := backend.ModelDiarization(c.Request().Context(), req, ml, *modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADReques
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
|
||||
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||
return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
|
||||
@@ -82,7 +82,7 @@ func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*sc
|
||||
}
|
||||
|
||||
func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
|
||||
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||
return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
|
||||
@@ -241,7 +241,7 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
|
||||
return backend.ModelTTS(text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
|
||||
|
||||
@@ -126,7 +126,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
return streamTranscription(c, req, ml, *config, appConfig)
|
||||
}
|
||||
|
||||
tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig)
|
||||
tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
// Log before returning so the underlying error survives. Echo's
|
||||
// error handler turns this into a 500 with a generic body, which
|
||||
@@ -157,16 +157,16 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
Words: []schema.TranscriptionWordSeconds{},
|
||||
Segments: []schema.TranscriptionSegmentSeconds{},
|
||||
}
|
||||
for _, word := range(tr.Words) {
|
||||
for _, word := range tr.Words {
|
||||
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
|
||||
Start: word.Start.Seconds(),
|
||||
End: word.End.Seconds(),
|
||||
Text: word.Text,
|
||||
})
|
||||
}
|
||||
for _, seg := range(tr.Segments) {
|
||||
for _, seg := range tr.Segments {
|
||||
segWords := []schema.TranscriptionWordSeconds{}
|
||||
for _, word := range(seg.Words) {
|
||||
for _, word := range seg.Words {
|
||||
segWords = append(segWords, schema.TranscriptionWordSeconds{
|
||||
Start: word.Start.Seconds(),
|
||||
End: word.End.Seconds(),
|
||||
@@ -174,7 +174,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
})
|
||||
}
|
||||
trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{
|
||||
Id: seg.Id,
|
||||
Id: seg.Id,
|
||||
Start: seg.Start.Seconds(),
|
||||
End: seg.End.Seconds(),
|
||||
Text: seg.Text,
|
||||
@@ -216,7 +216,7 @@ func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *m
|
||||
var assembled strings.Builder
|
||||
var finalResult *schema.TranscriptionResult
|
||||
|
||||
err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
||||
err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
||||
if chunk.Delta != "" {
|
||||
assembled.WriteString(chunk.Delta)
|
||||
_ = writeEvent(map[string]any{
|
||||
|
||||
@@ -3,6 +3,7 @@ package base
|
||||
// This is a wrapper to satisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -57,11 +58,11 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
|
||||
func (llm *Base) AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
@@ -22,8 +24,8 @@ type AIModel interface {
|
||||
VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error)
|
||||
VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error)
|
||||
VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error)
|
||||
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||
AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
|
||||
AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||
AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
|
||||
TTS(*pb.TTSRequest) error
|
||||
TTSStream(*pb.TTSRequest, chan []byte) error
|
||||
SoundGeneration(*pb.SoundGenerationRequest) error
|
||||
|
||||
@@ -218,7 +218,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
result, err := s.llm.AudioTranscription(in)
|
||||
result, err := s.llm.AudioTranscription(ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -260,7 +260,7 @@ func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Ba
|
||||
done <- true
|
||||
}()
|
||||
|
||||
err := s.llm.AudioTranscriptionStream(in, resultChan)
|
||||
err := s.llm.AudioTranscriptionStream(stream.Context(), in, resultChan)
|
||||
<-done
|
||||
|
||||
return err
|
||||
|
||||
@@ -93,7 +93,7 @@ func (t *testLLM) SoundGeneration(req *pb.SoundGenerationRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testLLM) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (t *testLLM) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
t.lastAudioDst = req.Dst
|
||||
return pb.TranscriptResult{Text: "transcribed text"}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user