Files
LocalAI/pkg/grpc/server.go
Ettore Di Giacinto e86ade54a6 feat(api): add /v1/audio/diarization endpoint with sherpa-onnx + vibevoice.cpp (#9654)
* feat(api): add /v1/audio/diarization endpoint with sherpa-onnx + vibevoice.cpp

Closes #1648.

OpenAI-style multipart endpoint that returns "who spoke when". Single
endpoint instead of the issue's three-endpoint sketch (refactor /vad,
/vad/embedding, /diarization) — the typical client wants one call, and
embeddings can land later as a sibling without breaking this surface.

Response shape borrows from Pyannote/Deepgram: segments carry a
normalised SPEAKER_NN id (zero-padded, stable across the response) plus
the raw backend label, optional per-segment text when the backend bundles
ASR, and a speakers summary in verbose_json. response_format also accepts
rttm so consumers can pipe straight into pyannote.metrics / dscore.

Backends:

* vibevoice-cpp — Diarize() reuses the existing vv_capi_asr pass.
  vibevoice's ASR prompt asks the model to emit
  [{Start,End,Speaker,Content}] natively, so diarization is a by-product
  of the same pass; include_text=true preserves the transcript per
  segment, otherwise we drop it.

* sherpa-onnx — wraps the upstream SherpaOnnxOfflineSpeakerDiarization
  C API (pyannote segmentation + speaker-embedding extractor + fast
  clustering). libsherpa-shim grew config builders, a SetClustering
  wrapper for per-call num_clusters/threshold overrides, and a
  segment_at accessor (purego can't read field arrays out of
  SherpaOnnxOfflineSpeakerDiarizationSegment[] directly).

Plumbing: new Diarize gRPC RPC + DiarizeRequest / DiarizeSegment /
DiarizeResponse messages, threaded through interface.go, base, server,
client, embed. Default Base impl returns unimplemented.

Capability surfaces all updated: FLAG_DIARIZATION usecase,
FeatureAudioDiarization permission (default-on), RouteFeatureRegistry
entries for /v1/audio/diarization and /audio/diarization, audio
instruction-def description widened, CAP_DIARIZATION JS symbol,
swagger regenerated, /api/instructions discovery map updated.

Tests:

* core/backend: speaker-label normalisation (first-seen → SPEAKER_NN,
  per-speaker totals, nil-safety, fallback to backend NumSpeakers when
  no segments).

* core/http/endpoints/openai: RTTM rendering (file-id basename, negative
  duration clamping, fallback id).

* tests/e2e: mock-backend grew a deterministic Diarize that emits
  raw labels "5","2","5" so the e2e suite verifies SPEAKER_NN
  remapping, verbose_json speakers summary + transcript pass-through
  (gated by include_text), RTTM bytes content-type, and rejection of
  unknown response_format. mock-diarize model config registered with
  known_usecases=[FLAG_DIARIZATION] to bypass the backend-name guard.

Docs: new features/audio-diarization.md (request/response, RTTM example,
sherpa-onnx + vibevoice setup), cross-link from audio-to-text.md, entry
in whats-new.md.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

* fix(diarization): correct sherpa-onnx symbol name + lint cleanup

CI failures on #9654:

* sherpa-onnx-grpc-{tts,transcription} and sherpa-onnx-realtime panicked
  at backend startup with `undefined symbol: SherpaOnnxDestroyOfflineSpeakerDiarizationResult`.
  Upstream's actual symbol is SherpaOnnxOfflineSpeakerDiarizationDestroyResult
  (Destroy in the middle, not the prefix); the rest of the diarization
  surface follows the same naming pattern. The mismatched name made
  purego.RegisterLibFunc fail at dlopen time and crashed the gRPC server
  before the BeforeAll could probe Health, taking down every sherpa-onnx
  test job — not just the diarization-related ones.

* golangci-lint flagged 5 errcheck violations on new defer cleanups
  (os.RemoveAll / Close / conn.Close); wrap each in a `defer func() { _ = X() }()`
  closure (matches the pattern other LocalAI files use for new code, since
  pre-existing bare defers are grandfathered in via new-from-merge-base).

