From 716dba94b4359f8d5d550d21b82896b473818957 Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Thu, 18 Dec 2025 13:40:45 +0000 Subject: [PATCH] feat(whisper): Add prompt to condition transcription output (#7624) * chore(makefile): Add buildargs for sd and cuda when building backend Signed-off-by: Richard Palethorpe * feat(whisper): Add prompt to condition transcription output Signed-off-by: Richard Palethorpe --------- Signed-off-by: Richard Palethorpe --- Makefile | 2 +- backend/backend.proto | 1 + backend/go/whisper/.gitignore | 4 ++-- backend/go/whisper/gowhisper.cpp | 4 +++- backend/go/whisper/gowhisper.go | 4 ++-- backend/go/whisper/gowhisper.h | 3 ++- core/backend/transcript.go | 3 ++- core/cli/transcript.go | 3 ++- core/http/endpoints/openai/realtime.go | 6 ++++++ core/http/endpoints/openai/transcription.go | 3 ++- 10 files changed, 23 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 5e36f682d..b2ebd9d76 100644 --- a/Makefile +++ b/Makefile @@ -514,7 +514,7 @@ docker-save-diffusers: backend-images docker save local-ai-backend:diffusers -o backend-images/diffusers.tar docker-build-whisper: - docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:whisper -f backend/Dockerfile.golang --build-arg BACKEND=whisper . + docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) -t local-ai-backend:whisper -f backend/Dockerfile.golang --build-arg BACKEND=whisper . docker-save-whisper: backend-images docker save local-ai-backend:whisper -o backend-images/whisper.tar diff --git a/backend/backend.proto b/backend/backend.proto index 187294236..2ea7f6f10 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -282,6 +282,7 @@ message TranscriptRequest { uint32 threads = 4; bool translate = 5; bool diarize = 6; + string prompt = 7; } message TranscriptResult { diff --git a/backend/go/whisper/.gitignore b/backend/go/whisper/.gitignore index 017e34a10..7c42de3f1 100644 --- a/backend/go/whisper/.gitignore +++ b/backend/go/whisper/.gitignore @@ -3,5 +3,5 @@ sources/ build/ package/ whisper -libgowhisper.so - +*.so +compile_commands.json diff --git a/backend/go/whisper/gowhisper.cpp b/backend/go/whisper/gowhisper.cpp index 210b6b894..f1756d780 100644 --- a/backend/go/whisper/gowhisper.cpp +++ b/backend/go/whisper/gowhisper.cpp @@ -107,7 +107,7 @@ int vad(float pcmf32[], size_t pcmf32_len, float **segs_out, } int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz, - float pcmf32[], size_t pcmf32_len, size_t *segs_out_len) { + float pcmf32[], size_t pcmf32_len, size_t *segs_out_len, char *prompt) { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); @@ -122,8 +122,10 @@ int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz, wparams.debug_mode = true; wparams.print_progress = true; wparams.tdrz_enable = tdrz; + wparams.initial_prompt = prompt; fprintf(stderr, "info: Enable tdrz: %d\n", tdrz); + fprintf(stderr, "info: Initial prompt: \"%s\"\n", prompt); if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) { fprintf(stderr, "error: transcription failed\n"); diff --git a/backend/go/whisper/gowhisper.go b/backend/go/whisper/gowhisper.go index 0b0d5f6af..047f0ab88 100644 --- a/backend/go/whisper/gowhisper.go +++ b/backend/go/whisper/gowhisper.go @@ -17,7 +17,7 @@ var ( CppLoadModel func(modelPath string) int CppLoadModelVAD func(modelPath string) int CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int - CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int + CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int CppGetSegmentText func(i int) string CppGetSegmentStart func(i int) int64 CppGetSegmentEnd func(i int) int64 @@ -123,7 +123,7 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR segsLen := uintptr(0xdeadbeef) segsLenPtr := unsafe.Pointer(&segsLen) - if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr); ret != 0 { + if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 { return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe") } diff --git a/backend/go/whisper/gowhisper.h b/backend/go/whisper/gowhisper.h index 1c7f17a17..0e061cf93 100644 --- a/backend/go/whisper/gowhisper.h +++ b/backend/go/whisper/gowhisper.h @@ -7,7 +7,8 @@ int load_model_vad(const char *const model_path); int vad(float pcmf32[], size_t pcmf32_size, float **segs_out, size_t *segs_out_len); int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz, - float pcmf32[], size_t pcmf32_len, size_t *segs_out_len); + float pcmf32[], size_t pcmf32_len, size_t *segs_out_len, + char *prompt); const char *get_segment_text(int i); int64_t get_segment_t0(int i); int64_t get_segment_t1(int i); diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 9781e26fd..66e687813 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -12,7 +12,7 @@ import ( "github.com/mudler/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, translate bool, diarize bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { +func ModelTranscription(audio, language string, translate bool, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { if modelConfig.Backend == "" { modelConfig.Backend = model.WhisperBackend @@ -35,6 +35,7 @@ func ModelTranscription(audio, language string, translate bool, diarize bool, ml Translate: translate, Diarize: diarize, Threads: uint32(*modelConfig.Threads), + Prompt: prompt, }) if err != nil { return nil, err diff --git a/core/cli/transcript.go b/core/cli/transcript.go index 2beb00944..b15a14d6d 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -23,6 +23,7 @@ type TranscriptCMD struct { Diarize bool `short:"d" help:"Mark speaker turns"` Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` + Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"` } func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { @@ -57,7 +58,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { } }() - tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, ml, c, opts) + tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts) if err != nil { return err } diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index b9c7d8e53..f5f781caa 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -593,6 +593,11 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi session.ModelInterface = m } + if trUpd != nil { + trCur.Language = trUpd.Language + trCur.Prompt = trUpd.Prompt + } + if update.TurnDetection != nil && update.TurnDetection.Type != "" { session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type) session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams @@ -790,6 +795,7 @@ func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, e Language: session.InputAudioTranscription.Language, Translate: false, Threads: uint32(*cfg.Threads), + Prompt: session.InputAudioTranscription.Prompt, }) if err != nil { sendError(c, "transcription_failed", err.Error(), "", "event_TODO") diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index c5fc9f352..032b455ff 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -37,6 +37,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app } diarize := c.FormValue("diarize") != "false" + prompt := c.FormValue("prompt") // retrieve the file data from the request file, err := c.FormFile("file") @@ -69,7 +70,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app log.Debug().Msgf("Audio file copied to: %+v", dst) - tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, ml, *config, appConfig) + tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, prompt, ml, *config, appConfig) if err != nil { return err }