mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-04 23:06:22 -04:00
Compare commits
17 Commits
fix/model-
...
fix/distri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc33d4f4a5 | ||
|
|
9f41e69bc3 | ||
|
|
ef80a0e825 | ||
|
|
92726f7631 | ||
|
|
994063ba9a | ||
|
|
c1a55cf72d | ||
|
|
96758841d8 | ||
|
|
7a59260621 | ||
|
|
27e63b9a78 | ||
|
|
55c0911c23 | ||
|
|
f6cb6ab6d9 | ||
|
|
9f11b09c6a | ||
|
|
a5c4f822f0 | ||
|
|
fb36c262fe | ||
|
|
0e4e8980e6 | ||
|
|
3a932a9803 | ||
|
|
9d10418593 |
9
Makefile
9
Makefile
@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
|
||||
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
|
||||
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
|
||||
test-e2e-distributed: protogen-go
|
||||
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
# test-e2e-distributed.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
|
||||
@@ -537,6 +537,15 @@ message TTSRequest {
|
||||
string dst = 3;
|
||||
string voice = 4;
|
||||
optional string language = 5;
|
||||
// instructions is a free-form, per-request style/voice description (maps to
|
||||
// the OpenAI `instructions` field). Backends that support expressive synthesis
|
||||
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
|
||||
// option when set; backends that don't simply ignore it.
|
||||
optional string instructions = 6;
|
||||
// params carries optional, backend-specific per-request generation parameters
|
||||
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
|
||||
// coerced by the backend; unset leaves the backend's configured defaults.
|
||||
map<string, string> params = 7;
|
||||
}
|
||||
|
||||
message VADRequest {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
# Upstream pin lives below as DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=ba00a8a88c4c5810a3d1fed6b7b8fa2b44b82fdc
|
||||
DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=5dcb71166686799f0d873eab7386234302d05ecf
|
||||
LLAMA_VERSION?=94a220cd6745e6e3f8de62870b66fd5b9bc92700
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
@@ -145,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -175,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -269,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=05e60432bcb5bc2113f8c395a41e86497c11504a
|
||||
CRISPASR_VERSION?=13d54e110e1538e0f0bc3af0680b9ab246cfb48d
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
@@ -15,7 +15,7 @@
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=9edf17c3ada66e0f881dcff155492867db7ac4cf
|
||||
PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -230,16 +231,25 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
|
||||
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
// Fallback when the batched C-API is unavailable: transcribe directly from
|
||||
// the file path (original behavior, no batching).
|
||||
// Fallback when the batched C-API is unavailable: transcribe from a file
|
||||
// path (original behavior, no batching). The C library's audio loader only
|
||||
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
|
||||
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
|
||||
// mirrors what every other audio backend (whisper, crispasr) does via
|
||||
// utils.AudioToWav before handing the file to the engine.
|
||||
if p.bat == nil {
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, opts.Dst, 0)
|
||||
converted, cleanup, err := convertToWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
@@ -342,7 +352,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return errors.New("parakeet-cpp: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
@@ -460,17 +470,33 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
// decodes the PCM.
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
|
||||
// a fresh temp dir and returns the path together with a cleanup func the caller
|
||||
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
|
||||
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
|
||||
// Used by the direct (non-batched) transcription path, which hands a file path
|
||||
// to the C library's WAV-only audio loader.
|
||||
func convertToWavMono16k(path string) (string, func(), error) {
|
||||
dir, err := os.MkdirTemp("", "parakeet")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return "", func() {}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
|
||||
converted := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(path, converted); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, err
|
||||
}
|
||||
return converted, cleanup, nil
|
||||
}
|
||||
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
converted, cleanup, err := convertToWavMono16k(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
fh, err := os.Open(converted)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,11 +3,14 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -70,6 +73,24 @@ func fixturesOrSkip() (string, string) {
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
|
||||
// path. The result is already in AudioToWav's target format, so the conversion
|
||||
// helper copies it through without invoking ffmpeg.
|
||||
func writeMono16kWav(path string, samples int) {
|
||||
GinkgoHelper()
|
||||
f, err := os.Create(path)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
|
||||
SourceBitDepth: 16,
|
||||
Data: make([]int, samples),
|
||||
}
|
||||
Expect(enc.Write(buf)).To(Succeed())
|
||||
Expect(enc.Close()).To(Succeed())
|
||||
Expect(f.Close()).To(Succeed())
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCpp", func() {
|
||||
Context("AudioTranscription", func() {
|
||||
It("transcribes a WAV via the parakeet C-API", func() {
|
||||
@@ -120,6 +141,39 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("convertToWavMono16k", func() {
|
||||
// The non-batched transcription path hands a file path to the C
|
||||
// library's WAV-only audio loader, so it must convert first.
|
||||
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
|
||||
// without ffmpeg, which lets us exercise the helper (and the
|
||||
// regression: the direct path used to skip conversion entirely)
|
||||
// without a model, the C library, or ffmpeg.
|
||||
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
src := filepath.Join(dir, "input.wav")
|
||||
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
|
||||
|
||||
converted, cleanup, err := convertToWavMono16k(src)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// It must produce a fresh temp file, not return the original path.
|
||||
Expect(converted).ToNot(Equal(src))
|
||||
Expect(converted).To(BeAnExistingFile())
|
||||
|
||||
pcm, _, err := decodeWavMono16k(converted)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
|
||||
|
||||
cleanup()
|
||||
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
|
||||
})
|
||||
|
||||
It("errors on a non-existent input rather than passing the path through", func() {
|
||||
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
|
||||
// Streaming needs a cache-aware streaming model (e.g.
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# qwen3-tts.cpp version
|
||||
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
|
||||
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
|
||||
QWEN3TTS_CPP_VERSION?=136e5d36c17083da0321fd96512dc7b263f94a44
|
||||
SO_TARGET?=libgoqwen3ttscpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -21,6 +22,43 @@ type Qwen3TtsCpp struct {
|
||||
threads int
|
||||
}
|
||||
|
||||
// languageNameAliases maps common full language names to the canonical
|
||||
// two-letter code understood by the C++ language_to_id table.
|
||||
var languageNameAliases = map[string]string{
|
||||
"english": "en",
|
||||
"russian": "ru",
|
||||
"chinese": "zh",
|
||||
"japanese": "ja",
|
||||
"korean": "ko",
|
||||
"german": "de",
|
||||
"french": "fr",
|
||||
"spanish": "es",
|
||||
"italian": "it",
|
||||
"portuguese": "pt",
|
||||
}
|
||||
|
||||
// normalizeLanguage coerces a caller-supplied language into the canonical code
|
||||
// the model expects. It lowercases, trims, strips any region/locale suffix
|
||||
// (en-US, en_US, ja.JP -> en/ja), and resolves common full names (english -> en).
|
||||
// An empty input stays empty so the C++ side applies its English default; an
|
||||
// unrecognized value is returned normalized so C++ can log it and default.
|
||||
func normalizeLanguage(lang string) string {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
if lang == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip region/locale suffix: keep the segment before the first separator.
|
||||
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||
lang = lang[:i]
|
||||
}
|
||||
|
||||
if code, ok := languageNameAliases[lang]; ok {
|
||||
return code
|
||||
}
|
||||
return lang
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
|
||||
// ModelFile is the model directory path (containing GGUF files)
|
||||
modelDir := opts.ModelFile
|
||||
@@ -54,7 +92,7 @@ func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
|
||||
dst := req.Dst
|
||||
language := ""
|
||||
if req.Language != nil {
|
||||
language = *req.Language
|
||||
language = normalizeLanguage(*req.Language)
|
||||
}
|
||||
|
||||
// Synthesis parameters with sensible defaults
|
||||
|
||||
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLanguageNormalization(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "qwen3-tts-cpp language normalization")
|
||||
}
|
||||
|
||||
var _ = Describe("normalizeLanguage", func() {
|
||||
DescribeTable("maps caller input to the canonical model language code",
|
||||
func(input, expected string) {
|
||||
Expect(normalizeLanguage(input)).To(Equal(expected))
|
||||
},
|
||||
// Canonical codes pass through unchanged
|
||||
Entry("canonical en", "en", "en"),
|
||||
Entry("canonical zh", "zh", "zh"),
|
||||
Entry("canonical pt", "pt", "pt"),
|
||||
|
||||
// Case-insensitive
|
||||
Entry("uppercase", "EN", "en"),
|
||||
Entry("mixed case", "Ja", "ja"),
|
||||
|
||||
// Surrounding whitespace
|
||||
Entry("trims whitespace", " en ", "en"),
|
||||
|
||||
// Region/locale stripping
|
||||
Entry("BCP-47 region", "en-US", "en"),
|
||||
Entry("underscore region", "en_US", "en"),
|
||||
Entry("dotted locale", "ja.JP", "ja"),
|
||||
Entry("region + case", "ZH-CN", "zh"),
|
||||
|
||||
// Full-name aliases
|
||||
Entry("english name", "english", "en"),
|
||||
Entry("chinese name cased", "Chinese", "zh"),
|
||||
Entry("japanese name", "japanese", "ja"),
|
||||
Entry("russian name", "russian", "ru"),
|
||||
Entry("portuguese name", "portuguese", "pt"),
|
||||
|
||||
// Empty stays empty (C++ applies the English default)
|
||||
Entry("empty", "", ""),
|
||||
Entry("whitespace only", " ", ""),
|
||||
|
||||
// Unknown values pass through normalized so C++ can log + default
|
||||
Entry("unknown code", "klingon", "klingon"),
|
||||
Entry("unknown with region", "xx-YY", "xx"),
|
||||
)
|
||||
})
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=2d40a8b2adcdf8b5b0ca0535f3bb7801b6ba13e5
|
||||
STABLEDIFFUSION_GGML_VERSION?=1f9ee88e09c258053fa59d5e05e23dfb10fa0b13
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -37,6 +37,20 @@ def is_int(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a TTSRequest.params value (string on the wire) to the type the
|
||||
Chatterbox generate() kwargs expect (float/int/bool), matching how static
|
||||
YAML options are coerced at load time. Non-string values pass through."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if is_float(value):
|
||||
return float(value)
|
||||
if is_int(value):
|
||||
return int(value)
|
||||
if value.lower() in ["true", "false"]:
|
||||
return value.lower() == "true"
|
||||
return value
|
||||
|
||||
def split_text_at_word_boundary(text, max_length=250):
|
||||
"""
|
||||
Split text at word boundaries without truncating words.
|
||||
@@ -191,6 +205,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Merge per-request params (TTSRequest.params), overriding the static
|
||||
# YAML options. This exposes Chatterbox generation knobs (e.g.
|
||||
# exaggeration, cfg_weight, temperature) per request. Values arrive as
|
||||
# strings on the wire and are coerced to float/int/bool.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Check if text exceeds 250 characters
|
||||
# (chatterbox does not support long text)
|
||||
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||
|
||||
@@ -47,6 +47,26 @@ def is_int(s):
|
||||
return False
|
||||
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a string param value (from the TTSRequest.params map, which is
|
||||
string-typed on the wire) into the most specific Python type the model
|
||||
generation kwargs expect: bool, int, float, else the original string."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
lowered = value.strip().lower()
|
||||
if lowered in ("true", "false"):
|
||||
return lowered == "true"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
@@ -322,6 +342,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _effective_instruct(self, request):
|
||||
"""Resolve the instruction/style string for this request, preferring the
|
||||
per-request TTSRequest.instructions value and falling back to the static
|
||||
YAML `instruct` option. Empty string means "no instruction"."""
|
||||
req_instruct = (
|
||||
request.instructions
|
||||
if hasattr(request, "instructions") and request.instructions
|
||||
else ""
|
||||
)
|
||||
if req_instruct:
|
||||
return req_instruct
|
||||
return self.options.get("instruct", "") or ""
|
||||
|
||||
def _detect_mode(self, request):
|
||||
"""Detect which mode to use based on request parameters."""
|
||||
# Priority: VoiceClone > VoiceDesign > CustomVoice
|
||||
@@ -338,8 +371,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.audio_path or self.voices:
|
||||
return "VoiceClone"
|
||||
|
||||
# VoiceDesign: instruct option is provided
|
||||
if "instruct" in self.options and self.options["instruct"]:
|
||||
# VoiceDesign: instruct provided per-request or via YAML option
|
||||
if self._effective_instruct(request):
|
||||
return "VoiceDesign"
|
||||
|
||||
# Default to CustomVoice
|
||||
@@ -690,10 +723,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if do_sample is not None:
|
||||
generation_kwargs["do_sample"] = do_sample
|
||||
|
||||
instruct = self.options.get("instruct", "")
|
||||
# Prefer the per-request instruction (TTSRequest.instructions) over the
|
||||
# static YAML `instruct` option. This lets clients set a different style
|
||||
# (CustomVoice emotion) or designed voice (VoiceDesign) per request.
|
||||
instruct = self._effective_instruct(request)
|
||||
if instruct is not None and instruct != "":
|
||||
generation_kwargs["instruct"] = instruct
|
||||
|
||||
# Merge any per-request backend-specific params (TTSRequest.params).
|
||||
# Values arrive as strings on the wire; coerce to int/float/bool so the
|
||||
# model receives the types it expects. These override YAML-derived kwargs.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
generation_kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Generate audio based on mode
|
||||
if mode == "VoiceClone":
|
||||
# VoiceClone mode
|
||||
|
||||
@@ -102,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||
natsAuth := cfg.Distributed.NatsAuthConfig()
|
||||
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
|
||||
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
|
||||
}
|
||||
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
|
||||
@@ -123,14 +123,14 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
|
||||
})
|
||||
|
||||
It("ModelTTS forwards the request context to the SmartRouter", func() {
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", loader, appCfg, modelCfg)
|
||||
_, _, err := backend.ModelTTS(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
})
|
||||
|
||||
It("ModelTTSStream forwards the request context to the SmartRouter", func() {
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
err := backend.ModelTTSStream(reqCtx, "hello", "", "", "", nil, loader, appCfg, modelCfg, func([]byte) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))
|
||||
stampViaRouterCtx()
|
||||
|
||||
@@ -20,11 +20,32 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// newTTSRequest assembles the gRPC TTSRequest from the per-request inputs. The
|
||||
// optional instructions string is only attached when non-empty so backends can
|
||||
// distinguish "no per-request instruction" (fall back to YAML) from an explicit
|
||||
// empty one. params is forwarded as-is (nil when unset).
|
||||
func newTTSRequest(text, modelPath, voice, dst, language, instructions string, params map[string]string) *proto.TTSRequest {
|
||||
req := &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: dst,
|
||||
Language: &language,
|
||||
Params: params,
|
||||
}
|
||||
if instructions != "" {
|
||||
req.Instructions = &instructions
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func ModelTTS(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -74,13 +95,9 @@ func ModelTTS(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: filePath,
|
||||
Language: &language,
|
||||
})
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, filePath, language, instructions, params)
|
||||
|
||||
res, err := ttsModel.TTS(ctx, ttsRequest)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
@@ -128,7 +145,9 @@ func ModelTTSStream(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
language,
|
||||
instructions string,
|
||||
params map[string]string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
@@ -177,12 +196,10 @@ func ModelTTSStream(
|
||||
var totalPCMBytes int
|
||||
snippetCapped := false
|
||||
|
||||
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Language: &language,
|
||||
}, func(reply *proto.Reply) {
|
||||
// Streaming TTS writes to the HTTP response, not a file, so dst is empty.
|
||||
ttsRequest := newTTSRequest(text, modelPath, voice, "", language, instructions, params)
|
||||
|
||||
err = ttsModel.TTSStream(ctx, ttsRequest, func(reply *proto.Reply) {
|
||||
// First message contains sample rate info
|
||||
if !headerSent && len(reply.Message) > 0 {
|
||||
var info map[string]any
|
||||
|
||||
42
core/backend/tts_test.go
Normal file
42
core/backend/tts_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
// Specs for the TTSRequest assembly that carries the per-request
|
||||
// instructions/params from the OpenAI `instructions` field (and the LocalAI
|
||||
// `params` extension) through to the gRPC boundary. Before this plumbing the
|
||||
// instruction value was dropped before reaching the backend; these specs pin
|
||||
// that it now survives, and that the empty case stays backward compatible.
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("newTTSRequest", func() {
|
||||
It("attaches the instructions when a per-request value is set", func() {
|
||||
req := newTTSRequest("hi", "/m", "alloy", "/out.wav", "en", "cheerful narrator", nil)
|
||||
Expect(req.Instructions).ToNot(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal("cheerful narrator"))
|
||||
Expect(req.GetText()).To(Equal("hi"))
|
||||
Expect(req.GetVoice()).To(Equal("alloy"))
|
||||
Expect(req.GetDst()).To(Equal("/out.wav"))
|
||||
Expect(req.GetLanguage()).To(Equal("en"))
|
||||
})
|
||||
|
||||
It("leaves instructions unset when empty so backends fall back to YAML", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.Instructions).To(BeNil())
|
||||
Expect(req.GetInstructions()).To(Equal(""))
|
||||
})
|
||||
|
||||
It("forwards per-request params through to the backend", func() {
|
||||
params := map[string]string{"exaggeration": "0.7", "cfg_weight": "0.3"}
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", params)
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("exaggeration", "0.7"))
|
||||
Expect(req.GetParams()).To(HaveKeyWithValue("cfg_weight", "0.3"))
|
||||
})
|
||||
|
||||
It("leaves params nil when none are supplied", func() {
|
||||
req := newTTSRequest("hi", "/m", "", "/out.wav", "", "", nil)
|
||||
Expect(req.GetParams()).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -52,6 +52,15 @@ type AgentWorkerCMD struct {
|
||||
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
||||
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
||||
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// Timeouts
|
||||
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
||||
}
|
||||
@@ -81,15 +90,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
registrationBody["token"] = cmd.RegistrationToken
|
||||
}
|
||||
|
||||
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
||||
// other background goroutines.
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Acquire credentials via (re)registration. When the bus requires auth and no
|
||||
// static fallback is configured, wait through admin approval until the
|
||||
// frontend mints credentials rather than starting unauthenticated.
|
||||
credMgr := workerregistry.NewNATSCredentialManager(
|
||||
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
||||
return regClient.RegisterFull(ctx, registrationBody)
|
||||
},
|
||||
cmd.NatsRequireAuth && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
|
||||
)
|
||||
res, err := credMgr.Acquire(shutdownCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
nodeID := res.ID
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||
|
||||
// Use provisioned API token if none was set
|
||||
if cmd.APIToken == "" {
|
||||
cmd.APIToken = apiToken
|
||||
cmd.APIToken = res.APIToken
|
||||
}
|
||||
|
||||
// Start heartbeat
|
||||
@@ -98,14 +122,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cmd.NatsURL)
|
||||
// Resolve NATS credentials with precedence: explicit env override, then
|
||||
// frontend-minted (auto-refreshed before expiry), then service fallback.
|
||||
// Each static source must supply JWT and seed together.
|
||||
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
|
||||
var natsOpts []messaging.Option
|
||||
switch {
|
||||
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
|
||||
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
|
||||
case credMgr.HasCredentials():
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
||||
go func() {
|
||||
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
||||
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
|
||||
shutdownCancel()
|
||||
}
|
||||
}()
|
||||
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
|
||||
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
|
||||
case cmd.NatsRequireAuth:
|
||||
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
|
||||
}
|
||||
if natsTLS.Enabled() {
|
||||
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
|
||||
}
|
||||
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
@@ -183,17 +233,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
||||
|
||||
// Wait for shutdown
|
||||
// Wait for an OS signal or an internal fatal condition (e.g. NATS
|
||||
// credentials became unrenewable), so the worker restarts and re-acquires
|
||||
// rather than lingering unable to serve.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
var runErr error
|
||||
select {
|
||||
case <-sigCh:
|
||||
case <-shutdownCtx.Done():
|
||||
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
|
||||
xlog.Error("Internal shutdown requested", "error", runErr)
|
||||
}
|
||||
|
||||
xlog.Info("Shutting down agent worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
dispatcher.Stop()
|
||||
mcpTools.CloseAllMCPSessions()
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
return nil
|
||||
return runErr
|
||||
}
|
||||
|
||||
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
||||
|
||||
@@ -159,6 +159,14 @@ type RunCMD struct {
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
@@ -283,6 +291,34 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.NatsAccountSeed != "" {
|
||||
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
|
||||
}
|
||||
if r.NatsServiceJWT != "" {
|
||||
opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT))
|
||||
}
|
||||
if r.NatsServiceSeed != "" {
|
||||
opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed))
|
||||
}
|
||||
if r.NatsWorkerJWTTTL != "" {
|
||||
d, err := time.ParseDuration(r.NatsWorkerJWTTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithNatsWorkerJWTTTL(d))
|
||||
}
|
||||
if r.NatsRequireAuth {
|
||||
opts = append(opts, config.EnableNatsRequireAuth)
|
||||
}
|
||||
if r.NatsTLSCA != "" {
|
||||
opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA))
|
||||
}
|
||||
if r.NatsTLSCert != "" {
|
||||
opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert))
|
||||
}
|
||||
if r.NatsTLSKey != "" {
|
||||
opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||
options.Backend = t.Backend
|
||||
options.Model = t.Model
|
||||
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, "", nil, ml, opts, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
|
||||
FrontendURL: r.RegisterTo,
|
||||
RegistrationToken: r.RegistrationToken,
|
||||
}
|
||||
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("registering with frontend: %w", regErr)
|
||||
}
|
||||
|
||||
@@ -58,65 +58,77 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
|
||||
|
||||
// RegisterResponse is the JSON body returned by /api/node/register.
|
||||
type RegisterResponse struct {
|
||||
ID string `json:"id"`
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
NatsJWT string `json:"nats_jwt,omitempty"`
|
||||
NatsUserSeed string `json:"nats_user_seed,omitempty"`
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// (optionally) an auto-provisioned API token.
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
|
||||
// RegisterFull sends a single registration request and returns the full
|
||||
// response (node ID, approval status, and optional API token / NATS creds).
|
||||
// Re-registration is idempotent: the frontend preserves the node row and mints
|
||||
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
|
||||
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
url := c.baseURL() + "/api/node/register"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("creating request: %w", err)
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("posting to %s: %w", url, err)
|
||||
return nil, fmt.Errorf("posting to %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", "", fmt.Errorf("decoding response: %w", err)
|
||||
return nil, fmt.Errorf("decoding response: %w", err)
|
||||
}
|
||||
return result.ID, result.APIToken, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// optional credentials (API token for agent workers, NATS JWT when configured).
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
res, err := c.RegisterFull(ctx, body)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
|
||||
}
|
||||
|
||||
// RegisterWithRetry retries registration with exponential backoff.
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
backoff := 2 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
|
||||
var nodeID, apiToken string
|
||||
var err error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
nodeID, apiToken, err = c.Register(ctx, body)
|
||||
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
|
||||
if err == nil {
|
||||
return nodeID, apiToken, nil
|
||||
return nodeID, apiToken, natsJWT, natsSeed, nil
|
||||
}
|
||||
if attempt == maxRetries {
|
||||
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", "", ctx.Err()
|
||||
return "", "", "", "", ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
}
|
||||
return nodeID, apiToken, err
|
||||
return nodeID, apiToken, natsJWT, natsSeed, err
|
||||
}
|
||||
|
||||
// Heartbeat sends a single heartbeat POST with the given body.
|
||||
|
||||
200
core/cli/workerregistry/credentials.go
Normal file
200
core/cli/workerregistry/credentials.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
|
||||
// imported so the lightweight registration client does not pull in the nodes
|
||||
// package (and its gorm/DB dependencies).
|
||||
const statusPending = "pending"
|
||||
|
||||
// defaultMaxAttempts bounds how many times Acquire registers (and how many
|
||||
// consecutive times RefreshLoop may fail) before giving up. It is high enough
|
||||
// to ride out a slow admin approval or a transient frontend outage, but finite
|
||||
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
|
||||
// non-zero exit and the resulting restart) rather than waiting forever.
|
||||
const defaultMaxAttempts = 100
|
||||
|
||||
// RegisterFunc performs one idempotent registration round-trip.
|
||||
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
|
||||
|
||||
// NATSCredentialManager acquires NATS credentials at startup — waiting through
|
||||
// admin approval when required — and refreshes them before the minted JWT
|
||||
// expires, by re-registering (which mints a fresh JWT). The live NATS
|
||||
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
|
||||
// for concurrent use.
|
||||
//
|
||||
// It addresses two failure modes: a worker that needs credentials but registers
|
||||
// while still pending approval (it would otherwise give up and never connect),
|
||||
// and a long-running worker whose 24h JWT expires with no way to renew it.
|
||||
type NATSCredentialManager struct {
|
||||
register RegisterFunc
|
||||
requireCreds bool // block until credentials are present (frontend minting in use)
|
||||
|
||||
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
|
||||
initialBackoff time.Duration
|
||||
maxBackoff time.Duration
|
||||
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
|
||||
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
|
||||
refreshRetry time.Duration
|
||||
expiryOf func(jwt string) (time.Time, bool)
|
||||
|
||||
mu sync.RWMutex
|
||||
jwt string
|
||||
seed string
|
||||
nodeID string
|
||||
}
|
||||
|
||||
// NewNATSCredentialManager builds a manager over register. When requireCreds is
|
||||
// true, Acquire blocks until the node is approved and credentials are minted.
|
||||
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
|
||||
return &NATSCredentialManager{
|
||||
register: register,
|
||||
requireCreds: requireCreds,
|
||||
initialBackoff: 2 * time.Second,
|
||||
maxBackoff: 30 * time.Second,
|
||||
maxAttempts: defaultMaxAttempts,
|
||||
refreshLead: 0.75,
|
||||
refreshRetry: 30 * time.Second,
|
||||
expiryOf: jwtExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
|
||||
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
|
||||
func jwtExpiry(token string) (time.Time, bool) {
|
||||
if token == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
uc, err := natsauth.DecodeUserClaims(token)
|
||||
if err != nil || uc.Expires == 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Unix(uc.Expires, 0), true
|
||||
}
|
||||
|
||||
func (m *NATSCredentialManager) store(res *RegisterResponse) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.nodeID = res.ID
|
||||
if res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
|
||||
}
|
||||
}
|
||||
|
||||
// Current returns the latest NATS credentials (both empty until acquired).
|
||||
func (m *NATSCredentialManager) Current() (jwt, seed string) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.jwt, m.seed
|
||||
}
|
||||
|
||||
// NodeID returns the node ID from the most recent registration.
|
||||
func (m *NATSCredentialManager) NodeID() string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.nodeID
|
||||
}
|
||||
|
||||
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
|
||||
// supplying the current credentials on each (re)connect.
|
||||
func (m *NATSCredentialManager) Provider() func() (string, string) {
|
||||
return m.Current
|
||||
}
|
||||
|
||||
// HasCredentials reports whether complete NATS credentials have been obtained.
|
||||
func (m *NATSCredentialManager) HasCredentials() bool {
|
||||
jwt, seed := m.Current()
|
||||
return jwt != "" && seed != ""
|
||||
}
|
||||
|
||||
// Acquire registers and, when requireCreds is set, keeps re-registering with
|
||||
// exponential backoff until the node is approved (status != pending) and
|
||||
// credentials are minted. Without requireCreds it returns the first successful
|
||||
// response (the historical one-shot behavior, preserved for anonymous NATS).
|
||||
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
|
||||
backoff := m.initialBackoff
|
||||
var lastReason error
|
||||
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
|
||||
res, err := m.register(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
lastReason = err
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
case !m.requireCreds:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
case res.Status == statusPending:
|
||||
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
|
||||
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
case res.NatsJWT == "" || res.NatsUserSeed == "":
|
||||
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
|
||||
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
default:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, m.maxBackoff)
|
||||
}
|
||||
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
|
||||
}
|
||||
|
||||
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
|
||||
// updating the credentials returned by Current/Provider so the NATS connection
|
||||
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
|
||||
// when the current credential has no expiry (nothing to refresh), and a non-nil
|
||||
// error after maxAttempts consecutive refresh failures — letting the caller
|
||||
// exit the worker so it restarts and re-acquires (or surfaces the outage)
|
||||
// rather than silently drifting toward an expired, unrenewable JWT.
|
||||
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
|
||||
failures := 0
|
||||
for {
|
||||
jwt, _ := m.Current()
|
||||
exp, ok := m.expiryOf(jwt)
|
||||
if !ok {
|
||||
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
|
||||
return nil
|
||||
}
|
||||
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(wait):
|
||||
}
|
||||
|
||||
res, err := m.register(ctx)
|
||||
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.store(res)
|
||||
failures = 0
|
||||
xlog.Info("Refreshed NATS credentials", "node", res.ID)
|
||||
continue
|
||||
}
|
||||
failures++
|
||||
if err != nil {
|
||||
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
|
||||
} else {
|
||||
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
|
||||
}
|
||||
if m.maxAttempts > 0 && failures >= m.maxAttempts {
|
||||
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
|
||||
}
|
||||
// Back off before retrying so a persistent failure near expiry does not spin.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(m.refreshRetry):
|
||||
}
|
||||
}
|
||||
}
|
||||
198
core/cli/workerregistry/credentials_test.go
Normal file
198
core/cli/workerregistry/credentials_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestWorkerRegistry(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "WorkerRegistry")
|
||||
}
|
||||
|
||||
// fakeRegister returns a sequence of canned responses/errors, one per call, and
|
||||
// records how many times it was invoked. The last entry repeats once exhausted.
|
||||
type fakeRegister struct {
|
||||
mu sync.Mutex
|
||||
steps []step
|
||||
calls int
|
||||
}
|
||||
|
||||
type step struct {
|
||||
res *RegisterResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeRegister) fn() RegisterFunc {
|
||||
return func(context.Context) (*RegisterResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
i := f.calls
|
||||
f.calls++
|
||||
if i >= len(f.steps) {
|
||||
i = len(f.steps) - 1
|
||||
}
|
||||
return f.steps[i].res, f.steps[i].err
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeRegister) count() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls
|
||||
}
|
||||
|
||||
var _ = Describe("NATSCredentialManager", func() {
|
||||
approved := func(jwt, seed string) *RegisterResponse {
|
||||
return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed}
|
||||
}
|
||||
pending := &RegisterResponse{ID: "node-1", Status: "pending"}
|
||||
|
||||
Describe("Acquire (#4 — wait through admin approval)", func() {
|
||||
It("keeps re-registering until the node is approved and credentials are minted", func() {
|
||||
f := &fakeRegister{steps: []step{
|
||||
{res: pending}, // not approved yet
|
||||
{res: approved("", "")}, // approved but JWT not minted yet
|
||||
{res: approved("jwt-1", "seed-1")}, // finally minted
|
||||
}}
|
||||
m := NewNATSCredentialManager(f.fn(), true /* requireCreds */)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.ID).To(Equal("node-1"))
|
||||
Expect(f.count()).To(Equal(3))
|
||||
|
||||
jwt, seed := m.Current()
|
||||
Expect(jwt).To(Equal("jwt-1"))
|
||||
Expect(seed).To(Equal("seed-1"))
|
||||
Expect(m.HasCredentials()).To(BeTrue())
|
||||
Expect(m.NodeID()).To(Equal("node-1"))
|
||||
})
|
||||
|
||||
It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), false /* requireCreds */)
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Status).To(Equal("pending"))
|
||||
Expect(f.count()).To(Equal(1))
|
||||
Expect(m.HasCredentials()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("aborts when the context is cancelled while waiting for approval", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = 10 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := m.Acquire(ctx)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
})
|
||||
|
||||
It("gives up after a bounded number of attempts so the worker exits and alerts", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}} // never approved
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
m.maxAttempts = 5
|
||||
|
||||
_, err := m.Acquire(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("after 5 attempts"))
|
||||
Expect(err.Error()).To(ContainSubstring("pending admin approval"))
|
||||
Expect(f.count()).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RefreshLoop (#5 — renew before the JWT expires)", func() {
|
||||
It("re-registers before expiry and updates the credentials served to new connections", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
// jwt-1 expires soon; jwt-2 is long-lived so the loop then idles.
|
||||
m.expiryOf = func(jwt string) (time.Time, bool) {
|
||||
switch jwt {
|
||||
case "jwt-1":
|
||||
return time.Now().Add(40 * time.Millisecond), true
|
||||
case "jwt-2":
|
||||
return time.Now().Add(time.Hour), true
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() { _ = m.RefreshLoop(ctx) }()
|
||||
|
||||
Eventually(func() string {
|
||||
jwt, _ := m.Current()
|
||||
return jwt
|
||||
}, "2s", "10ms").Should(Equal("jwt-2"))
|
||||
})
|
||||
|
||||
It("returns an error after the bounded number of consecutive failures so the caller can exit", func() {
|
||||
f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
m.maxAttempts = 3
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true }
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- m.RefreshLoop(context.Background()) }()
|
||||
Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row"))))
|
||||
})
|
||||
|
||||
It("exits promptly when the current credential has no expiry (nothing to refresh)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("x", "y")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false }
|
||||
m.store(approved("static", "seed"))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { _ = m.RefreshLoop(context.Background()); close(done) }()
|
||||
Eventually(done, "1s").Should(BeClosed())
|
||||
Expect(f.count()).To(Equal(0)) // never tried to re-register
|
||||
})
|
||||
})
|
||||
|
||||
Describe("jwtExpiry default", func() {
|
||||
It("decodes the expiry of a real minted worker JWT", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour}
|
||||
token, _, err := cfg.MintWorkerJWT("node-1", "backend")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
exp, ok := jwtExpiry(token)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute))
|
||||
})
|
||||
|
||||
It("reports no expiry for an empty or undecodable token", func() {
|
||||
_, ok := jwtExpiry("")
|
||||
Expect(ok).To(BeFalse())
|
||||
_, ok = jwtExpiry("not-a-jwt")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -22,9 +22,11 @@ const (
|
||||
UsecaseRerank = "rerank"
|
||||
UsecaseDetection = "detection"
|
||||
UsecaseVAD = "vad"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseFaceRecognition = "face_recognition"
|
||||
UsecaseSpeakerRecognition = "speaker_recognition"
|
||||
)
|
||||
|
||||
// GRPCMethod identifies a Backend service RPC from backend.proto.
|
||||
@@ -47,6 +49,11 @@ const (
|
||||
MethodAudioTransform GRPCMethod = "AudioTransform"
|
||||
MethodDiarize GRPCMethod = "Diarize"
|
||||
MethodAudioToAudioStream GRPCMethod = "AudioToAudioStream"
|
||||
MethodFaceVerify GRPCMethod = "FaceVerify"
|
||||
MethodFaceAnalyze GRPCMethod = "FaceAnalyze"
|
||||
MethodVoiceVerify GRPCMethod = "VoiceVerify"
|
||||
MethodVoiceEmbed GRPCMethod = "VoiceEmbed"
|
||||
MethodVoiceAnalyze GRPCMethod = "VoiceAnalyze"
|
||||
)
|
||||
|
||||
// UsecaseInfo describes a single known_usecase value and how it maps
|
||||
@@ -154,6 +161,16 @@ var UsecaseInfoMap = map[string]UsecaseInfo{
|
||||
GRPCMethod: MethodAudioToAudioStream,
|
||||
Description: "Self-contained any-to-any audio model for the Realtime API — accepts microphone audio and emits speech + transcript (+ optional function calls) from a single backend via the AudioToAudioStream RPC.",
|
||||
},
|
||||
UsecaseFaceRecognition: {
|
||||
Flag: FLAG_FACE_RECOGNITION,
|
||||
GRPCMethod: MethodFaceVerify,
|
||||
Description: "Face recognition — verify identity, analyze attributes (age/gender/emotion) via FaceVerify and FaceAnalyze RPCs.",
|
||||
},
|
||||
UsecaseSpeakerRecognition: {
|
||||
Flag: FLAG_SPEAKER_RECOGNITION,
|
||||
GRPCMethod: MethodVoiceVerify,
|
||||
Description: "Speaker recognition — verify identity, embed and analyze voice via VoiceVerify, VoiceEmbed and VoiceAnalyze RPCs.",
|
||||
},
|
||||
}
|
||||
|
||||
// BackendCapability describes which gRPC methods and usecases a backend supports.
|
||||
@@ -471,6 +488,21 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseDetection},
|
||||
Description: "RF-DETR C++ object detection",
|
||||
},
|
||||
|
||||
// --- Face and speaker recognition backends ---
|
||||
"insightface": {
|
||||
GRPCMethods: []GRPCMethod{MethodEmbedding, MethodDetect, MethodFaceVerify, MethodFaceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseEmbeddings, UsecaseDetection, UsecaseFaceRecognition},
|
||||
DefaultUsecases: []string{UsecaseFaceRecognition},
|
||||
AcceptsImages: true,
|
||||
Description: "InsightFace — face detection, embedding, verification and attribute analysis",
|
||||
},
|
||||
"speaker-recognition": {
|
||||
GRPCMethods: []GRPCMethod{MethodVoiceVerify, MethodVoiceEmbed, MethodVoiceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseSpeakerRecognition},
|
||||
DefaultUsecases: []string{UsecaseSpeakerRecognition},
|
||||
Description: "Speaker recognition — voice identity verification and analysis",
|
||||
},
|
||||
"silero-vad": {
|
||||
GRPCMethods: []GRPCMethod{MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseVAD},
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -18,6 +20,16 @@ type DistributedConfig struct {
|
||||
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
|
||||
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
|
||||
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
|
||||
NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers
|
||||
NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT
|
||||
NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h)
|
||||
NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing
|
||||
NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify)
|
||||
NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS
|
||||
NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert
|
||||
|
||||
// S3 configuration (used when StorageURL is set)
|
||||
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
|
||||
StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION
|
||||
@@ -80,6 +92,13 @@ func (c DistributedConfig) Validate() error {
|
||||
if c.RegistrationToken == "" {
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
|
||||
}
|
||||
if err := c.NatsAuthConfig().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.NatsTLSFiles().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.NatsAuthConfig().WarnIfInsecure(true)
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||
@@ -123,6 +142,52 @@ func WithRegistrationToken(token string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsAccountSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsAccountSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceJWT(jwt string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceJWT = jwt
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsWorkerJWTTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsWorkerJWTTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableNatsRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsRequireAuth = true
|
||||
}
|
||||
|
||||
func WithNatsTLSCA(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCA = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSCert(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCert = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSKey(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSKey = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageURL = url
|
||||
@@ -217,6 +282,44 @@ const (
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client.
|
||||
func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles {
|
||||
return messaging.TLSFiles{
|
||||
CA: c.NatsTLSCA,
|
||||
Cert: c.NatsTLSCert,
|
||||
Key: c.NatsTLSKey,
|
||||
}
|
||||
}
|
||||
|
||||
// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components.
|
||||
// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config.
|
||||
func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option {
|
||||
var opts []messaging.Option
|
||||
jwt, seed := userJWT, userSeed
|
||||
if jwt == "" && seed == "" {
|
||||
auth := c.NatsAuthConfig()
|
||||
jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed
|
||||
}
|
||||
if jwt != "" && seed != "" {
|
||||
opts = append(opts, messaging.WithUserJWT(jwt, seed))
|
||||
}
|
||||
if tls := c.NatsTLSFiles(); tls.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(tls))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// NatsAuthConfig builds pkg/natsauth settings from distributed configuration.
|
||||
func (c DistributedConfig) NatsAuthConfig() natsauth.Config {
|
||||
return natsauth.Config{
|
||||
AccountSeed: c.NatsAccountSeed,
|
||||
ServiceUserJWT: c.NatsServiceJWT,
|
||||
ServiceUserSeed: c.NatsServiceSeed,
|
||||
WorkerJWTTTL: c.NatsWorkerJWTTTL,
|
||||
RequireAuth: c.NatsRequireAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
|
||||
@@ -420,8 +420,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
remoteUnloader = d.Router.Unloader()
|
||||
}
|
||||
}
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
natsCfg := distCfg.NatsAuthConfig()
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg)
|
||||
|
||||
// Distributed SSE routes (job progress + agent events via NATS)
|
||||
if d := application.Distributed(); d != nil {
|
||||
|
||||
@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
|
||||
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
|
||||
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, "", nil, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
)
|
||||
|
||||
// nodeError builds a schema.ErrorResponse for node endpoints.
|
||||
@@ -89,7 +90,7 @@ type RegisterNodeRequest struct {
|
||||
// RegisterNodeEndpoint registers a new backend node.
|
||||
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
|
||||
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req RegisterNodeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
@@ -217,13 +218,15 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
}
|
||||
|
||||
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
|
||||
// For agent workers, it also provisions an API key so they can call the inference API.
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
id := c.Param("id")
|
||||
@@ -253,10 +256,26 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
|
||||
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
|
||||
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
|
||||
return
|
||||
}
|
||||
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
|
||||
return
|
||||
}
|
||||
response["nats_jwt"] = jwt
|
||||
response["nats_user_seed"] = seed
|
||||
}
|
||||
|
||||
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
|
||||
// Returns the plaintext API key on success.
|
||||
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -63,7 +65,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -74,6 +76,29 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
Expect(resp["status"]).To(Equal(nodes.StatusHealthy))
|
||||
})
|
||||
|
||||
It("returns nats_jwt when account seed is configured", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
e := echo.New()
|
||||
body := `{"name":"worker-nats","address":"10.0.0.2:50051"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
natsCfg := natsauth.Config{AccountSeed: string(seed)}
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["nats_jwt"]).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns 400 when name is missing", func() {
|
||||
e := echo.New()
|
||||
body := `{"address":"10.0.0.1:50051"}`
|
||||
@@ -82,7 +107,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -102,7 +127,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -121,7 +146,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -140,7 +165,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -159,7 +184,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
@@ -172,7 +197,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -195,7 +220,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1))
|
||||
req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(e.NewContext(req1, rec1))).To(Succeed())
|
||||
Expect(rec1.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Stream audio chunks as they're generated
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
_, writeErr := c.Response().Write(audioChunk)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
}
|
||||
|
||||
// Non-streaming TTS (existing behavior)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, input.Instructions, input.Params, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -313,7 +313,7 @@ func newRealtimeDecisionID() string {
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
|
||||
return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
return backend.ModelTTS(ctx, text, voice, language, "", nil, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
|
||||
|
||||
@@ -152,15 +152,9 @@ func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIR
|
||||
|
||||
// If a model name was specified, verify it actually exists before proceeding.
|
||||
// Check both configured models and loose model files in the model path.
|
||||
// Skip the check only for HuggingFace-style model IDs ("org/repo") that
|
||||
// backends like diffusers may download on the fly. A name that points at a
|
||||
// concrete weight file (e.g. "local/model.gguf") is NOT such an ID: it must
|
||||
// still be verified, otherwise a wrong name silently falls through to the
|
||||
// gallery autoloader and triggers a surprising download (issue #10162).
|
||||
// CheckIfModelExists resolves relative paths against the models dir, so a
|
||||
// loose weight file addressed by path still passes.
|
||||
isRemoteModelID := strings.Contains(modelName, "/") && !model.HasKnownModelFileExtension(modelName)
|
||||
if modelName != "" && !isRemoteModelID {
|
||||
// Skip the check for HuggingFace model IDs (contain "/") since backends
|
||||
// like diffusers may download these on the fly.
|
||||
if modelName != "" && !strings.Contains(modelName, "/") {
|
||||
exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE)
|
||||
if existsErr == nil && !exists {
|
||||
return c.JSON(http.StatusNotFound, schema.ErrorResponse{
|
||||
|
||||
@@ -140,40 +140,6 @@ var _ = Describe("SetModelAndConfig middleware", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("when the model name is a file path to a weight that does not exist", func() {
|
||||
// A name like "local/model.gguf" is the parameters.model weight path, not a
|
||||
// HuggingFace org/repo ID. The slash must not exempt it from the existence
|
||||
// check, otherwise a wrong name silently falls through to the gallery
|
||||
// autoloader and triggers a surprising download (issue #10162).
|
||||
It("returns 404 instead of passing through", func() {
|
||||
rec := postJSON(app, "/v1/chat/completions",
|
||||
`{"model":"local/missing-model.gguf","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusNotFound))
|
||||
|
||||
var resp schema.ErrorResponse
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp.Error).ToNot(BeNil())
|
||||
Expect(resp.Error.Message).To(ContainSubstring("local/missing-model.gguf"))
|
||||
Expect(resp.Error.Message).To(ContainSubstring("not found"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when the model name is a file path to a weight that exists on disk", func() {
|
||||
// The same path, but the loose weight file is actually present in a
|
||||
// subdirectory of the models path: the request must pass through so users
|
||||
// can address a raw weight file by its relative path.
|
||||
It("passes through to the handler", func() {
|
||||
Expect(os.MkdirAll(filepath.Join(modelDir, "local"), 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "local", "present-model.gguf"), []byte("weights"), 0644)).To(Succeed())
|
||||
|
||||
rec := postJSON(app, "/v1/chat/completions",
|
||||
`{"model":"local/present-model.gguf","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when no model is specified", func() {
|
||||
It("passes through without checking", func() {
|
||||
rec := postJSON(app, "/v1/chat/completions",
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ func nodeReadyMiddleware(registry *nodes.NodeRegistry) echo.MiddlewareFunc {
|
||||
// token but do not verify per-node identity. A compromised worker can heartbeat/drain/
|
||||
// deregister other nodes. Future: issue per-node JWT at registration, validate node
|
||||
// identity on subsequent requests (compare :id param with token subject).
|
||||
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) {
|
||||
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
@@ -44,7 +45,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
|
||||
tokenAuthMw := nodeTokenAuth(registrationToken)
|
||||
|
||||
node := e.Group("/api/node", readyMw, tokenAuthMw)
|
||||
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret))
|
||||
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret, natsCfg))
|
||||
node.POST("/:id/heartbeat", localai.HeartbeatEndpoint(registry))
|
||||
node.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
|
||||
node.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
|
||||
@@ -60,7 +61,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
|
||||
// backend install path (POST /:id/backends/install). That handler enqueues a
|
||||
// ManagementOp on the gallery channel rather than blocking on a NATS reply, so
|
||||
// the browser gets HTTP 202 + jobID immediately instead of waiting up to 3 minutes.
|
||||
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string) {
|
||||
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string, natsCfg natsauth.Config) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
|
||||
admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry))
|
||||
admin.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
|
||||
admin.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
|
||||
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret))
|
||||
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret, natsCfg))
|
||||
|
||||
// Backend management on workers
|
||||
admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader))
|
||||
|
||||
@@ -60,6 +60,14 @@ type TTSRequest struct {
|
||||
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
|
||||
Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS
|
||||
SampleRate int `json:"sample_rate,omitempty" yaml:"sample_rate,omitempty"` // (optional) desired output sample rate
|
||||
// Instructions is a free-form, per-request style/voice description. It maps to
|
||||
// the OpenAI `instructions` field and is forwarded to the backend so expressive
|
||||
// TTS models (e.g. Qwen3-TTS CustomVoice/VoiceDesign) can vary tone or designed
|
||||
// voice per request instead of only via the static YAML option.
|
||||
Instructions string `json:"instructions,omitempty" yaml:"instructions,omitempty"`
|
||||
// Params carries optional, backend-specific per-request generation parameters
|
||||
// (LocalAI extension, e.g. Chatterbox exaggeration/cfg_weight/temperature).
|
||||
Params map[string]string `json:"params,omitempty" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// @Description VAD request body
|
||||
|
||||
@@ -2,15 +2,22 @@ package messaging
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
// subscribeConfirmTimeout bounds the server round-trip used to detect whether a
|
||||
// subscription was rejected (e.g. by JWT permissions) before returning to the caller.
|
||||
const subscribeConfirmTimeout = 5 * time.Second
|
||||
|
||||
// Client wraps a NATS connection and provides helpers for pub/sub and queue subscriptions.
|
||||
type Client struct {
|
||||
conn *nats.Conn
|
||||
@@ -18,8 +25,13 @@ type Client struct {
|
||||
}
|
||||
|
||||
// New creates a new NATS client with auto-reconnect.
|
||||
func New(url string) (*Client, error) {
|
||||
nc, err := nats.Connect(url,
|
||||
func New(url string, opts ...Option) (*Client, error) {
|
||||
var cfg connectConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
natsOpts := []nats.Option{
|
||||
nats.RetryOnFailedConnect(true),
|
||||
nats.MaxReconnects(-1),
|
||||
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
|
||||
@@ -33,7 +45,60 @@ func New(url string) (*Client, error) {
|
||||
nats.ClosedHandler(func(_ *nats.Conn) {
|
||||
xlog.Info("NATS connection closed")
|
||||
}),
|
||||
)
|
||||
// Surface async errors (notably permission violations) that NATS would
|
||||
// otherwise deliver silently. A subscription the server rejects for a
|
||||
// JWT permission means the worker never receives those messages, so make
|
||||
// it loud rather than letting the feature fail invisibly.
|
||||
nats.ErrorHandler(func(_ *nats.Conn, sub *nats.Subscription, err error) {
|
||||
subject := ""
|
||||
if sub != nil {
|
||||
subject = sub.Subject
|
||||
}
|
||||
if errors.Is(err, nats.ErrPermissionViolation) {
|
||||
xlog.Error("NATS permission violation — check JWT pub/sub allow lists", "subject", subject, "error", err)
|
||||
return
|
||||
}
|
||||
xlog.Warn("NATS async error", "subject", subject, "error", err)
|
||||
}),
|
||||
}
|
||||
switch {
|
||||
case cfg.jwtProvider != nil:
|
||||
// Fetch creds on every (re)connect so a refresh loop can rotate the JWT
|
||||
// before expiry; the server expiring the old JWT triggers a reconnect
|
||||
// that transparently picks up the new one.
|
||||
natsOpts = append(natsOpts, nats.UserJWT(
|
||||
func() (string, error) {
|
||||
jwt, _ := cfg.jwtProvider()
|
||||
if jwt == "" {
|
||||
return "", fmt.Errorf("no NATS user JWT available")
|
||||
}
|
||||
return jwt, nil
|
||||
},
|
||||
func(nonce []byte) ([]byte, error) {
|
||||
_, seed := cfg.jwtProvider()
|
||||
kp, err := nkeys.FromSeed([]byte(seed))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading NATS user seed: %w", err)
|
||||
}
|
||||
defer kp.Wipe()
|
||||
return kp.Sign(nonce)
|
||||
},
|
||||
))
|
||||
case cfg.userJWT != "" && cfg.userSeed != "":
|
||||
natsOpts = append(natsOpts, nats.UserJWTAndSeed(cfg.userJWT, cfg.userSeed))
|
||||
}
|
||||
if cfg.tls.Enabled() {
|
||||
if err := cfg.tls.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsOpts, err := cfg.tls.natsOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
natsOpts = append(natsOpts, tlsOpts...)
|
||||
}
|
||||
|
||||
nc, err := nats.Connect(url, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS at %s: %w", sanitize.URL(url), err)
|
||||
}
|
||||
@@ -54,23 +119,67 @@ func (c *Client) Publish(subject string, data any) error {
|
||||
|
||||
// Subscribe creates a subscription on the given subject. All subscribers receive every message.
|
||||
func (c *Client) Subscribe(subject string, handler func([]byte)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// QueueSubscribe creates a queue subscription. Within the same queue group,
|
||||
// only one subscriber receives each message (load-balanced).
|
||||
func (c *Client) QueueSubscribe(subject, queue string, handler func([]byte)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// confirmSubscription creates a subscription via mk and forces a server
|
||||
// round-trip so that a permissions violation — which NATS otherwise reports
|
||||
// only asynchronously — is returned to the caller synchronously. The server
|
||||
// emits the "-ERR Permissions Violation" for a rejected SUB before the PONG
|
||||
// that satisfies the flush, so by the time FlushTimeout returns the violation
|
||||
// is recorded as the connection's last error. Without this, a worker whose JWT
|
||||
// lacks a subject gets a non-nil subscription that never receives a message,
|
||||
// turning a permission misconfiguration into a silent failure.
|
||||
func (c *Client) confirmSubscription(subject string, mk func(*nats.Conn) (*nats.Subscription, error)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
conn := c.conn
|
||||
c.mu.RUnlock()
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("subscribe to %s: nil NATS connection", subject)
|
||||
}
|
||||
|
||||
sub, err := mk(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// A failed flush here means we could not round-trip to the server (not yet
|
||||
// connected, reconnecting, slow link). RetryOnFailedConnect intentionally
|
||||
// buffers subscriptions across that gap, so do NOT fail — keep the
|
||||
// subscription and let it replay on (re)connect; a later permission
|
||||
// violation is still logged by the async error handler in New.
|
||||
if err := conn.FlushTimeout(subscribeConfirmTimeout); err != nil {
|
||||
xlog.Debug("Could not confirm NATS subscription (will replay on connect)", "subject", subject, "error", err)
|
||||
return sub, nil
|
||||
}
|
||||
// Flush succeeded, so any permission violation for this SUB has already been
|
||||
// recorded as the connection's last error (the server emits it before the
|
||||
// PONG). LastError is per-connection; match the exact quoted subject the
|
||||
// server echoes ("Subscription to \"<subject>\"") so a stale violation for
|
||||
// another subject can't be mis-attributed here.
|
||||
if lerr := conn.LastError(); lerr != nil &&
|
||||
errors.Is(lerr, nats.ErrPermissionViolation) &&
|
||||
strings.Contains(lerr.Error(), `Subscription to "`+subject+`"`) {
|
||||
_ = sub.Unsubscribe()
|
||||
return nil, fmt.Errorf("subscription to %s denied by NATS server (check JWT sub allow list): %w", subject, lerr)
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// Request sends a request and waits for a reply (request-reply pattern).
|
||||
// Returns the raw reply data.
|
||||
func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
|
||||
@@ -86,15 +195,15 @@ func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]
|
||||
// SubscribeReply creates a subscription that supports replying to requests.
|
||||
// The handler receives the raw request data and the reply subject.
|
||||
func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -102,15 +211,15 @@ func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply
|
||||
// QueueSubscribeReply creates a queue subscription that supports replying to requests.
|
||||
// Load-balanced across subscribers in the same queue group, with request-reply support.
|
||||
func (c *Client) QueueSubscribeReply(subject, queue string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
34
core/services/messaging/options.go
Normal file
34
core/services/messaging/options.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package messaging
|
||||
|
||||
// Option configures NATS client connection behavior.
|
||||
type Option func(*connectConfig)
|
||||
|
||||
// CredentialProvider returns the NATS user JWT and signing seed to use for the
|
||||
// next (re)connect. It is consulted on every connection attempt, so a refresh
|
||||
// loop can rotate credentials before they expire and the connection picks them
|
||||
// up automatically when the server expires the old JWT and triggers a reconnect.
|
||||
type CredentialProvider func() (jwt, seed string)
|
||||
|
||||
type connectConfig struct {
|
||||
userJWT string
|
||||
userSeed string
|
||||
jwtProvider CredentialProvider
|
||||
tls TLSFiles
|
||||
}
|
||||
|
||||
// WithUserJWT connects using a static NATS user JWT and signing seed (UserJWTAndSeed).
|
||||
func WithUserJWT(jwt, seed string) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.userJWT = jwt
|
||||
c.userSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserJWTProvider connects using credentials fetched from provider on each
|
||||
// (re)connect, enabling JWT rotation without dropping the client. Takes
|
||||
// precedence over WithUserJWT when both are set.
|
||||
func WithUserJWTProvider(provider CredentialProvider) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.jwtProvider = provider
|
||||
}
|
||||
}
|
||||
68
core/services/messaging/tls.go
Normal file
68
core/services/messaging/tls.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package messaging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
// TLSFiles holds PEM paths for NATS TLS / mTLS. Cert and key must be set together.
|
||||
// Use tls:// in LOCALAI_NATS_URL; CA and client cert paths are optional extras.
|
||||
type TLSFiles struct {
|
||||
CA string // LOCALAI_NATS_TLS_CA — private CA for server verification
|
||||
Cert string // LOCALAI_NATS_TLS_CERT — client certificate (mTLS)
|
||||
Key string // LOCALAI_NATS_TLS_KEY — client private key
|
||||
}
|
||||
|
||||
// Enabled reports whether any TLS file path is configured.
|
||||
func (f TLSFiles) Enabled() bool {
|
||||
return f.CA != "" || f.Cert != "" || f.Key != ""
|
||||
}
|
||||
|
||||
// Validate checks path pairing and that files exist.
|
||||
func (f TLSFiles) Validate() error {
|
||||
if f.Cert != "" && f.Key == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_TLS_KEY is required when LOCALAI_NATS_TLS_CERT is set")
|
||||
}
|
||||
if f.Key != "" && f.Cert == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_TLS_CERT is required when LOCALAI_NATS_TLS_KEY is set")
|
||||
}
|
||||
for _, path := range []struct {
|
||||
name, path string
|
||||
}{
|
||||
{"LOCALAI_NATS_TLS_CA", f.CA},
|
||||
{"LOCALAI_NATS_TLS_CERT", f.Cert},
|
||||
{"LOCALAI_NATS_TLS_KEY", f.Key},
|
||||
} {
|
||||
if path.path == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(path.path); err != nil {
|
||||
return fmt.Errorf("%s: %w", path.name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// natsOptions builds nats-go TLS options. Call Validate first.
|
||||
func (f TLSFiles) natsOptions() ([]nats.Option, error) {
|
||||
if !f.Enabled() {
|
||||
return nil, nil
|
||||
}
|
||||
opts := []nats.Option{nats.Secure()}
|
||||
if f.CA != "" {
|
||||
opts = append(opts, nats.RootCAs(f.CA))
|
||||
}
|
||||
if f.Cert != "" {
|
||||
opts = append(opts, nats.ClientCert(f.Cert, f.Key))
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// WithTLS configures CA and/or client certificate paths for the NATS connection.
|
||||
func WithTLS(files TLSFiles) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.tls = files
|
||||
}
|
||||
}
|
||||
25
core/services/messaging/tls_test.go
Normal file
25
core/services/messaging/tls_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package messaging_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TLSFiles", func() {
|
||||
It("requires cert and key together", func() {
|
||||
Expect((messaging.TLSFiles{Cert: "/tmp/c.pem"}).Validate()).To(HaveOccurred())
|
||||
Expect((messaging.TLSFiles{Key: "/tmp/k.pem"}).Validate()).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("validates files exist", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
ca := filepath.Join(dir, "ca.pem")
|
||||
Expect(os.WriteFile(ca, []byte("x"), 0600)).To(Succeed())
|
||||
Expect((messaging.TLSFiles{CA: ca}).Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
ggrpc "google.golang.org/grpc"
|
||||
@@ -64,64 +65,95 @@ func (c *InFlightTrackingClient) track(ctx context.Context) func() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconcile self-heals stale routing: when a backend reports that the model is
|
||||
// no longer loaded (the process survived but the model was evicted, while the
|
||||
// registry still lists it as loaded), it drops the replica row so the next
|
||||
// request triggers a fresh load instead of routing back here. Without this the
|
||||
// model stays unreachable until the controller restarts. The original error is
|
||||
// returned unchanged.
|
||||
func (c *InFlightTrackingClient) reconcile(err error) error {
|
||||
if !grpcerrors.IsModelNotLoaded(err) {
|
||||
return err
|
||||
}
|
||||
rmCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if rmErr := c.registry.RemoveNodeModel(rmCtx, c.nodeID, c.modelName, c.replicaIndex); rmErr != nil {
|
||||
xlog.Warn("Failed to drop stale replica after model-not-loaded",
|
||||
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex, "error", rmErr)
|
||||
} else {
|
||||
xlog.Warn("Backend reports model not loaded; dropped stale replica so the next request reloads",
|
||||
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// --- Tracked inference methods ---
|
||||
|
||||
func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Predict(ctx, in, opts...)
|
||||
reply, err := c.Backend.Predict(ctx, in, opts...)
|
||||
return reply, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.PredictStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Embeddings(ctx, in, opts...)
|
||||
res, err := c.Backend.Embeddings(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.GenerateImage(ctx, in, opts...)
|
||||
res, err := c.Backend.GenerateImage(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.GenerateVideo(ctx, in, opts...)
|
||||
res, err := c.Backend.GenerateVideo(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.TTS(ctx, in, opts...)
|
||||
res, err := c.Backend.TTS(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.TTSStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.SoundGeneration(ctx, in, opts...)
|
||||
res, err := c.Backend.SoundGeneration(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.AudioTranscription(ctx, in, opts...)
|
||||
res, err := c.Backend.AudioTranscription(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Detect(ctx, in, opts...)
|
||||
res, err := c.Backend.Detect(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Rerank(ctx, in, opts...)
|
||||
res, err := c.Backend.Rerank(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
@@ -20,9 +20,17 @@ type fakeInFlightTracker struct {
|
||||
mu sync.Mutex
|
||||
increments int
|
||||
decrements int
|
||||
removed int
|
||||
incrementErr error
|
||||
}
|
||||
|
||||
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.removed++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
@@ -295,4 +303,33 @@ var _ = Describe("InFlightTrackingClient", func() {
|
||||
Expect(tracker.decrements).To(Equal(1))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("stale model reload (self-heal)", func() {
|
||||
It("removes the replica when the backend reports the model is not loaded", func() {
|
||||
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(1))
|
||||
})
|
||||
|
||||
It("keeps the replica on an unrelated error", func() {
|
||||
backend.predictErr = fmt.Errorf("context deadline exceeded")
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(0))
|
||||
})
|
||||
|
||||
It("does not remove on success", func() {
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(0))
|
||||
})
|
||||
|
||||
It("self-heals on a streamed call too", func() {
|
||||
backend.streamErr = fmt.Errorf("whisper: model not loaded")
|
||||
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -78,6 +78,9 @@ type ModelLookup interface {
|
||||
type InFlightTracker interface {
|
||||
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
// RemoveNodeModel drops a stale replica row so the next request reloads the
|
||||
// model instead of routing back to a node where it is no longer loaded.
|
||||
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
}
|
||||
|
||||
// NodeManager is used by HTTP endpoints for node registration and lifecycle.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -932,13 +933,12 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
||||
{"AudioPath", &opts.AudioPath},
|
||||
}
|
||||
|
||||
// Count stageable files for progress tracking
|
||||
// Count stageable files for progress tracking. Directory models expand to
|
||||
// the number of files they contain, matching what stageDirectory uploads.
|
||||
totalFiles := 0
|
||||
for _, f := range fields {
|
||||
if *f.val != "" {
|
||||
if _, err := os.Stat(*f.val); err == nil {
|
||||
totalFiles++
|
||||
}
|
||||
totalFiles += countStageableFiles(*f.val)
|
||||
}
|
||||
}
|
||||
for _, adapter := range opts.LoraAdapters {
|
||||
@@ -969,8 +969,33 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
||||
*f.val = ""
|
||||
continue
|
||||
}
|
||||
fileIdx++
|
||||
localPath := *f.val
|
||||
|
||||
// Directory models (e.g. qwen3-tts-cpp ships its weights and tokenizer
|
||||
// ggufs under one directory) can't be uploaded as a single file — the
|
||||
// stager would open the directory and read its fd, failing with
|
||||
// "is a directory" (EISDIR). Expand the directory and stage each
|
||||
// contained file, then rewrite the field to the remote directory.
|
||||
if fi, statErr := os.Stat(localPath); statErr == nil && fi.IsDir() {
|
||||
remoteDir, dirErr := r.stageDirectory(ctx, node, trackingKey, localPath, keyMapper, &fileIdx, totalFiles)
|
||||
if dirErr != nil {
|
||||
if f.name == "ModelFile" {
|
||||
xlog.Error("Failed to stage model directory for remote node", "node", node.Name, "field", f.name, "path", localPath, "error", dirErr)
|
||||
return nil, fmt.Errorf("staging model file: %w", dirErr)
|
||||
}
|
||||
xlog.Warn("Failed to stage model directory, clearing field", "field", f.name, "path", localPath, "error", dirErr)
|
||||
*f.val = ""
|
||||
continue
|
||||
}
|
||||
*f.val = remoteDir
|
||||
if f.name == "ModelFile" && opts.Model != "" {
|
||||
opts.ModelPath = DeriveRemoteModelPath(remoteDir, opts.Model)
|
||||
xlog.Debug("Derived remote ModelPath", "modelPath", opts.ModelPath)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
fileIdx++
|
||||
key := keyMapper.Key(localPath)
|
||||
|
||||
// Attach progress callback to context for byte-level tracking
|
||||
@@ -1074,6 +1099,77 @@ func (r *SmartRouter) withStagingCallback(ctx context.Context, trackingKey, file
|
||||
})
|
||||
}
|
||||
|
||||
// countStageableFiles returns the number of regular files a model path expands
|
||||
// to for staging: 1 for a regular file, the contained file count for a
|
||||
// directory, and 0 if the path does not exist.
|
||||
func countStageableFiles(path string) int {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
if !fi.IsDir() {
|
||||
return 1
|
||||
}
|
||||
n := 0
|
||||
_ = filepath.WalkDir(path, func(_ string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil
|
||||
}
|
||||
if !d.IsDir() {
|
||||
n++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return n
|
||||
}
|
||||
|
||||
// stageDirectory stages every file under a directory-based model (e.g.
|
||||
// qwen3-tts-cpp, whose weights and tokenizer ggufs live in one directory).
|
||||
// Each file is uploaded individually with a structure-preserving key; the
|
||||
// returned path is the remote directory that contained them, suitable for the
|
||||
// backend's ModelFile/ModelPath. fileIdx is advanced per staged file so the
|
||||
// staging progress tracker stays accurate.
|
||||
func (r *SmartRouter) stageDirectory(ctx context.Context, node *BackendNode, trackingKey, dir string, keyMapper *StagingKeyMapper, fileIdx *int, totalFiles int) (string, error) {
|
||||
var remoteDir string
|
||||
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
*fileIdx++
|
||||
fileName := filepath.Base(path)
|
||||
stageCtx := r.withStagingCallback(ctx, trackingKey, fileName, *fileIdx, totalFiles)
|
||||
xlog.Info("Staging file", "model", trackingKey, "node", node.Name, "field", "ModelDir", "file", fileName, "fileIndex", *fileIdx, "totalFiles", totalFiles)
|
||||
|
||||
remoteFile, err := r.fileStager.EnsureRemote(stageCtx, node.ID, path, keyMapper.Key(path))
|
||||
if err != nil {
|
||||
return fmt.Errorf("staging %s: %w", path, err)
|
||||
}
|
||||
r.stagingTracker.FileComplete(trackingKey, *fileIdx, totalFiles)
|
||||
|
||||
// Every file under dir shares the same remote parent directory; derive
|
||||
// it from this file's staged path and its path relative to dir.
|
||||
rel, relErr := filepath.Rel(dir, path)
|
||||
if relErr != nil {
|
||||
return relErr
|
||||
}
|
||||
remoteDir = DeriveRemoteModelPath(remoteFile, rel)
|
||||
|
||||
r.stageCompanionFiles(ctx, node, path, keyMapper.Key)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if remoteDir == "" {
|
||||
return "", fmt.Errorf("model directory %s contains no files", dir)
|
||||
}
|
||||
return remoteDir, nil
|
||||
}
|
||||
|
||||
// stageCompanionFiles stages known companion files that exist alongside
|
||||
// localPath. For example, piper TTS implicitly loads ".onnx.json" next to
|
||||
// the ".onnx" model file. Errors are logged but not propagated.
|
||||
|
||||
64
core/services/nodes/router_dirstage_test.go
Normal file
64
core/services/nodes/router_dirstage_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// These tests cover staging of "directory models" — models whose ModelFile is a
|
||||
// directory containing multiple files (e.g. qwen3-tts-cpp ships weights +
|
||||
// tokenizer ggufs under one directory). The HTTP file stager uploads a single
|
||||
// regular file per path, so a directory ModelFile must be expanded into its
|
||||
// constituent files; otherwise the upload reads a directory fd and fails with
|
||||
// "is a directory" (EISDIR) on remote NATS worker nodes.
|
||||
var _ = Describe("stageModelFiles directory models", func() {
|
||||
var (
|
||||
stager *fakeFileStager
|
||||
router *SmartRouter
|
||||
node *BackendNode
|
||||
tmp string
|
||||
modelID = "qwen3-tts-cpp"
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stager = &fakeFileStager{}
|
||||
router = &SmartRouter{
|
||||
fileStager: stager,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
}
|
||||
node = &BackendNode{ID: "node-1", Name: "node-1", Address: "10.0.0.1:50051"}
|
||||
tmp = GinkgoT().TempDir()
|
||||
})
|
||||
|
||||
It("stages every file inside a directory ModelFile instead of the directory path", func() {
|
||||
modelDir := filepath.Join(tmp, "models", modelID)
|
||||
Expect(os.MkdirAll(modelDir, 0o755)).To(Succeed())
|
||||
weights := filepath.Join(modelDir, "qwen3-tts-0.6b-f16.gguf")
|
||||
tokenizer := filepath.Join(modelDir, "qwen3-tts-tokenizer-f16.gguf")
|
||||
Expect(os.WriteFile(weights, []byte("weights"), 0o644)).To(Succeed())
|
||||
Expect(os.WriteFile(tokenizer, []byte("tokenizer"), 0o644)).To(Succeed())
|
||||
|
||||
opts := &pb.ModelOptions{
|
||||
Model: modelID,
|
||||
ModelFile: modelDir,
|
||||
}
|
||||
|
||||
_, err := router.stageModelFiles(context.Background(), node, opts, "track-key")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
staged := make([]string, 0, len(stager.ensureCalls))
|
||||
for _, c := range stager.ensureCalls {
|
||||
staged = append(staged, c.localPath)
|
||||
}
|
||||
// Each contained file is staged individually; the directory path itself
|
||||
// is never handed to the stager (which would read a directory fd).
|
||||
Expect(staged).To(ConsistOf(weights, tokenizer))
|
||||
Expect(staged).ToNot(ContainElement(modelDir))
|
||||
})
|
||||
})
|
||||
@@ -60,7 +60,13 @@ type Config struct {
|
||||
MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"`
|
||||
|
||||
// NATS (required)
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (normally from registration nats_jwt)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user signing seed override (normally from registration nats_user_seed)" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed from registration or env" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// S3 storage for distributed file transfer
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`
|
||||
|
||||
33
core/services/worker/nats_connect.go
Normal file
33
core/services/worker/nats_connect.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
// connectNATS opens a NATS client using JWT+seed from env or registration (env wins).
|
||||
func connectNATS(url, envJWT, envSeed, registerJWT, registerSeed string, requireAuth bool, tls messaging.TLSFiles) (*messaging.Client, error) {
|
||||
// Env credentials take precedence, but only fall back to registration when
|
||||
// the env supplied neither half — otherwise a JWT set without its seed (or
|
||||
// vice-versa) would be silently completed from a different source.
|
||||
jwt, seed := envJWT, envSeed
|
||||
if jwt == "" && seed == "" {
|
||||
jwt, seed = registerJWT, registerSeed
|
||||
}
|
||||
// A JWT without its paired seed (or vice-versa) is a misconfiguration: refuse
|
||||
// rather than silently connecting anonymously, which would look authenticated.
|
||||
if (jwt == "") != (seed == "") {
|
||||
return nil, fmt.Errorf("NATS JWT and seed must be provided together (got JWT set=%t, seed set=%t)", jwt != "", seed != "")
|
||||
}
|
||||
var opts []messaging.Option
|
||||
if jwt != "" && seed != "" {
|
||||
opts = append(opts, messaging.WithUserJWT(jwt, seed))
|
||||
} else if requireAuth {
|
||||
return nil, fmt.Errorf("NATS JWT+seed required: set LOCALAI_NATS_JWT/LOCALAI_NATS_USER_SEED or enable frontend minting")
|
||||
}
|
||||
if tls.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(tls))
|
||||
}
|
||||
return messaging.New(url, opts...)
|
||||
}
|
||||
29
core/services/worker/nats_connect_test.go
Normal file
29
core/services/worker/nats_connect_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("connectNATS", func() {
|
||||
It("requires JWT when requireAuth is set and no credentials are provided", func() {
|
||||
_, err := connectNATS("nats://127.0.0.1:4222", "", "", "", "", true, messaging.TLSFiles{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("NATS JWT+seed required"))
|
||||
})
|
||||
|
||||
// A JWT supplied without its paired seed (or vice-versa) is an operator
|
||||
// misconfiguration. Today connectNATS silently drops the unpaired credential
|
||||
// and connects anonymously, so the operator believes the link is
|
||||
// authenticated when it is not. It should refuse instead.
|
||||
It("rejects a JWT supplied without a seed instead of connecting anonymously", func() {
|
||||
client, err := connectNATS("nats://127.0.0.1:4222", "jwt-without-seed", "", "", "", false, messaging.TLSFiles{})
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
Expect(err).To(HaveOccurred(),
|
||||
"connectNATS should reject an unpaired JWT rather than silently connecting anonymously")
|
||||
})
|
||||
})
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/cli/workerregistry"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
@@ -67,10 +68,63 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
|
||||
RegistrationToken: cfg.RegistrationToken,
|
||||
}
|
||||
|
||||
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
||||
// other background goroutines.
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
registrationBody := cfg.registrationBody()
|
||||
nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register with frontend: %w", err)
|
||||
natsTLS := messaging.TLSFiles{CA: cfg.NatsTLSCA, Cert: cfg.NatsTLSCert, Key: cfg.NatsTLSKey}
|
||||
|
||||
// Resolve how to connect to NATS. Static env credentials cannot be re-minted,
|
||||
// so register once and use them directly. Otherwise the credential manager
|
||||
// (re)registers to obtain credentials — waiting through admin approval — and
|
||||
// refreshes them before the minted JWT expires, so the connection survives
|
||||
// expiry via a transparent reconnect.
|
||||
var (
|
||||
nodeID string
|
||||
connectNats func() (*messaging.Client, error)
|
||||
)
|
||||
if cfg.NatsJWT != "" || cfg.NatsUserSeed != "" {
|
||||
nid, _, _, _, regErr := regClient.RegisterWithRetry(shutdownCtx, registrationBody, 10)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("failed to register with frontend: %w", regErr)
|
||||
}
|
||||
nodeID = nid
|
||||
connectNats = func() (*messaging.Client, error) {
|
||||
return connectNATS(cfg.NatsURL, cfg.NatsJWT, cfg.NatsUserSeed, "", "", cfg.NatsRequireAuth, natsTLS)
|
||||
}
|
||||
} else {
|
||||
credMgr := workerregistry.NewNATSCredentialManager(
|
||||
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
||||
return regClient.RegisterFull(ctx, registrationBody)
|
||||
},
|
||||
cfg.NatsRequireAuth,
|
||||
)
|
||||
res, regErr := credMgr.Acquire(shutdownCtx)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("failed to register with frontend: %w", regErr)
|
||||
}
|
||||
nodeID = res.ID
|
||||
connectNats = func() (*messaging.Client, error) {
|
||||
var opts []messaging.Option
|
||||
if credMgr.HasCredentials() {
|
||||
opts = append(opts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
||||
}
|
||||
if natsTLS.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(natsTLS))
|
||||
}
|
||||
client, cerr := messaging.New(cfg.NatsURL, opts...)
|
||||
if cerr == nil && credMgr.HasCredentials() {
|
||||
go func() {
|
||||
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
||||
xlog.Error("NATS credential refresh permanently failed; shutting down worker", "error", err)
|
||||
shutdownCancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
return client, cerr
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cfg.RegisterTo)
|
||||
@@ -79,9 +133,6 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cfg.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Start HTTP file transfer server
|
||||
httpAddr := cfg.resolveHTTPAddr()
|
||||
@@ -94,7 +145,7 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
|
||||
|
||||
// Connect to NATS
|
||||
xlog.Info("Connecting to NATS", "url", sanitize.URL(cfg.NatsURL))
|
||||
natsClient, err := messaging.New(cfg.NatsURL)
|
||||
natsClient, err := connectNats()
|
||||
if err != nil {
|
||||
nodes.ShutdownFileTransferServer(httpServer)
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
@@ -154,12 +205,21 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
|
||||
}
|
||||
|
||||
xlog.Info("Worker ready, waiting for backend.install events")
|
||||
<-sigCh
|
||||
// Exit on an OS signal or on an internal fatal condition (e.g. NATS
|
||||
// credentials became unrenewable), so the worker restarts and re-acquires
|
||||
// rather than lingering unable to serve.
|
||||
var runErr error
|
||||
select {
|
||||
case <-sigCh:
|
||||
case <-shutdownCtx.Done():
|
||||
runErr = fmt.Errorf("worker shutting down: NATS credentials unavailable")
|
||||
xlog.Error("Internal shutdown requested", "error", runErr)
|
||||
}
|
||||
|
||||
xlog.Info("Shutting down worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
supervisor.stopAllBackends()
|
||||
nodes.ShutdownFileTransferServer(httpServer)
|
||||
return nil
|
||||
return runErr
|
||||
}
|
||||
|
||||
@@ -71,6 +71,50 @@ The frontend is a standard LocalAI instance with distributed mode enabled. These
|
||||
| `--backend-upgrade-timeout` | `LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT` | `15m` | Same as the install timeout, applied to backend upgrades (force-reinstall). |
|
||||
| `--expose-node-header` | `LOCALAI_EXPOSE_NODE_HEADER` | `false` | When enabled, inference responses carry an `X-LocalAI-Node` header with the ID of the worker node that served the request. Coverage spans the OpenAI-compatible endpoints (chat completions, completions, embeddings, audio transcriptions, audio speech / TTS, image generations, image inpainting), the Jina rerank endpoint (`/v1/rerank`), the VAD endpoints (`/v1/vad`, `/vad`), and the Anthropic Messages (`/v1/messages`) and Ollama (`/api/chat`, `/api/generate`, `/api/embed`) shims. Useful for debugging, observability and load-balancer attribution. Off by default: the node ID reveals internal cluster topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency for the same model across multiple replicas, the header may reflect a recent routing decision rather than this exact request's. Acceptable for observability and debugging. |
|
||||
|
||||
### NATS JWT authentication (recommended for production)
|
||||
|
||||
By default, NATS connections are anonymous: any client that can reach port `4222` may publish control-plane subjects such as `nodes.<id>.backend.install`. Enable JWT auth to scope workers to their own node subjects and give the frontend a dedicated service credential.
|
||||
|
||||
| Flag | Env Var | Description |
|
||||
|------|---------|-------------|
|
||||
| `--nats-account-seed` | `LOCALAI_NATS_ACCOUNT_SEED` | Account signing seed (`SU...`). The frontend mints a per-node user JWT at registration (`nats_jwt` in the register response). |
|
||||
| `--nats-service-jwt` | `LOCALAI_NATS_SERVICE_JWT` | User JWT for the frontend (and optional fallback for agent workers) to publish install/upgrade and related subjects. |
|
||||
| `--nats-service-seed` | `LOCALAI_NATS_SERVICE_SEED` | User signing seed (`SU...`) paired with the service JWT. |
|
||||
| `--nats-worker-jwt-ttl` | `LOCALAI_NATS_WORKER_JWT_TTL` | Lifetime of minted worker JWTs (default `24h`). |
|
||||
| `--nats-require-auth` | `LOCALAI_NATS_REQUIRE_AUTH` | Fail startup if JWT credentials are missing when distributed mode is enabled. |
|
||||
|
||||
### NATS TLS / mTLS (optional)
|
||||
|
||||
Use `tls://` in `--nats-url` / `LOCALAI_NATS_URL` for encrypted transport. When the server uses a private CA or requires client certificates, set:
|
||||
|
||||
| Flag | Env Var | Description |
|
||||
|------|---------|-------------|
|
||||
| `--nats-tls-ca` | `LOCALAI_NATS_TLS_CA` | PEM file to verify the NATS server (private CA) |
|
||||
| `--nats-tls-cert` | `LOCALAI_NATS_TLS_CERT` | Client certificate for NATS mTLS |
|
||||
| `--nats-tls-key` | `LOCALAI_NATS_TLS_KEY` | Client private key (required with `--nats-tls-cert`) |
|
||||
|
||||
The same env vars apply to backend workers and `local-ai agent-worker`. If the server cert is already trusted by the OS, `tls://` alone is enough.
|
||||
|
||||
**Worker register response** (when minting is enabled and the node is approved):
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "…",
|
||||
"nats_jwt": "eyJ…",
|
||||
"nats_user_seed": "SU…"
|
||||
}
|
||||
```
|
||||
|
||||
Workers connect with that JWT and seed automatically (shown once; store securely). Override with `LOCALAI_NATS_JWT` / `LOCALAI_NATS_USER_SEED` if needed. Set `LOCALAI_NATS_REQUIRE_AUTH=true` on workers when the bus requires credentials.
|
||||
|
||||
When `LOCALAI_NATS_REQUIRE_AUTH=true` and no static credentials are provided, a worker that registers while still **pending admin approval** keeps re-registering (with backoff) until an admin approves it and the frontend mints its JWT — it does not start unauthenticated. This retry is **bounded**: if the node is never approved (or no credentials are minted) after a large number of attempts, the worker exits non-zero so the failure is visible (a crash-looping or failed worker) rather than hanging silently. Minted worker JWTs are also **refreshed automatically** before they expire (the worker re-registers at ~75% of the JWT lifetime), so long-running workers survive past `LOCALAI_NATS_WORKER_JWT_TTL`; the NATS connection picks up the new JWT on its next reconnect. If refresh fails persistently, the worker exits (to restart and re-acquire) rather than drifting toward an expired, unrenewable JWT. Statically configured (`LOCALAI_NATS_JWT`) and service (`LOCALAI_NATS_SERVICE_JWT`) credentials are used as-is and not refreshed.
|
||||
|
||||
Generate operator/account material with [`scripts/nats-auth-setup.sh`](https://github.com/mudler/LocalAI/blob/master/scripts/nats-auth-setup.sh) (requires [nsc](https://docs.nats.io/running-a-nats-service/configuration/securing_nats/auth_intro/nsc)). Configure the NATS server with account resolver JWTs before enabling `LOCALAI_NATS_REQUIRE_AUTH`.
|
||||
|
||||
{{% notice note %}}
|
||||
`LOCALAI_AUTH` (HTTP users/sessions) and NATS JWTs are separate: end-user API keys do not connect to NATS. HTTP registration still uses `LOCALAI_REGISTRATION_TOKEN`.
|
||||
{{% /notice %}}
|
||||
|
||||
### Optional: S3 Object Storage
|
||||
|
||||
For multi-host deployments where workers don't share a filesystem, S3-compatible storage enables distributed file transfer (model files, configs):
|
||||
@@ -134,6 +178,12 @@ local-ai worker \
|
||||
| `--registration-token` | `LOCALAI_REGISTRATION_TOKEN` | *(empty)* | Token to authenticate with the frontend |
|
||||
| `--heartbeat-interval` | `LOCALAI_HEARTBEAT_INTERVAL` | `10s` | Interval between heartbeat pings |
|
||||
| `--nats-url` | `LOCALAI_NATS_URL` | *(required)* | NATS URL for backend installation and file staging |
|
||||
| `--nats-jwt` | `LOCALAI_NATS_JWT` | *(empty)* | Optional override for the `nats_jwt` returned at registration |
|
||||
| `--nats-user-seed` | `LOCALAI_NATS_USER_SEED` | *(empty)* | Optional override for `nats_user_seed` from registration |
|
||||
| `--nats-require-auth` | `LOCALAI_NATS_REQUIRE_AUTH` | `false` | Require NATS JWT+seed (from registration or env) |
|
||||
| `--nats-tls-ca` | `LOCALAI_NATS_TLS_CA` | *(empty)* | PEM file for NATS server CA |
|
||||
| `--nats-tls-cert` | `LOCALAI_NATS_TLS_CERT` | *(empty)* | Client certificate for NATS mTLS |
|
||||
| `--nats-tls-key` | `LOCALAI_NATS_TLS_KEY` | *(empty)* | Client private key for NATS mTLS |
|
||||
| `--backends-path` | `LOCALAI_BACKENDS_PATH` | `./backends` | Path to backend binaries |
|
||||
| `--models-path` | `LOCALAI_MODELS_PATH` | `./models` | Path to model files |
|
||||
|
||||
|
||||
@@ -296,6 +296,28 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
#### Language
|
||||
|
||||
You can hint the synthesis language with the `language` request field:
|
||||
|
||||
```
|
||||
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
"model": "qwen-tts",
|
||||
"input": "Bonjour le monde.",
|
||||
"language": "fr"
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
Supported languages: `en` (English), `zh` (Chinese), `ru` (Russian), `ja` (Japanese), `ko` (Korean), `de` (German), `fr` (French), `es` (Spanish), `it` (Italian), `pt` (Portuguese).
|
||||
|
||||
The value is matched case-insensitively and accepts a few forms for convenience:
|
||||
|
||||
- the two-letter code (`fr`, `FR`)
|
||||
- a locale/region form, whose region is ignored (`fr-FR`, `pt_BR`, `zh-Hans` → `fr`/`pt`/`zh`)
|
||||
- the English full name (`french`, `Portuguese`)
|
||||
|
||||
If the field is omitted or the value isn't one of the supported languages, the backend defaults to English.
|
||||
|
||||
#### Custom Voice Mode
|
||||
|
||||
Qwen3-TTS supports predefined speakers. You can specify a speaker using the `voice` parameter:
|
||||
@@ -337,6 +359,37 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
#### Per-request instructions
|
||||
|
||||
Instead of (or in addition to) the static YAML `instruct` option, you can pass an
|
||||
`instructions` string per request. It maps to the OpenAI
|
||||
[`instructions`](https://platform.openai.com/docs/api-reference/audio/createSpeech) field
|
||||
and takes precedence over the YAML option when set, falling back to it when empty. This lets
|
||||
a single model config serve a different emotion (CustomVoice) or a different designed voice
|
||||
(VoiceDesign) on every request - useful for roleplay/narration clients that need many voices:
|
||||
|
||||
```
|
||||
curl http://localhost:8080/v1/audio/speech -H "Content-Type: application/json" -d '{
|
||||
"model": "qwen-tts-design",
|
||||
"input": "Hello world, this is a test.",
|
||||
"instructions": "A calm, low-pitched elderly storyteller with a warm tone."
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
Backends that do not support style/voice instructions simply ignore the field.
|
||||
|
||||
You can also pass backend-specific generation parameters per request via the LocalAI
|
||||
`params` extension (a string-to-string map; values are coerced to the backend's expected
|
||||
types). For example, with the Chatterbox backend:
|
||||
|
||||
```
|
||||
curl http://localhost:8080/v1/audio/speech -H "Content-Type: application/json" -d '{
|
||||
"model": "chatterbox",
|
||||
"input": "Hello world, this is a test.",
|
||||
"params": { "exaggeration": "0.7", "cfg_weight": "0.3", "temperature": "0.8" }
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
#### Voice Clone Mode
|
||||
|
||||
Voice Clone allows you to clone a voice from reference audio. Configure the model with an `AudioPath` and optional `ref_text`:
|
||||
|
||||
@@ -1,4 +1,58 @@
|
||||
---
|
||||
- name: "step-3.7-flash"
|
||||
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||
urls:
|
||||
- https://huggingface.co/unsloth/Step-3.7-Flash-GGUF
|
||||
description: |
|
||||
**[ModelPage]**: https://static.stepfun.com/blog/step-3.7-flash/
|
||||
|
||||
## 1. Introduction
|
||||
|
||||
Step 3.7 Flash is a 198B-parameter sparse Mixture-of-Experts (MoE) vision-language model that combines a 196B-parameter language backbone with a 1.8B-parameter vision encoder for native image understanding. Engineered for high-frequency production workloads, it activates approximately 11B parameters per token and delivers a throughput of up to 400 tokens per second. Step 3.7 Flash supports a 256k context window and offers three selectable reasoning levels (low, medium, and high) so developers can easily balance speed, cost, and cognitive depth.
|
||||
|
||||
We built Step 3.7 Flash for developers who need to scale agentic workflows that combine perception, search, and reasoning. It is designed to handle intensive tasks such as parsing massive financial reports in one pass, running multi-step search loops with cross-source verification, or operating concurrent coding agents in high-throughput pipelines.
|
||||
|
||||
## 2. Capabilities & Performance
|
||||
|
||||
### Multimodal Perception and Verification
|
||||
|
||||
...
|
||||
license: "apache-2.0"
|
||||
tags:
|
||||
- llm
|
||||
- gguf
|
||||
icon: https://example.com/photo.jpg
|
||||
overrides:
|
||||
backend: llama-cpp
|
||||
function:
|
||||
automatic_tool_parsing_fallback: true
|
||||
grammar:
|
||||
disable: true
|
||||
known_usecases:
|
||||
- chat
|
||||
mmproj: llama-cpp/mmproj/Step-3.7-Flash-GGUF/mmproj-F32.gguf
|
||||
options:
|
||||
- use_jinja:true
|
||||
parameters:
|
||||
model: llama-cpp/models/Step-3.7-Flash-GGUF/Step-3.7-Flash-UD-Q4_K_M-00001-of-00004.gguf
|
||||
template:
|
||||
use_tokenizer_template: true
|
||||
files:
|
||||
- filename: llama-cpp/models/Step-3.7-Flash-GGUF/Step-3.7-Flash-UD-Q4_K_M-00001-of-00004.gguf
|
||||
sha256: 3ace7518df03a818243c55076e8c5b422961aa3cefe4fa8f120d4456dd2edde7
|
||||
uri: https://huggingface.co/unsloth/Step-3.7-Flash-GGUF/resolve/main/UD-Q4_K_M/Step-3.7-Flash-UD-Q4_K_M-00001-of-00004.gguf
|
||||
- filename: llama-cpp/models/Step-3.7-Flash-GGUF/Step-3.7-Flash-UD-Q4_K_M-00002-of-00004.gguf
|
||||
sha256: 1ff05ea5a4518c488548219ec944aadec6a1a075140a3f81ae258ec51b755a75
|
||||
uri: https://huggingface.co/unsloth/Step-3.7-Flash-GGUF/resolve/main/UD-Q4_K_M/Step-3.7-Flash-UD-Q4_K_M-00002-of-00004.gguf
|
||||
- filename: llama-cpp/models/Step-3.7-Flash-GGUF/Step-3.7-Flash-UD-Q4_K_M-00003-of-00004.gguf
|
||||
sha256: 47c1b36d9e6df9fcd6e05873bdaa101a54b85e56bcd775ce0a199453387c339d
|
||||
uri: https://huggingface.co/unsloth/Step-3.7-Flash-GGUF/resolve/main/UD-Q4_K_M/Step-3.7-Flash-UD-Q4_K_M-00003-of-00004.gguf
|
||||
- filename: llama-cpp/models/Step-3.7-Flash-GGUF/Step-3.7-Flash-UD-Q4_K_M-00004-of-00004.gguf
|
||||
sha256: 1cc54c0a491b63b86ef0ddc631950c2b881ed701de9ffb1903338d3cbf088262
|
||||
uri: https://huggingface.co/unsloth/Step-3.7-Flash-GGUF/resolve/main/UD-Q4_K_M/Step-3.7-Flash-UD-Q4_K_M-00004-of-00004.gguf
|
||||
- filename: llama-cpp/mmproj/Step-3.7-Flash-GGUF/mmproj-F32.gguf
|
||||
sha256: 2fab13dcd32e4b3dc4410297df80f4d82627308e725dedac802940ceca7dff13
|
||||
uri: https://huggingface.co/unsloth/Step-3.7-Flash-GGUF/resolve/main/mmproj-F32.gguf
|
||||
- name: "lfm2.5-8b-a1b"
|
||||
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||
urls:
|
||||
@@ -31855,6 +31909,7 @@
|
||||
files:
|
||||
- filename: parakeet-tdt-0.6b-v3-q4_k.gguf
|
||||
uri: huggingface://cstr/parakeet-tdt-0.6b-v3-GGUF/parakeet-tdt-0.6b-v3-q4_k.gguf
|
||||
sha256: 1a60f6e53e5781240dde6e69a47a47a8a71995a3a106517b009225afcc514457
|
||||
- name: parakeet-v2-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -31877,6 +31932,7 @@
|
||||
files:
|
||||
- filename: parakeet-tdt-0.6b-v2-q4_k.gguf
|
||||
uri: huggingface://cstr/parakeet-tdt-0.6b-v2-GGUF/parakeet-tdt-0.6b-v2-q4_k.gguf
|
||||
sha256: f392cee3c2ba81b397b021e151e4588ded7fc985f8115cfaeb405ea42fc518a9
|
||||
- name: parakeet-ja-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -31899,6 +31955,7 @@
|
||||
files:
|
||||
- filename: parakeet-tdt-0.6b-ja.gguf
|
||||
uri: huggingface://cstr/parakeet-tdt-0.6b-ja-GGUF/parakeet-tdt-0.6b-ja.gguf
|
||||
sha256: a9c43116b180b8a2ada2771ac829cf751b9e73adcbe69b7c8379593f9e5da31e
|
||||
- name: parakeet-tdt-1.1b-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -31921,6 +31978,7 @@
|
||||
files:
|
||||
- filename: parakeet-tdt-1.1b-q4_k.gguf
|
||||
uri: huggingface://cstr/parakeet-tdt-1.1b-GGUF/parakeet-tdt-1.1b-q4_k.gguf
|
||||
sha256: db64b442d02430b76e664fa1fd5facc7866d2bdc071d64028dad55772cde252c
|
||||
- name: parakeet-tdt_ctc-110m-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -31943,6 +32001,7 @@
|
||||
files:
|
||||
- filename: parakeet-tdt_ctc-110m-q4_k.gguf
|
||||
uri: huggingface://cstr/parakeet-tdt_ctc-110m-GGUF/parakeet-tdt_ctc-110m-q4_k.gguf
|
||||
sha256: c57f84d0826b6a10172c0b9696da472efb5e4c604987ef0d023214b29f38e929
|
||||
- name: parakeet-tdt_ctc-1.1b-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -31965,6 +32024,7 @@
|
||||
files:
|
||||
- filename: parakeet-tdt_ctc-1.1b-q4_k.gguf
|
||||
uri: huggingface://cstr/parakeet-tdt_ctc-1.1b-GGUF/parakeet-tdt_ctc-1.1b-q4_k.gguf
|
||||
sha256: 52784c0ac7321a6e1d915a96837f6f508fc5bff240b37f5e58dea39feb302edd
|
||||
- name: parakeet-rnnt-0.6b-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -31987,6 +32047,7 @@
|
||||
files:
|
||||
- filename: parakeet-rnnt-0.6b-q4_k.gguf
|
||||
uri: huggingface://cstr/parakeet-rnnt-0.6b-GGUF/parakeet-rnnt-0.6b-q4_k.gguf
|
||||
sha256: 84de2c556e30e87ef1fe5b0ac035b581c233ec017afe517082543b19eba8c73d
|
||||
- name: parakeet-rnnt-1.1b-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -32009,6 +32070,7 @@
|
||||
files:
|
||||
- filename: parakeet-rnnt-1.1b-q4_k.gguf
|
||||
uri: huggingface://cstr/parakeet-rnnt-1.1b-GGUF/parakeet-rnnt-1.1b-q4_k.gguf
|
||||
sha256: 9e6d6e5aba6dbe15853f93ad317b8017fe21df78fd854d334ca0c4144aefce08
|
||||
- name: fastconformer-ctc-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -32031,6 +32093,7 @@
|
||||
files:
|
||||
- filename: stt-en-fastconformer-ctc-large-q4_k.gguf
|
||||
uri: huggingface://cstr/stt-en-fastconformer-ctc-large-GGUF/stt-en-fastconformer-ctc-large-q4_k.gguf
|
||||
sha256: 5529d6762d1799a58b4fb806f766c2ce893f59d4d38d948d1177fcd3bfa28920
|
||||
- name: canary-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
@@ -32053,6 +32116,7 @@
|
||||
files:
|
||||
- filename: canary-1b-v2-q4_k.gguf
|
||||
uri: huggingface://cstr/canary-1b-v2-GGUF/canary-1b-v2-q4_k.gguf
|
||||
sha256: 187668f4b7bb7faee0c02de55664c7cb13c792dd54e47da888e05815420e16f1
|
||||
- name: voxtral-crispasr
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
|
||||
3
go.mod
3
go.mod
@@ -41,7 +41,9 @@ require (
|
||||
github.com/mudler/go-processmanager v0.1.1
|
||||
github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8
|
||||
github.com/mudler/xlog v0.0.6
|
||||
github.com/nats-io/jwt/v2 v2.7.4
|
||||
github.com/nats-io/nats.go v1.52.0
|
||||
github.com/nats-io/nkeys v0.4.15
|
||||
github.com/ollama/ollama v0.20.4
|
||||
github.com/onsi/ginkgo/v2 v2.29.0
|
||||
github.com/onsi/gomega v1.41.0
|
||||
@@ -134,7 +136,6 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/moby/moby/api v1.54.2 // indirect
|
||||
github.com/moby/moby/client v0.4.1 // indirect
|
||||
github.com/nats-io/nkeys v0.4.15 // indirect
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
github.com/oklog/ulid v1.3.1 // indirect
|
||||
github.com/secure-systems-lab/go-securesystemslib v0.9.1 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1016,6 +1016,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A=
|
||||
github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM=
|
||||
github.com/nats-io/jwt/v2 v2.7.4 h1:jXFuDDxs/GQjGDZGhNgH4tXzSUK6WQi2rsj4xmsNOtI=
|
||||
github.com/nats-io/jwt/v2 v2.7.4/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA=
|
||||
github.com/nats-io/nats.go v1.52.0 h1:n3avV4VBsCgsdwh71TppsTwtv+QdPs7ntSKM8qJLGsc=
|
||||
github.com/nats-io/nats.go v1.52.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno=
|
||||
github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4=
|
||||
|
||||
35
pkg/grpc/grpcerrors/errors.go
Normal file
35
pkg/grpc/grpcerrors/errors.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Package grpcerrors defines well-known error signals shared between backends
|
||||
// (which produce them) and the router (which consumes them). Go error types do
|
||||
// not survive the gRPC boundary, so these conditions are carried as gRPC status
|
||||
// codes and detected via the code rather than by matching the error message.
|
||||
package grpcerrors
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// ModelNotLoaded returns the canonical error a backend returns when it has no
|
||||
// model loaded for the request. It carries codes.FailedPrecondition so callers
|
||||
// can detect it across the gRPC boundary without matching the message string.
|
||||
func ModelNotLoaded(backend string) error {
|
||||
return status.Errorf(codes.FailedPrecondition, "%s: model not loaded", backend)
|
||||
}
|
||||
|
||||
// IsModelNotLoaded reports whether err signals that the backend has no model
|
||||
// loaded. It prefers the typed gRPC status code (FailedPrecondition) and falls
|
||||
// back to the message for backends that have not yet adopted ModelNotLoaded.
|
||||
//
|
||||
// Acting on a false positive is harmless: the only consequence upstream is that
|
||||
// the model is reloaded, which is idempotent.
|
||||
func IsModelNotLoaded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if status.Code(err) == codes.FailedPrecondition {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "model not loaded")
|
||||
}
|
||||
37
pkg/grpc/grpcerrors/errors_test.go
Normal file
37
pkg/grpc/grpcerrors/errors_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package grpcerrors_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestGRPCErrors(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "grpcerrors test suite")
|
||||
}
|
||||
|
||||
var _ = Describe("grpcerrors", func() {
|
||||
DescribeTable("IsModelNotLoaded",
|
||||
func(err error, want bool) {
|
||||
Expect(grpcerrors.IsModelNotLoaded(err)).To(Equal(want))
|
||||
},
|
||||
Entry("nil", nil, false),
|
||||
Entry("typed via constructor", grpcerrors.ModelNotLoaded("parakeet-cpp"), true),
|
||||
Entry("typed code only", status.Error(codes.FailedPrecondition, "anything"), true),
|
||||
Entry("legacy message (Unknown code)", errors.New("parakeet-cpp: model not loaded"), true),
|
||||
Entry("legacy message mixed case", errors.New("Backend: Model Not Loaded"), true),
|
||||
Entry("unrelated error", errors.New("context deadline exceeded"), false),
|
||||
Entry("unrelated grpc code", status.Error(codes.Unavailable, "connection refused"), false),
|
||||
)
|
||||
|
||||
It("ModelNotLoaded carries FailedPrecondition", func() {
|
||||
Expect(status.Code(grpcerrors.ModelNotLoaded("whisper"))).To(Equal(codes.FailedPrecondition))
|
||||
})
|
||||
})
|
||||
@@ -207,28 +207,6 @@ var knownModelsNameSuffixToSkip []string = []string{
|
||||
".tar.gz",
|
||||
}
|
||||
|
||||
// HasKnownModelFileExtension reports whether name ends in a file extension that
|
||||
// LocalAI recognizes as a model weight or asset file (e.g. ".gguf",
|
||||
// ".safetensors", ".json"). It is used to tell a concrete file path such as
|
||||
// "local/model.gguf" apart from a HuggingFace-style repository ID like
|
||||
// "org/repo": only the former carries a recognized suffix. A version-style
|
||||
// suffix such as the ".0" in "stabilityai/stable-diffusion-xl-base-1.0" is not
|
||||
// in the list, so such repo IDs are correctly treated as non-files.
|
||||
func HasKnownModelFileExtension(name string) bool {
|
||||
lower := strings.ToLower(name)
|
||||
for _, suffix := range knownModelsNameSuffixToSkip {
|
||||
// "." is a guard entry consumed by ListFilesInModelPath, not a real
|
||||
// extension; skip it so it doesn't match every dotted name.
|
||||
if suffix == "." {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(lower, strings.ToLower(suffix)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const retryTimeout = time.Duration(2 * time.Minute)
|
||||
|
||||
func (ml *ModelLoader) ListFilesInModelPath() ([]string, error) {
|
||||
|
||||
@@ -58,23 +58,6 @@ var _ = Describe("ModelLoader", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("HasKnownModelFileExtension", func() {
|
||||
It("returns true for concrete weight/asset file paths", func() {
|
||||
Expect(model.HasKnownModelFileExtension("local/model.gguf")).To(BeTrue())
|
||||
Expect(model.HasKnownModelFileExtension("model.safetensors")).To(BeTrue())
|
||||
Expect(model.HasKnownModelFileExtension("foo/bar.GGUF")).To(BeTrue())
|
||||
Expect(model.HasKnownModelFileExtension("config.json")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns false for HuggingFace-style repository IDs", func() {
|
||||
// org/repo carries no recognized file extension...
|
||||
Expect(model.HasKnownModelFileExtension("bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF")).To(BeFalse())
|
||||
// ...and a version suffix like ".0" is not a known model extension.
|
||||
Expect(model.HasKnownModelFileExtension("stabilityai/stable-diffusion-xl-base-1.0")).To(BeFalse())
|
||||
Expect(model.HasKnownModelFileExtension("plain-model-name")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ListFilesInModelPath", func() {
|
||||
It("should list all valid model files in the model path", func() {
|
||||
os.Create(filepath.Join(modelPath, "test.model"))
|
||||
|
||||
66
pkg/natsauth/config.go
Normal file
66
pkg/natsauth/config.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package natsauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// DefaultWorkerJWTTTL is how long a worker may use a minted NATS user JWT before re-registering.
|
||||
const DefaultWorkerJWTTTL = 24 * time.Hour
|
||||
|
||||
// Config holds NATS JWT authentication settings for distributed mode.
|
||||
type Config struct {
|
||||
// AccountSeed is the NATS account signing seed (SU...). Used to mint per-node worker JWTs.
|
||||
AccountSeed string
|
||||
// ServiceUserJWT is a pre-generated user JWT for frontends and agent workers (broad publish).
|
||||
ServiceUserJWT string
|
||||
// ServiceUserSeed is the signing seed (SU...) paired with ServiceUserJWT.
|
||||
ServiceUserSeed string
|
||||
// WorkerJWTTTL sets expiry on minted worker JWTs. Zero uses DefaultWorkerJWTTTL.
|
||||
WorkerJWTTTL time.Duration
|
||||
// RequireAuth rejects anonymous NATS when true (both ServiceUserJWT and AccountSeed expected).
|
||||
RequireAuth bool
|
||||
}
|
||||
|
||||
// Enabled reports whether any NATS credential material is configured.
|
||||
func (c Config) Enabled() bool {
|
||||
return c.AccountSeed != "" || c.ServiceUserJWT != ""
|
||||
}
|
||||
|
||||
// CanMintWorkers reports whether per-node JWTs can be issued at registration.
|
||||
func (c Config) CanMintWorkers() bool {
|
||||
return c.AccountSeed != ""
|
||||
}
|
||||
|
||||
// WorkerTTL returns the configured worker JWT lifetime.
|
||||
func (c Config) WorkerTTL() time.Duration {
|
||||
if c.WorkerJWTTTL > 0 {
|
||||
return c.WorkerJWTTTL
|
||||
}
|
||||
return DefaultWorkerJWTTTL
|
||||
}
|
||||
|
||||
// Validate checks consistency when distributed NATS auth is required.
|
||||
func (c Config) Validate() error {
|
||||
if !c.RequireAuth {
|
||||
return nil
|
||||
}
|
||||
if c.ServiceUserJWT == "" || c.ServiceUserSeed == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
|
||||
}
|
||||
if c.AccountSeed == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH is set but LOCALAI_NATS_ACCOUNT_SEED is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WarnIfInsecure logs when distributed NATS is reachable without credentials.
|
||||
func (c Config) WarnIfInsecure(distributed bool) {
|
||||
if !distributed || c.Enabled() {
|
||||
return
|
||||
}
|
||||
xlog.Warn("NATS is used without JWT credentials — any client on the bus can publish backend.install. " +
|
||||
"Set LOCALAI_NATS_ACCOUNT_SEED + LOCALAI_NATS_SERVICE_JWT (see docs/features/distributed-mode.md).")
|
||||
}
|
||||
16
pkg/natsauth/decode.go
Normal file
16
pkg/natsauth/decode.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package natsauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
)
|
||||
|
||||
// DecodeUserClaims decodes a minted worker JWT for tests and diagnostics.
|
||||
func DecodeUserClaims(token string) (*jwt.UserClaims, error) {
|
||||
uc, err := jwt.DecodeUserClaims(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("natsauth: decode user JWT: %w", err)
|
||||
}
|
||||
return uc, nil
|
||||
}
|
||||
59
pkg/natsauth/mint.go
Normal file
59
pkg/natsauth/mint.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package natsauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
// MintWorkerJWT creates a signed NATS user JWT and user seed scoped to nodeID and nodeType.
|
||||
// The seed is returned once at registration so the worker can sign NATS connections.
|
||||
func (c Config) MintWorkerJWT(nodeID, nodeType string) (userJWT, userSeed string, err error) {
|
||||
if c.AccountSeed == "" {
|
||||
return "", "", fmt.Errorf("natsauth: account seed not configured")
|
||||
}
|
||||
if nodeID == "" {
|
||||
return "", "", fmt.Errorf("natsauth: node ID is required")
|
||||
}
|
||||
|
||||
accountKP, err := nkeys.FromSeed([]byte(c.AccountSeed))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: invalid account seed: %w", err)
|
||||
}
|
||||
|
||||
userKP, err := nkeys.CreateUser()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: create user key: %w", err)
|
||||
}
|
||||
seedBytes, err := userKP.Seed()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: user seed: %w", err)
|
||||
}
|
||||
|
||||
accountPub, err := accountKP.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: account public key: %w", err)
|
||||
}
|
||||
userPub, err := userKP.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: user public key: %w", err)
|
||||
}
|
||||
|
||||
pubAllow, subAllow := WorkerPermissions(nodeID, nodeType)
|
||||
|
||||
uc := jwt.NewUserClaims(userPub)
|
||||
uc.Name = fmt.Sprintf("localai-%s-%s", nodeType, workerSubjectToken(nodeID))
|
||||
uc.IssuerAccount = accountPub
|
||||
uc.Expires = time.Now().Add(c.WorkerTTL()).Unix()
|
||||
|
||||
uc.Permissions.Pub.Allow = pubAllow
|
||||
uc.Permissions.Sub.Allow = subAllow
|
||||
|
||||
token, err := uc.Encode(accountKP)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: encode user JWT: %w", err)
|
||||
}
|
||||
return token, string(seedBytes), nil
|
||||
}
|
||||
60
pkg/natsauth/mint_test.go
Normal file
60
pkg/natsauth/mint_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package natsauth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestNatsAuth(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "NatsAuth")
|
||||
}
|
||||
|
||||
var _ = Describe("MintWorkerJWT", func() {
|
||||
var accountSeed string
|
||||
|
||||
BeforeEach(func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
accountSeed = string(seed)
|
||||
})
|
||||
|
||||
It("mints a JWT with backend worker permissions", func() {
|
||||
cfg := natsauth.Config{AccountSeed: accountSeed, WorkerJWTTTL: time.Hour}
|
||||
token, seed, err := cfg.MintWorkerJWT("550e8400-e29b-41d4-a716-446655440000", "backend")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(token).NotTo(BeEmpty())
|
||||
Expect(seed).NotTo(BeEmpty())
|
||||
|
||||
uc, err := jwt.DecodeUserClaims(token)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(uc.Permissions.Sub.Allow).To(ContainElement("nodes.550e8400-e29b-41d4-a716-446655440000.>"))
|
||||
Expect(uc.Permissions.Pub.Allow).To(ContainElement("nodes.550e8400-e29b-41d4-a716-446655440000.backend.install.*.progress"))
|
||||
})
|
||||
|
||||
It("mints agent permissions without backend install subscribe", func() {
|
||||
cfg := natsauth.Config{AccountSeed: accountSeed}
|
||||
token, _, err := cfg.MintWorkerJWT("node-1", "agent")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
uc, err := jwt.DecodeUserClaims(token)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(uc.Permissions.Sub.Allow).To(ContainElement("agent.execute"))
|
||||
for _, subj := range uc.Permissions.Sub.Allow {
|
||||
Expect(subj).NotTo(ContainSubstring("backend.install"))
|
||||
}
|
||||
})
|
||||
|
||||
It("rejects mint without account seed", func() {
|
||||
_, _, err := (natsauth.Config{}).MintWorkerJWT("id", "backend")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
49
pkg/natsauth/permissions.go
Normal file
49
pkg/natsauth/permissions.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package natsauth
|
||||
|
||||
import "strings"
|
||||
|
||||
// workerSubjectToken mirrors messaging.sanitizeSubjectToken without importing unexported logic.
|
||||
func workerSubjectToken(nodeID string) string {
|
||||
r := strings.NewReplacer(".", "-", "*", "-", ">", "-", " ", "-", "\t", "-", "\n", "-")
|
||||
return r.Replace(nodeID)
|
||||
}
|
||||
|
||||
// WorkerPermissions returns NATS pub/sub allow lists for a registered node.
|
||||
func WorkerPermissions(nodeID, nodeType string) (pubAllow, subAllow []string) {
|
||||
tok := workerSubjectToken(nodeID)
|
||||
prefix := "nodes." + tok
|
||||
|
||||
switch nodeType {
|
||||
case "agent":
|
||||
// Agent workers consume queue workloads; they must not handle backend.install.
|
||||
// Keep this list in sync with the subscriptions in core/cli/agent_worker.go.
|
||||
subAllow = []string{
|
||||
"agent.execute",
|
||||
"jobs.*.cancel",
|
||||
"jobs.*.progress",
|
||||
"jobs.*.result",
|
||||
"jobs.mcp-ci.new", // MCP CI jobs dispatched to agent workers
|
||||
"mcp.tools.execute",
|
||||
"mcp.discovery",
|
||||
prefix + ".backend.stop", // stop events drive MCP session cleanup
|
||||
"_INBOX.>",
|
||||
}
|
||||
pubAllow = []string{
|
||||
"agent.>",
|
||||
"jobs.>",
|
||||
"_INBOX.>",
|
||||
}
|
||||
default:
|
||||
// Backend worker: lifecycle + file staging on this node only.
|
||||
subAllow = []string{
|
||||
prefix + ".>",
|
||||
"_INBOX.>",
|
||||
}
|
||||
pubAllow = []string{
|
||||
prefix + ".backend.install.*.progress",
|
||||
prefix + ".files.>",
|
||||
"_INBOX.>",
|
||||
}
|
||||
}
|
||||
return pubAllow, subAllow
|
||||
}
|
||||
134
pkg/natsauth/permissions_coverage_test.go
Normal file
134
pkg/natsauth/permissions_coverage_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package natsauth_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// subjectMatches implements NATS subject-token matching: "*" matches exactly one
|
||||
// token and ">" matches one or more trailing tokens. It lets these tests assert
|
||||
// that a permission allow-list (which uses wildcards) actually covers a concrete
|
||||
// subject a component publishes/subscribes — the same check the NATS server makes.
|
||||
func subjectMatches(pattern, subject string) bool {
|
||||
p := strings.Split(pattern, ".")
|
||||
s := strings.Split(subject, ".")
|
||||
for i, tok := range p {
|
||||
if tok == ">" {
|
||||
return i < len(s) // ">" must match at least one remaining token
|
||||
}
|
||||
if i >= len(s) {
|
||||
return false
|
||||
}
|
||||
if tok != "*" && tok != s[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(p) == len(s)
|
||||
}
|
||||
|
||||
func anyAllows(allow []string, subject string) bool {
|
||||
for _, p := range allow {
|
||||
if subjectMatches(p, subject) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var _ = Describe("WorkerPermissions subject coverage", func() {
|
||||
// A node ID containing NATS-reserved characters exercises the (duplicated)
|
||||
// sanitizer in pkg/natsauth against the canonical one in core/services/messaging.
|
||||
// If the two ever diverge, the minted prefix stops matching the real subject
|
||||
// and these assertions fail — guarding the copy noted in the review.
|
||||
const nodeID = "host.a 1*b"
|
||||
|
||||
Context("backend worker", func() {
|
||||
pub, sub := natsauth.WorkerPermissions(nodeID, "backend")
|
||||
|
||||
// Every subject core/services/worker/{lifecycle,file_staging}.go subscribes to.
|
||||
subscribed := []string{
|
||||
messaging.SubjectNodeBackendInstall(nodeID),
|
||||
messaging.SubjectNodeBackendUpgrade(nodeID),
|
||||
messaging.SubjectNodeBackendStop(nodeID),
|
||||
messaging.SubjectNodeBackendDelete(nodeID),
|
||||
messaging.SubjectNodeBackendList(nodeID),
|
||||
messaging.SubjectNodeModelUnload(nodeID),
|
||||
messaging.SubjectNodeModelDelete(nodeID),
|
||||
messaging.SubjectNodeStop(nodeID),
|
||||
messaging.SubjectNodeFilesEnsure(nodeID),
|
||||
messaging.SubjectNodeFilesStage(nodeID),
|
||||
messaging.SubjectNodeFilesTemp(nodeID),
|
||||
messaging.SubjectNodeFilesListDir(nodeID),
|
||||
}
|
||||
for _, subject := range subscribed {
|
||||
It("allows subscribing to "+subject, func() {
|
||||
Expect(anyAllows(sub, subject)).To(BeTrue(),
|
||||
"backend JWT sub allow-list %v does not cover %s", sub, subject)
|
||||
})
|
||||
}
|
||||
|
||||
It("allows publishing backend.install progress", func() {
|
||||
subject := messaging.SubjectNodeBackendInstallProgress(nodeID, "op-123")
|
||||
Expect(anyAllows(pub, subject)).To(BeTrue(),
|
||||
"backend JWT pub allow-list %v does not cover %s", pub, subject)
|
||||
})
|
||||
})
|
||||
|
||||
Context("agent worker", func() {
|
||||
// node_type "agent"; subjects from core/cli/agent_worker.go.
|
||||
pub, sub := natsauth.WorkerPermissions(nodeID, "agent")
|
||||
_ = pub
|
||||
|
||||
subscribed := []string{
|
||||
messaging.SubjectAgentExecute, // dispatcher (default --agent-subject)
|
||||
messaging.SubjectMCPToolExecute, // QueueSubscribeReply
|
||||
messaging.SubjectMCPDiscovery, // QueueSubscribeReply
|
||||
messaging.SubjectMCPCIJobsNew, // QueueSubscribe — jobs.mcp-ci.new
|
||||
messaging.SubjectNodeBackendStop(nodeID), // Subscribe — MCP session cleanup
|
||||
}
|
||||
for _, subject := range subscribed {
|
||||
It("allows subscribing to "+subject, func() {
|
||||
Expect(anyAllows(sub, subject)).To(BeTrue(),
|
||||
"agent JWT sub allow-list %v does not cover %s — the agent worker subscribes to it", sub, subject)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
var allowPubRe = regexp.MustCompile(`--allow-pub "([^"]*)"`)
|
||||
|
||||
var _ = Describe("Documented NATS service-user permissions", func() {
|
||||
// scripts/nats-auth-setup.sh ships the recommended service (frontend) JWT
|
||||
// permissions. They must cover every subject the frontend actually publishes,
|
||||
// or prefix-cache sync (and friends) break once LOCALAI_NATS_REQUIRE_AUTH is on.
|
||||
const scriptPath = "../../scripts/nats-auth-setup.sh"
|
||||
|
||||
// Representative subjects the frontend publishes on the control plane.
|
||||
// prefixcache.* is emitted by prefixcache.Sync in core/application/distributed.go.
|
||||
frontendPublishes := []string{
|
||||
messaging.SubjectPrefixCacheObserve,
|
||||
messaging.SubjectPrefixCacheInvalidate,
|
||||
messaging.SubjectNodeBackendInstall("node-1"),
|
||||
messaging.SubjectGalleryProgress("op-1"),
|
||||
}
|
||||
|
||||
It("cover every subject the frontend publishes", func() {
|
||||
raw, err := os.ReadFile(scriptPath)
|
||||
Expect(err).ToNot(HaveOccurred(), "cannot read %s", scriptPath)
|
||||
m := allowPubRe.FindStringSubmatch(string(raw))
|
||||
Expect(m).To(HaveLen(2), "no --allow-pub list found in %s", scriptPath)
|
||||
allow := strings.Split(m[1], ",")
|
||||
|
||||
for _, subject := range frontendPublishes {
|
||||
Expect(anyAllows(allow, subject)).To(BeTrue(),
|
||||
"service-user --allow-pub %v does not cover %s (frontend publishes it)", allow, subject)
|
||||
}
|
||||
})
|
||||
})
|
||||
49
scripts/nats-auth-setup.sh
Executable file
49
scripts/nats-auth-setup.sh
Executable file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env bash
|
||||
# Generate NATS account + service user JWTs for LocalAI distributed mode.
|
||||
#
|
||||
# Requires: nsc (https://docs.nats.io/running-a-nats-service/configuration/securing_nats/auth_intro/nsc)
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/nats-auth-setup.sh
|
||||
#
|
||||
# Outputs operator/account seeds and a service user JWT suitable for:
|
||||
# LOCALAI_NATS_ACCOUNT_SEED
|
||||
# LOCALAI_NATS_SERVICE_JWT
|
||||
#
|
||||
# Per-node worker JWTs are minted automatically by the frontend at registration
|
||||
# when LOCALAI_NATS_ACCOUNT_SEED is set.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if ! command -v nsc >/dev/null 2>&1; then
|
||||
echo "nsc is required. Install from https://github.com/nats-io/nsc/releases" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
OPERATOR="${NATS_OPERATOR_NAME:-localai-operator}"
|
||||
ACCOUNT="${NATS_ACCOUNT_NAME:-localai}"
|
||||
SERVICE_USER="${NATS_SERVICE_USER:-localai-frontend}"
|
||||
|
||||
nsc add operator -n "$OPERATOR" --generate-signing-key
|
||||
nsc add account -n "$ACCOUNT"
|
||||
nsc add user -n "$SERVICE_USER" --account "$ACCOUNT"
|
||||
|
||||
# Broad publish for frontend control plane (tighten with custom claims in production).
|
||||
nsc edit user -n "$SERVICE_USER" --account "$ACCOUNT" \
|
||||
--allow-pub "nodes.>,gallery.>,agent.>,jobs.>,mcp.>,cache.>,prefixcache.>,finetune.>" \
|
||||
--allow-sub "nodes.>,gallery.>,agent.>,jobs.>,mcp.>,cache.>,prefixcache.>,_INBOX.>"
|
||||
|
||||
KEYS_DIR="${NATS_KEYS_DIR:-./nats-keys}"
|
||||
mkdir -p "$KEYS_DIR"
|
||||
nsc generate creds -a "$ACCOUNT" -n "$SERVICE_USER" -o "$KEYS_DIR"
|
||||
|
||||
ACCOUNT_SEED=$(nsc describe account "$ACCOUNT" -o json | jq -r '.nats.private_key')
|
||||
SERVICE_JWT=$(cat "$KEYS_DIR/${ACCOUNT}/${SERVICE_USER}.jwt" 2>/dev/null || cat "$KEYS_DIR/${SERVICE_USER}.jwt")
|
||||
|
||||
echo ""
|
||||
echo "=== LocalAI NATS auth material ==="
|
||||
echo "LOCALAI_NATS_ACCOUNT_SEED=${ACCOUNT_SEED}"
|
||||
echo "LOCALAI_NATS_SERVICE_JWT=${SERVICE_JWT}"
|
||||
echo ""
|
||||
echo "Configure the NATS server with the generated operator/account JWTs under $KEYS_DIR"
|
||||
echo "and set LOCALAI_NATS_REQUIRE_AUTH=true on frontends and workers in production."
|
||||
@@ -5897,6 +5897,10 @@ const docTemplate = `{
|
||||
"description": "text input",
|
||||
"type": "string"
|
||||
},
|
||||
"instructions": {
|
||||
"description": "Instructions is a free-form, per-request style/voice description. It maps to\nthe OpenAI ` + "`" + `instructions` + "`" + ` field and is forwarded to the backend so expressive\nTTS models (e.g. Qwen3-TTS CustomVoice/VoiceDesign) can vary tone or designed\nvoice per request instead of only via the static YAML option.",
|
||||
"type": "string"
|
||||
},
|
||||
"language": {
|
||||
"description": "(optional) language to use with TTS model",
|
||||
"type": "string"
|
||||
@@ -5904,6 +5908,13 @@ const docTemplate = `{
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"params": {
|
||||
"description": "Params carries optional, backend-specific per-request generation parameters\n(LocalAI extension, e.g. Chatterbox exaggeration/cfg_weight/temperature).",
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"response_format": {
|
||||
"description": "(optional) output format",
|
||||
"type": "string"
|
||||
|
||||
@@ -5894,6 +5894,10 @@
|
||||
"description": "text input",
|
||||
"type": "string"
|
||||
},
|
||||
"instructions": {
|
||||
"description": "Instructions is a free-form, per-request style/voice description. It maps to\nthe OpenAI `instructions` field and is forwarded to the backend so expressive\nTTS models (e.g. Qwen3-TTS CustomVoice/VoiceDesign) can vary tone or designed\nvoice per request instead of only via the static YAML option.",
|
||||
"type": "string"
|
||||
},
|
||||
"language": {
|
||||
"description": "(optional) language to use with TTS model",
|
||||
"type": "string"
|
||||
@@ -5901,6 +5905,13 @@
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"params": {
|
||||
"description": "Params carries optional, backend-specific per-request generation parameters\n(LocalAI extension, e.g. Chatterbox exaggeration/cfg_weight/temperature).",
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"response_format": {
|
||||
"description": "(optional) output format",
|
||||
"type": "string"
|
||||
|
||||
@@ -1996,11 +1996,25 @@ definitions:
|
||||
input:
|
||||
description: text input
|
||||
type: string
|
||||
instructions:
|
||||
description: |-
|
||||
Instructions is a free-form, per-request style/voice description. It maps to
|
||||
the OpenAI `instructions` field and is forwarded to the backend so expressive
|
||||
TTS models (e.g. Qwen3-TTS CustomVoice/VoiceDesign) can vary tone or designed
|
||||
voice per request instead of only via the static YAML option.
|
||||
type: string
|
||||
language:
|
||||
description: (optional) language to use with TTS model
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
params:
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: |-
|
||||
Params carries optional, backend-specific per-request generation parameters
|
||||
(LocalAI extension, e.g. Chatterbox exaggeration/cfg_weight/temperature).
|
||||
type: object
|
||||
response_format:
|
||||
description: (optional) output format
|
||||
type: string
|
||||
|
||||
156
tests/e2e/distributed/nats_jwt_helpers_test.go
Normal file
156
tests/e2e/distributed/nats_jwt_helpers_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package distributed_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
tcnats "github.com/testcontainers/testcontainers-go/modules/nats"
|
||||
)
|
||||
|
||||
// JWTTestInfra holds a NATS server configured with JWT auth and minted worker credentials.
|
||||
type JWTTestInfra struct {
|
||||
*TestInfra
|
||||
AccountSeed string
|
||||
NodeID string
|
||||
WorkerJWT string
|
||||
WorkerSeed string
|
||||
}
|
||||
|
||||
// SetupJWTInfra starts NATS with an in-memory JWT resolver and returns worker credentials
|
||||
// minted the same way as node registration (pkg/natsauth).
|
||||
func SetupJWTInfra() *JWTTestInfra {
|
||||
GinkgoHelper()
|
||||
|
||||
infra := &JWTTestInfra{TestInfra: &TestInfra{Ctx: context.Background()}}
|
||||
|
||||
operatorJWT, accountJWT, accountSeed, err := jwtResolverMaterial()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
infra.AccountSeed = accountSeed
|
||||
|
||||
conf := fmt.Sprintf(`listen: 0.0.0.0:4222
|
||||
|
||||
operator: %s
|
||||
|
||||
resolver: MEMORY
|
||||
resolver_preload: {
|
||||
%s: %s
|
||||
}
|
||||
`, operatorJWT, accountPublicKeyFromSeed(accountSeed), accountJWT)
|
||||
|
||||
var natsContainer *tcnats.NATSContainer
|
||||
// Override default testcontainers -js: JetStream fails without a system account in JWT mode.
|
||||
natsContainer, err = tcnats.Run(infra.Ctx, "nats:2-alpine",
|
||||
tcnats.WithConfigFile(bytes.NewBufferString(conf)),
|
||||
testcontainers.WithCmd("-c", "/etc/nats.conf"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
infra.NATSContainer = natsContainer
|
||||
|
||||
infra.NatsURL, err = infra.NATSContainer.ConnectionString(infra.Ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
infra.NodeID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
cfg := natsauth.Config{AccountSeed: infra.AccountSeed, WorkerJWTTTL: time.Hour}
|
||||
infra.WorkerJWT, infra.WorkerSeed, err = cfg.MintWorkerJWT(infra.NodeID, "backend")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
infra.NC, err = messaging.New(infra.NatsURL, messaging.WithUserJWT(infra.WorkerJWT, infra.WorkerSeed))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
DeferCleanup(func() {
|
||||
if infra.NC != nil {
|
||||
infra.NC.Close()
|
||||
}
|
||||
if infra.NATSContainer != nil {
|
||||
_ = infra.NATSContainer.Terminate(context.Background())
|
||||
}
|
||||
})
|
||||
|
||||
return infra
|
||||
}
|
||||
|
||||
// jwtResolverMaterial builds operator + account JWTs for a MEMORY resolver.
|
||||
// Follows the NATS JWT tutorial: self-signed account, then operator re-sign, with the
|
||||
// account identity key listed as a signing key so MintWorkerJWT can use the account seed.
|
||||
func jwtResolverMaterial() (operatorJWT, accountJWT, accountSeed string, err error) {
|
||||
okp, err := nkeys.CreateOperator()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
opk, err := okp.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
oc := jwt.NewOperatorClaims(opk)
|
||||
oc.Name = "localai-test-operator"
|
||||
oskp, err := nkeys.CreateOperator()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
ospk, err := oskp.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
oc.SigningKeys.Add(ospk)
|
||||
operatorJWT, err = oc.Encode(okp)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
akp, err := nkeys.CreateAccount()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
seed, err := akp.Seed()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
accountSeed = string(seed)
|
||||
|
||||
apk, err := akp.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
ac := jwt.NewAccountClaims(apk)
|
||||
ac.Name = "localai-test-account"
|
||||
ac.SigningKeys.Add(apk)
|
||||
accountJWT, err = ac.Encode(akp)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
ac, err = jwt.DecodeAccountClaims(accountJWT)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
accountJWT, err = ac.Encode(oskp)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
return operatorJWT, accountJWT, accountSeed, nil
|
||||
}
|
||||
|
||||
func accountPublicKeyFromSeed(accountSeed string) string {
|
||||
akp, err := nkeys.FromSeed([]byte(accountSeed))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
pk, err := akp.PublicKey()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return pk
|
||||
}
|
||||
|
||||
// nodeSubjectPrefix returns the sanitized nodes.* prefix for a node ID.
|
||||
func nodeSubjectPrefix(nodeID string) string {
|
||||
tok := strings.NewReplacer(".", "-", "*", "-", ">", "-", " ", "-", "\t", "-", "\n", "-").Replace(nodeID)
|
||||
return "nodes." + tok
|
||||
}
|
||||
99
tests/e2e/distributed/nats_jwt_test.go
Normal file
99
tests/e2e/distributed/nats_jwt_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package distributed_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("NATS JWT Auth", Label("Distributed", "NatsJWT"), func() {
|
||||
var infra *JWTTestInfra
|
||||
|
||||
BeforeEach(func() {
|
||||
infra = SetupJWTInfra()
|
||||
})
|
||||
|
||||
It("connects with a minted backend worker JWT and publishes on allowed subjects", func() {
|
||||
// Backend workers may publish under nodes.<id>.files.> (see pkg/natsauth permissions).
|
||||
subject := nodeSubjectPrefix(infra.NodeID) + ".files.in"
|
||||
Expect(infra.NC.Publish(subject, map[string]string{"path": "/tmp/model"})).To(Succeed())
|
||||
Expect(infra.NC.Conn().FlushTimeout(2 * time.Second)).To(Succeed())
|
||||
Expect(infra.NC.Conn().IsConnected()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("allows backend subscribe on the node prefix", func() {
|
||||
wild := nodeSubjectPrefix(infra.NodeID) + ".>"
|
||||
sub, err := infra.NC.Subscribe(wild, func(_ []byte) {})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer func() { _ = sub.Unsubscribe() }()
|
||||
Expect(infra.NC.Conn().FlushTimeout(2 * time.Second)).To(Succeed())
|
||||
Expect(infra.NC.Conn().IsConnected()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects anonymous publish on the JWT-enabled server", func() {
|
||||
anon, err := messaging.New(infra.NatsURL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer anon.Close()
|
||||
|
||||
err = anon.Publish("nodes.any.files.x", map[string]string{"x": "1"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(anon.Conn().FlushTimeout(2 * time.Second)).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("denies backend publish to another node's subjects", func() {
|
||||
other := nodeSubjectPrefix("other-node-id") + ".files.stage"
|
||||
Expect(infra.NC.Publish(other, map[string]string{"stage": "nope"})).To(Succeed())
|
||||
Eventually(func() error {
|
||||
_ = infra.NC.Conn().FlushTimeout(500 * time.Millisecond)
|
||||
return infra.NC.Conn().LastError()
|
||||
}, "3s", "50ms").Should(HaveOccurred())
|
||||
})
|
||||
|
||||
It("mints agent JWT without backend.install in claims", func() {
|
||||
cfg := natsauth.Config{AccountSeed: infra.AccountSeed}
|
||||
token, _, err := cfg.MintWorkerJWT("agent-node-1", "agent")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
claims, err := natsauth.DecodeUserClaims(token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(claims.Permissions.Sub.Allow).To(ContainElement("agent.execute"))
|
||||
for _, subj := range claims.Permissions.Sub.Allow {
|
||||
Expect(subj).NotTo(ContainSubstring("backend.install"))
|
||||
}
|
||||
})
|
||||
|
||||
// Regression guard for the silent permission gaps: decoding the JWT claims
|
||||
// (above) only proves the agent JWT is *restrictive*, not that it is
|
||||
// *sufficient*. Stand a real agent connection up against the enforcing
|
||||
// server and exercise every subscription core/cli/agent_worker.go actually
|
||||
// makes — a denied SUB now surfaces synchronously via confirmSubscription,
|
||||
// so a missing allow rule fails this test instead of silently dropping
|
||||
// backend.stop / MCP-CI deliveries at runtime.
|
||||
It("lets an agent-minted JWT establish all the subscriptions the agent worker uses", func() {
|
||||
const nodeID = "agent-node-subs"
|
||||
cfg := natsauth.Config{AccountSeed: infra.AccountSeed, WorkerJWTTTL: time.Hour}
|
||||
token, seed, err := cfg.MintWorkerJWT(nodeID, "agent")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
nc, err := messaging.New(infra.NatsURL, messaging.WithUserJWT(token, seed))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
DeferCleanup(nc.Close)
|
||||
|
||||
// Mirror core/cli/agent_worker.go exactly.
|
||||
_, err = nc.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func([]byte, func([]byte)) {})
|
||||
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s", messaging.SubjectMCPToolExecute)
|
||||
|
||||
_, err = nc.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func([]byte, func([]byte)) {})
|
||||
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s", messaging.SubjectMCPDiscovery)
|
||||
|
||||
_, err = nc.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func([]byte) {})
|
||||
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s (MCP CI jobs)", messaging.SubjectMCPCIJobsNew)
|
||||
|
||||
_, err = nc.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func([]byte) {})
|
||||
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s (MCP session cleanup)", messaging.SubjectNodeBackendStop(nodeID))
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user