* golangci-lint also flagged forbidigo violations: the new
  diarization_test.go files used testing.T-style `t.Errorf` / `t.Fatalf`,
  which are forbidden by the project's coding-style policy
  (.agents/coding-style.md). Convert both files to Ginkgo/Gomega
  Describe/It with Expect(...) — they get picked up by the existing
  TestBackend / TestOpenAI suites, no new suite plumbing needed.

* modernize linter: tightened the diarization segment loop to
  `for i := range int(numSegments)` (Go 1.22+ idiom).

Verified locally: golangci-lint with new-from-merge-base=origin/master
reports 0 issues across all touched packages, and the four mocked
diarization e2e specs in tests/e2e/mock_backend_test.go still pass.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

* fix(vibevoice-cpp): convert non-WAV input via ffmpeg + raise ASR token budget

Confirmed end-to-end against a real LocalAI instance with vibevoice-asr-q4_k
loaded and the multi-speaker MP3 sample at vibevoice.cpp/samples/2p_argument.mp3:
both /v1/audio/transcriptions and /v1/audio/diarization now succeed and
return correctly attributed speaker turns for the full clip.

Two latent issues surfaced once the diarization endpoint actually exercised
the backend with a non-trivial input:

1. vv_capi_asr only accepts WAV via load_wav_24k_mono. The previous code
   passed the uploaded path straight through, so anything that wasn't
   already a 24 kHz mono s16le WAV failed at the C side with rc=-8 and
   the very unhelpful "vv_capi_asr failed". prepareWavInput shells out
   to ffmpeg ("-ar 24000 -ac 1 -acodec pcm_s16le") in a per-call temp
   dir, matching the rate the model was trained on; both AudioTranscription
   and Diarize now route through it. This is the same shape sherpa-onnx
   uses (utils.AudioToWav), but vibevoice needs 24 kHz rather than 16 kHz
   so we don't reuse that helper.

2. The C ABI's max_new_tokens defaults to 256 when 0 is passed. That's
   fine for a five-second clip but not for anything past ~10 s — vibevoice
   stops mid-JSON, the parse fails, and the caller sees a hard error.
   Pass a much larger budget (16 384 ≈ ~9 minutes of speech at the
   model's ~30 tok/s rate); generation stops at EOS so this is a cap
   rather than a target.

3. As a defensive belt-and-braces, mirror AudioTranscription's existing
   "fall back to a single segment if the model emits non-JSON text"
   pattern in Diarize, so partial / unusual model output never produces
   a 500. This kept the endpoint usable while diagnosing (1) and (2),
   and is the right behaviour to keep.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

* fix(vibevoice-cpp): pass valid WAVs through directly so ffmpeg is not required at runtime

Spotted by tests-e2e-backend (1.25.x): the previous fix forced every
incoming audio file through `ffmpeg -ar 24000 ...`, which meant the
backend container — which does not ship ffmpeg — failed even for the
existing happy path where the caller already uploads a WAV. The
container-side error was:

    rpc error: code = Unknown desc = vibevoice-cpp: ffmpeg convert to
    24k mono wav: exec: "ffmpeg": executable file not found in $PATH

Reading vibevoice.cpp's audio_io.cpp, `load_wav_24k_mono` uses drwav and
already accepts any PCM/IEEE-float WAV at any sample rate, downmixes
multi-channel input to mono, and resamples to 24 kHz internally. So the
only inputs that genuinely need an external converter are non-WAV
formats (MP3, OGG, FLAC, ...).

Detect WAVs by RIFF/WAVE magic at bytes 0..3 / 8..11 and pass them
straight through with a no-op cleanup; everything else still goes
through ffmpeg with the same 24 kHz mono s16le target. The result:

* Container builds without ffmpeg keep working for WAV uploads
  (the e2e-backends fixture is jfk.wav at 16 kHz mono s16le).
* MP3 and other non-WAV inputs still get the new ffmpeg conversion
  path so the diarization endpoint stays useful.
* If the caller uploads a non-WAV but ffmpeg isn't on PATH, the
  surfaced error is still descriptive enough to act on.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

* fix(ci): make gcc-14 install in Dockerfile.golang best-effort for jammy bases

The LocalVQE PR (bb033b16) made `gcc-14 g++-14` an unconditional apt
install in backend/Dockerfile.golang and pointed update-alternatives at
them. That works on the default `BASE_IMAGE=ubuntu:24.04` (noble has
gcc-14 in main), but every Go backend that builds on
`nvcr.io/nvidia/l4t-jetpack:r36.4.0` — jammy under the hood — now fails
at the apt step:

    E: Unable to locate package gcc-14

This blocked unrelated jobs:
backend-jobs(*-nvidia-l4t-arm64-{stablediffusion-ggml, sam3-cpp, whisper,
acestep-cpp, qwen3-tts-cpp, vibevoice-cpp}). LocalVQE itself is only
matrix-built on ubuntu:24.04 (CPU + Vulkan), so it doesn't actually
need gcc-14 anywhere else.

Make the gcc-14 install conditional on the package being available in
the configured apt repos. On noble: identical behaviour to today (gcc-14
installed, update-alternatives points at it). On jammy: skip the
gcc-14 stanza entirely and let build-essential's default gcc take over,
which is what the other Go backends compile with anyway.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-7 [Claude Code]

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-05-05 15:10:13 +02:00

721 lines
18 KiB
Go

package grpc
import (
"context"
"crypto/subtle"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"strings"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// A GRPC Server that allows to run LLM inference.
// It is used by the LLMServices to expose the LLM functionalities that are called by the client.
// The GRPC Service is general, trying to encompass all the possible LLM options models.
// It depends on the real implementer then what can be done or not.
//
// The server is implemented as a GRPC service, with the following methods:
// - Predict: to run the inference with options
// - PredictStream: to run the inference with options and stream the results
// server is used to implement helloworld.GreeterServer.
type server struct {
pb.UnimplementedBackendServer
llm AIModel
}
func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) {
return newReply("OK"), nil
}
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
embeds, err := s.llm.Embeddings(in)
if err != nil {
return nil, err
}
return &pb.EmbeddingResult{Embeddings: embeds}, nil
}
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.Load(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Loading succeeded", Success: true}, nil
}
func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
result, err := s.llm.Predict(in)
return newReply(result), err
}
func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.GenerateImage(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Image generated", Success: true}, nil
}
func (s *server) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.GenerateVideo(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating video: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Video generated", Success: true}, nil
}
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.TTS(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "TTS audio generated", Success: true}, nil
}
func (s *server) TTSStream(in *pb.TTSRequest, stream pb.Backend_TTSStreamServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
audioChan := make(chan []byte)
done := make(chan bool)
go func() {
for audioChunk := range audioChan {
stream.Send(&pb.Reply{Audio: audioChunk})
}
done <- true
}()
err := s.llm.TTSStream(in, audioChan)
<-done
return err
}
func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.SoundGeneration(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Sound Generation audio generated", Success: true}, nil
}
func (s *server) Detect(ctx context.Context, in *pb.DetectOptions) (*pb.DetectResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.Detect(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) FaceVerify(ctx context.Context, in *pb.FaceVerifyRequest) (*pb.FaceVerifyResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.FaceVerify(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) FaceAnalyze(ctx context.Context, in *pb.FaceAnalyzeRequest) (*pb.FaceAnalyzeResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.FaceAnalyze(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) VoiceVerify(ctx context.Context, in *pb.VoiceVerifyRequest) (*pb.VoiceVerifyResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.VoiceVerify(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) VoiceAnalyze(ctx context.Context, in *pb.VoiceAnalyzeRequest) (*pb.VoiceAnalyzeResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.VoiceAnalyze(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) VoiceEmbed(ctx context.Context, in *pb.VoiceEmbedRequest) (*pb.VoiceEmbedResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.VoiceEmbed(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
result, err := s.llm.AudioTranscription(in)
if err != nil {
return nil, err
}
tresult := &pb.TranscriptResult{}
for _, s := range result.Segments {
tks := []int32{}
for _, t := range s.Tokens {
tks = append(tks, int32(t))
}
tresult.Segments = append(tresult.Segments,
&pb.TranscriptSegment{
Text: s.Text,
Id: int32(s.Id),
Start: int64(s.Start),
End: int64(s.End),
Tokens: tks,
Speaker: s.Speaker,
})
}
tresult.Text = result.Text
tresult.Language = result.Language
tresult.Duration = result.Duration
return tresult, nil
}
func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Backend_AudioTranscriptionStreamServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
resultChan := make(chan *pb.TranscriptStreamResponse)
done := make(chan bool)
go func() {
for chunk := range resultChan {
stream.Send(chunk)
}
done <- true
}()
err := s.llm.AudioTranscriptionStream(in, resultChan)
<-done
return err
}
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
resultChan := make(chan string)
done := make(chan bool)
go func() {
for result := range resultChan {
stream.Send(newReply(result))
}
done <- true
}()
err := s.llm.PredictStream(in, resultChan)
<-done
return err
}
func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.TokenizeString(in)
if err != nil {
return nil, err
}
castTokens := make([]int32, len(res.Tokens))
for i, v := range res.Tokens {
castTokens[i] = int32(v)
}
return &pb.TokenizationResponse{
Length: int32(res.Length),
Tokens: castTokens,
}, err
}
func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) {
res, err := s.llm.Status()
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.StoresSet(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error setting entry: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Set key", Success: true}, nil
}
func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.StoresDelete(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Deleted key", Success: true}, nil
}
func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.StoresGet(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.StoresFind(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.VAD(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) Diarize(ctx context.Context, in *pb.DiarizeRequest) (*pb.DiarizeResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.Diarize(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.AudioEncode(in)
if err != nil {
return nil, err
}
return res, nil
}
func (s *server) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.AudioDecode(in)
if err != nil {
return nil, err
}
return res, nil
}
func (s *server) AudioTransform(ctx context.Context, in *pb.AudioTransformRequest) (*pb.AudioTransformResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.AudioTransform(in)
if err != nil {
return nil, err
}
return res, nil
}
func (s *server) AudioTransformStream(stream pb.Backend_AudioTransformStreamServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
in := make(chan *pb.AudioTransformFrameRequest, 4)
out := make(chan *pb.AudioTransformFrameResponse, 4)
// Pump incoming frames from the gRPC stream into `in`. EOF closes the
// channel, which signals the backend that the client is done sending.
recvErrCh := make(chan error, 1)
go func() {
defer close(in)
for {
req, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
recvErrCh <- nil
return
}
recvErrCh <- err
return
}
select {
case in <- req:
case <-stream.Context().Done():
recvErrCh <- stream.Context().Err()
return
}
}
}()
// Pump outgoing frames from `out` to the gRPC stream. The backend closes
// `out` on completion.
sendDone := make(chan error, 1)
go func() {
for resp := range out {
if err := stream.Send(resp); err != nil {
sendDone <- err
// Drain `out` so the backend can finish.
for range out {
}
return
}
}
sendDone <- nil
}()
backendErr := s.llm.AudioTransformStream(in, out)
sendErr := <-sendDone
recvErr := <-recvErrCh
if backendErr != nil {
return backendErr
}
if sendErr != nil {
return sendErr
}
return recvErr
}
func (s *server) StartFineTune(ctx context.Context, in *pb.FineTuneRequest) (*pb.FineTuneJobResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.StartFineTune(in)
if err != nil {
return &pb.FineTuneJobResult{Success: false, Message: fmt.Sprintf("Error starting fine-tune: %s", err.Error())}, err
}
return res, nil
}
func (s *server) FineTuneProgress(in *pb.FineTuneProgressRequest, stream pb.Backend_FineTuneProgressServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
updateChan := make(chan *pb.FineTuneProgressUpdate)
done := make(chan bool)
go func() {
for update := range updateChan {
stream.Send(update)
}
done <- true
}()
err := s.llm.FineTuneProgress(in, updateChan)
<-done
return err
}
func (s *server) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.StopFineTune(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error stopping fine-tune: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Fine-tune stopped", Success: true}, nil
}
func (s *server) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.ListCheckpoints(in)
if err != nil {
return nil, err
}
return res, nil
}
func (s *server) ExportModel(ctx context.Context, in *pb.ExportModelRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.ExportModel(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error exporting model: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Model exported", Success: true}, nil
}
func (s *server) StartQuantization(ctx context.Context, in *pb.QuantizationRequest) (*pb.QuantizationJobResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.StartQuantization(in)
if err != nil {
return &pb.QuantizationJobResult{Success: false, Message: fmt.Sprintf("Error starting quantization: %s", err.Error())}, err
}
return res, nil
}
func (s *server) QuantizationProgress(in *pb.QuantizationProgressRequest, stream pb.Backend_QuantizationProgressServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
updateChan := make(chan *pb.QuantizationProgressUpdate)
done := make(chan bool)
go func() {
for update := range updateChan {
stream.Send(update)
}
done <- true
}()
err := s.llm.QuantizationProgress(in, updateChan)
<-done
return err
}
func (s *server) StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.StopQuantization(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error stopping quantization: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Quantization stopped", Success: true}, nil
}
func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.ModelMetadata(in)
if err != nil {
return nil, err
}
return res, nil
}
func (s *server) Free(ctx context.Context, in *pb.HealthMessage) (*pb.Result, error) {
if err := s.llm.Free(); err != nil {
return &pb.Result{Success: false, Message: err.Error()}, nil
}
return &pb.Result{Success: true}, nil
}
// NewBackendServer creates a pb.BackendServer.
func NewBackendServer(model AIModel) pb.BackendServer {
return &server{llm: model}
}
// AuthTokenEnvVar is the environment variable used to configure gRPC bearer token auth.
const AuthTokenEnvVar = "LOCALAI_GRPC_AUTH_TOKEN"
// validateToken extracts the bearer token from gRPC metadata and validates it.
func validateToken(ctx context.Context, expected string) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Error(codes.Unauthenticated, "missing metadata")
}
values := md.Get("authorization")
if len(values) == 0 {
return status.Error(codes.Unauthenticated, "missing authorization header")
}
raw := values[0]
if !strings.HasPrefix(raw, "Bearer ") {
return status.Error(codes.Unauthenticated, "authorization must use Bearer scheme")
}
token := strings.TrimPrefix(raw, "Bearer ")
if subtle.ConstantTimeCompare([]byte(token), []byte(expected)) != 1 {
return status.Error(codes.Unauthenticated, "invalid token")
}
return nil
}
func tokenUnaryInterceptor(token string) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if err := validateToken(ctx, token); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
func tokenStreamInterceptor(token string) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := validateToken(ss.Context(), token); err != nil {
return err
}
return handler(srv, ss)
}
}
// serverOpts returns the common gRPC server options, including auth interceptors
// when LOCALAI_GRPC_AUTH_TOKEN is set.
func serverOpts() []grpc.ServerOption {
opts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(maxGRPCMessageSize),
grpc.MaxSendMsgSize(maxGRPCMessageSize),
}
if token := os.Getenv(AuthTokenEnvVar); token != "" {
opts = append(opts,
grpc.UnaryInterceptor(tokenUnaryInterceptor(token)),
grpc.StreamInterceptor(tokenStreamInterceptor(token)),
)
log.Printf("gRPC auth enabled via %s", AuthTokenEnvVar)
}
return opts
}
func StartServer(address string, model AIModel) error {
lis, err := net.Listen("tcp", address)
if err != nil {
return err
}
s := grpc.NewServer(serverOpts()...)
pb.RegisterBackendServer(s, &server{llm: model})
log.Printf("gRPC Server listening at %v", lis.Addr())
if err := s.Serve(lis); err != nil {
return err
}
return nil
}
func RunServer(address string, model AIModel) (func() error, error) {
lis, err := net.Listen("tcp", address)
if err != nil {
return nil, err
}
s := grpc.NewServer(serverOpts()...)
pb.RegisterBackendServer(s, &server{llm: model})
log.Printf("gRPC Server listening at %v", lis.Addr())
if err = s.Serve(lis); err != nil {
return func() error {
return lis.Close()
}, err
}
return func() error {
s.GracefulStop()
return nil
}, nil
}