mirror of
https://github.com/mudler/LocalAI.git
synced 2026-07-03 04:46:54 -04:00
Compare commits
65 Commits
worktree-f
...
fix/mlx-to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bf73a7e22 | ||
|
|
715d4ed8e5 | ||
|
|
9fcc9c0d43 | ||
|
|
3c67b5b746 | ||
|
|
bea66fd84e | ||
|
|
f7a5dfd5ae | ||
|
|
6bcaf30c14 | ||
|
|
ef15b4bfda | ||
|
|
237bce48e8 | ||
|
|
a4e6e01e4d | ||
|
|
6eea3ef2ac | ||
|
|
ad97bcbbdd | ||
|
|
9d8ff90941 | ||
|
|
29001a88c1 | ||
|
|
b0bfa0852e | ||
|
|
39a93e91cf | ||
|
|
26e0c98967 | ||
|
|
9acca54b25 | ||
|
|
2728e6000e | ||
|
|
006310d746 | ||
|
|
05acdb1778 | ||
|
|
5e68b5700c | ||
|
|
7910018249 | ||
|
|
1a03712a6f | ||
|
|
703ea32de6 | ||
|
|
751db06e35 | ||
|
|
f46c0e9c83 | ||
|
|
0d8adfc59a | ||
|
|
43f2615e19 | ||
|
|
875c539ad5 | ||
|
|
d641ded194 | ||
|
|
40445fff05 | ||
|
|
057dee956a | ||
|
|
4ec39bb776 | ||
|
|
25ecb9f015 | ||
|
|
2be495f9c0 | ||
|
|
02b007a31e | ||
|
|
fd8cebd0b3 | ||
|
|
dd625921ff | ||
|
|
d74f88357e | ||
|
|
dfaec3bd51 | ||
|
|
0e381897b5 | ||
|
|
b1af37257d | ||
|
|
ebefa6dcca | ||
|
|
605348925d | ||
|
|
686ce10b54 | ||
|
|
2cee318fad | ||
|
|
1a4f68ed4a | ||
|
|
28d7397743 | ||
|
|
5d0c43ec6e | ||
|
|
6ab29ec8b9 | ||
|
|
036f950b1b | ||
|
|
5b7b914b4f | ||
|
|
d1cee4c52a | ||
|
|
baaa0fe94f | ||
|
|
c3b5c7c3fa | ||
|
|
bd1ec8f2c2 | ||
|
|
135debf9af | ||
|
|
e8c18ae28e | ||
|
|
c4d302e1ab | ||
|
|
323b57a4bc | ||
|
|
3d2f639213 | ||
|
|
be1ae9338b | ||
|
|
923c47020d | ||
|
|
b7a1dec773 |
@@ -7,8 +7,11 @@
|
||||
# Runs only the checks relevant to what's staged:
|
||||
# - Go files -> make lint + make test-coverage-check
|
||||
# - core/http/react-ui -> make test-ui-coverage-check (Playwright e2e + gate)
|
||||
# A commit touching neither is skipped entirely (docs/YAML/etc. can't change
|
||||
# lint findings, Go coverage, or the UI).
|
||||
# - realtime state machines / specs -> make test-realtime-conformance
|
||||
# (respcoord/**, turncoord/**, or formal-verification/** -- a pure .fizz
|
||||
# spec edit must still re-verify the design, detected separately from Go)
|
||||
# A commit touching none of these is skipped entirely (other docs/YAML can't
|
||||
# change lint findings, Go coverage, the UI, or the realtime conformance gate).
|
||||
#
|
||||
# To bypass for a single commit (e.g. a WIP checkpoint): git commit --no-verify
|
||||
set -eu
|
||||
@@ -20,11 +23,13 @@ staged="$(git diff --cached --name-only --diff-filter=ACMRD)"
|
||||
|
||||
go_changed=0
|
||||
ui_changed=0
|
||||
rt_changed=0
|
||||
if echo "$staged" | grep -qE '\.go$'; then go_changed=1; fi
|
||||
if echo "$staged" | grep -qE '^core/http/react-ui/'; then ui_changed=1; fi
|
||||
if echo "$staged" | grep -qE '^(core/http/endpoints/openai/(coordinator|respcoord|turncoord|conncoord|compactcoord|ttscoord)/|formal-verification/)'; then rt_changed=1; fi
|
||||
|
||||
if [ "$go_changed" -eq 0 ] && [ "$ui_changed" -eq 0 ]; then
|
||||
echo "pre-commit: no Go or React UI changes staged — skipping."
|
||||
if [ "$go_changed" -eq 0 ] && [ "$ui_changed" -eq 0 ] && [ "$rt_changed" -eq 0 ]; then
|
||||
echo "pre-commit: no Go, React UI, or realtime-spec changes staged — skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
@@ -57,4 +62,11 @@ if [ "$ui_changed" -eq 1 ]; then
|
||||
make test-ui-coverage-check
|
||||
fi
|
||||
|
||||
if [ "$rt_changed" -eq 1 ]; then
|
||||
echo "pre-commit ▶ realtime state-machine conformance (make test-realtime-conformance) —"
|
||||
echo " Go transition/rapid tests under -race + FizzBee model check of the"
|
||||
echo " authoritative specs. Fail-closed: needs FizzBee (make install-fizzbee)."
|
||||
make test-realtime-conformance
|
||||
fi
|
||||
|
||||
echo "pre-commit ✓ all relevant checks passed"
|
||||
|
||||
12
.github/workflows/backend_build_darwin.yml
vendored
12
.github/workflows/backend_build_darwin.yml
vendored
@@ -82,7 +82,7 @@ jobs:
|
||||
# as the Linux registry cache.
|
||||
- name: Restore Homebrew cache
|
||||
id: brew-cache
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v6
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/Homebrew/downloads
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
|
||||
- name: Save Homebrew cache
|
||||
if: github.event_name != 'pull_request' && steps.brew-cache.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
uses: actions/cache/save@v6
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/Homebrew/downloads
|
||||
@@ -178,7 +178,7 @@ jobs:
|
||||
- name: Restore ccache
|
||||
if: inputs.backend == 'llama-cpp'
|
||||
id: ccache-cache
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v6
|
||||
with:
|
||||
path: ~/Library/Caches/ccache
|
||||
key: ccache-llama-${{ runner.arch }}-${{ steps.llama-version.outputs.version }}-${{ github.run_id }}
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
- name: Restore Python wheel cache
|
||||
if: inputs.lang == 'python'
|
||||
id: pyenv-cache
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v6
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/pip
|
||||
@@ -256,14 +256,14 @@ jobs:
|
||||
|
||||
- name: Save ccache
|
||||
if: inputs.backend == 'llama-cpp' && github.event_name != 'pull_request'
|
||||
uses: actions/cache/save@v4
|
||||
uses: actions/cache/save@v6
|
||||
with:
|
||||
path: ~/Library/Caches/ccache
|
||||
key: ccache-llama-${{ runner.arch }}-${{ steps.llama-version.outputs.version }}-${{ github.run_id }}
|
||||
|
||||
- name: Save Python wheel cache
|
||||
if: inputs.lang == 'python' && github.event_name != 'pull_request' && steps.pyenv-cache.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
uses: actions/cache/save@v6
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/pip
|
||||
|
||||
69
.github/workflows/realtime-conformance.yml
vendored
Normal file
69
.github/workflows/realtime-conformance.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
name: 'realtime-conformance'
|
||||
|
||||
# Verifies the realtime state-machine implementations conform to their formal
|
||||
# designs (docs/design/realtime-state-machines.md, formal-verification/). BOTH
|
||||
# layers are enforced and the gate is fail-closed: the Go conformance layer
|
||||
# (respcoord + turncoord transition/rapid tests under -race) AND the FizzBee model check of
|
||||
# the authoritative specs. FizzBee is pinned + checksum-verified
|
||||
# (formal-verification/fizzbee.sha256), so a failed install fails the job rather
|
||||
# than silently skipping verification.
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'core/http/endpoints/openai/coordinator/**'
|
||||
- 'core/http/endpoints/openai/respcoord/**'
|
||||
- 'core/http/endpoints/openai/turncoord/**'
|
||||
- 'core/http/endpoints/openai/conncoord/**'
|
||||
- 'core/http/endpoints/openai/compactcoord/**'
|
||||
- 'core/http/endpoints/openai/ttscoord/**'
|
||||
- 'formal-verification/**'
|
||||
- 'scripts/realtime-conformance.sh'
|
||||
- 'scripts/install-fizzbee.sh'
|
||||
- '.github/workflows/realtime-conformance.yml'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- 'core/http/endpoints/openai/coordinator/**'
|
||||
- 'core/http/endpoints/openai/respcoord/**'
|
||||
- 'core/http/endpoints/openai/turncoord/**'
|
||||
- 'core/http/endpoints/openai/conncoord/**'
|
||||
- 'core/http/endpoints/openai/compactcoord/**'
|
||||
- 'core/http/endpoints/openai/ttscoord/**'
|
||||
- 'formal-verification/**'
|
||||
- 'scripts/realtime-conformance.sh'
|
||||
|
||||
concurrency:
|
||||
group: realtime-conformance-${{ github.event.pull_request.number || github.sha }}-${{ github.repository }}
|
||||
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
|
||||
|
||||
jobs:
|
||||
conformance:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.26.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v7
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: false
|
||||
- name: Cache FizzBee
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: .tools/fizzbee
|
||||
key: fizzbee-v0.5.2-${{ runner.os }}-${{ hashFiles('formal-verification/fizzbee.sha256') }}
|
||||
- name: Install FizzBee (pinned, checksum-verified)
|
||||
# No `|| true`: a failed/forged download must fail the job, not silently
|
||||
# drop the design verification. install-fizzbee.sh is a no-op if the
|
||||
# cached binary is already present and valid.
|
||||
run: ./scripts/install-fizzbee.sh
|
||||
- name: Run conformance gate (fail-closed)
|
||||
# No skip env: both the Go conformance and the FizzBee model check are
|
||||
# required. The gate auto-detects .tools/fizzbee/fizz.
|
||||
run: make test-realtime-conformance
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -97,3 +97,12 @@ core/http/react-ui/test-results/
|
||||
|
||||
# Local Apple signing material (never commit)
|
||||
.certs/
|
||||
|
||||
# Pinned dev tools (e.g. FizzBee for the realtime-conformance gate)
|
||||
.tools/
|
||||
|
||||
# FizzBee model-check artifacts: the parser emits <spec>.json next to each
|
||||
# .fizz and the checker writes run dirs under out/. Both are regenerated by
|
||||
# the realtime-conformance gate; only the .fizz sources are authoritative.
|
||||
formal-verification/*.json
|
||||
formal-verification/out/
|
||||
|
||||
11
Dockerfile
11
Dockerfile
@@ -171,6 +171,17 @@ RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \
|
||||
ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \
|
||||
; fi
|
||||
|
||||
# ROCm's bundled libdrm_amdgpu is built with a hardcoded fallback lookup path
|
||||
# for the ASIC ID table (/opt/amdgpu/share/libdrm/amdgpu.ids), which only exists
|
||||
# if AMD's full amdgpu graphics/DKMS stack is installed. This compute-only image
|
||||
# doesn't have it, so hipblas/rocBLAS log "No such file or directory" on every
|
||||
# model load and can fail to identify the GPU. Point it at the equivalent file
|
||||
# Ubuntu's libdrm-common package already ships.
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ -f /usr/share/libdrm/amdgpu.ids ] && [ ! -e /opt/amdgpu/share/libdrm/amdgpu.ids ]; then \
|
||||
mkdir -p /opt/amdgpu/share/libdrm && \
|
||||
ln -s /usr/share/libdrm/amdgpu.ids /opt/amdgpu/share/libdrm/amdgpu.ids \
|
||||
; fi
|
||||
|
||||
RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel"
|
||||
|
||||
# Cuda
|
||||
|
||||
23
Makefile
23
Makefile
@@ -405,6 +405,18 @@ test-realtime: build-mock-backend
|
||||
@echo 'Running realtime e2e tests (mock backend)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime && !real-models" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
|
||||
# Verify the realtime state-machine implementations conform to their formal
|
||||
# designs (Go transition/rapid tests under -race + FizzBee model check of the
|
||||
# authoritative specs). See docs/design/realtime-state-machines.md (Part 6) and
|
||||
# docs/design/specs/README.md.
|
||||
test-realtime-conformance:
|
||||
GOCMD=$(GOCMD) ./scripts/realtime-conformance.sh
|
||||
|
||||
# Install the pinned, checksum-verified FizzBee model checker (into .tools/,
|
||||
# gitignored) used by test-realtime-conformance. Idempotent; no-op if present.
|
||||
install-fizzbee:
|
||||
./scripts/install-fizzbee.sh
|
||||
|
||||
# Container-based real-model realtime testing. Build env vars / pipeline
|
||||
# definition kept here so test-realtime-models-docker can drive a fully wired
|
||||
# pipeline (VAD + STT + LLM + TTS) from inside a containerised runner.
|
||||
@@ -1027,7 +1039,7 @@ test-extra-backend-whisper-transcription: docker-build-whisper
|
||||
## is reachable.
|
||||
test-extra-backend-parakeet-cpp-transcription: docker-build-parakeet-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:parakeet-cpp \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/tdt_ctc-110m-f16.gguf \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/realtime_eou_120m-v1-f16.gguf \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
@@ -1470,8 +1482,13 @@ build-launcher-darwin:
|
||||
mv cmd/launcher/LocalAI.app dist/LocalAI.app
|
||||
bash contrib/macos/sign-and-notarize.sh sign dist/LocalAI.app
|
||||
|
||||
# Wrap the (signed) app into a drag-to-Applications DMG via hdiutil, then sign the DMG.
|
||||
# Notarize + staple the .app itself, then wrap it into a drag-to-Applications
|
||||
# DMG via hdiutil and sign the DMG. The app is stapled BEFORE packaging so the
|
||||
# bundle carries its own ticket and verifies offline (a dmg-only staple leaves
|
||||
# the app relying on an online Gatekeeper check, which fails offline / once the
|
||||
# app is copied out of the dmg). No-op without notary secrets.
|
||||
dmg-launcher-darwin: build-launcher-darwin
|
||||
bash contrib/macos/sign-and-notarize.sh notarize-app dist/LocalAI.app
|
||||
rm -rf dist/dmg dist/LocalAI.dmg
|
||||
mkdir -p dist/dmg
|
||||
cp -R dist/LocalAI.app dist/dmg/LocalAI.app
|
||||
@@ -1483,7 +1500,7 @@ dmg-launcher-darwin: build-launcher-darwin
|
||||
notarize-launcher-darwin: dmg-launcher-darwin
|
||||
bash contrib/macos/sign-and-notarize.sh notarize dist/LocalAI.dmg
|
||||
|
||||
# Single entrypoint for CI: build -> sign app -> dmg -> sign dmg -> notarize -> staple.
|
||||
# Single entrypoint for CI: build -> sign app -> notarize+staple app -> dmg -> sign dmg -> notarize+staple dmg.
|
||||
release-launcher-darwin: notarize-launcher-darwin
|
||||
@echo "dist/LocalAI.dmg is ready"
|
||||
|
||||
|
||||
@@ -18,6 +18,18 @@ service Backend {
|
||||
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
|
||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||
rpc AudioTranscriptionStream(TranscriptRequest) returns (stream TranscriptStreamResponse) {}
|
||||
// AudioTranscriptionLive is the bidirectional live-microphone ASR RPC. The
|
||||
// first message MUST carry a Config; subsequent messages carry Audio frames
|
||||
// (mono float PCM at config.sample_rate, 16 kHz default). After a
|
||||
// successful open the backend replies with a single ready ack
|
||||
// (TranscriptLiveResponse{ready:true}); backends or models without
|
||||
// cache-aware streaming support return UNIMPLEMENTED instead. Newly
|
||||
// finalized text streams back as deltas; eou=true marks the model's
|
||||
// end-of-utterance token. One stream spans many utterances (the decoder
|
||||
// resets itself after each EOU). Closing the send side finalizes: the
|
||||
// backend flushes the decoder tail and emits a terminal message carrying
|
||||
// final_result. A second Config mid-stream resets the decode session.
|
||||
rpc AudioTranscriptionLive(stream TranscriptLiveRequest) returns (stream TranscriptLiveResponse) {}
|
||||
rpc TTS(TTSRequest) returns (Result) {}
|
||||
rpc TTSStream(TTSRequest) returns (stream Reply) {}
|
||||
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||
@@ -479,6 +491,10 @@ message TranscriptResult {
|
||||
string text = 2;
|
||||
string language = 3;
|
||||
float duration = 4;
|
||||
// True when the decode ended on the model's end-of-utterance special token
|
||||
// (<EOU>/<EOB>, emitted by cache-aware streaming models such as
|
||||
// parakeet_realtime_eou_120m-v1). The marker itself is stripped from text.
|
||||
bool eou = 5;
|
||||
}
|
||||
|
||||
message TranscriptStreamResponse {
|
||||
@@ -486,6 +502,34 @@ message TranscriptStreamResponse {
|
||||
TranscriptResult final_result = 2;
|
||||
}
|
||||
|
||||
// === AudioTranscriptionLive messages =====================================
|
||||
|
||||
message TranscriptLiveRequest {
|
||||
oneof payload {
|
||||
TranscriptLiveConfig config = 1;
|
||||
TranscriptLiveAudio audio = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message TranscriptLiveConfig {
|
||||
string language = 1; // "" => model default
|
||||
int32 sample_rate = 2; // 0 => 16000; backends may reject others
|
||||
map<string, string> params = 3; // backend-specific tuning
|
||||
}
|
||||
|
||||
message TranscriptLiveAudio {
|
||||
repeated float pcm = 1; // mono PCM in [-1,1] at config.sample_rate
|
||||
}
|
||||
|
||||
message TranscriptLiveResponse {
|
||||
bool ready = 1; // open ack: sent once, before any delta
|
||||
string delta = 2; // newly-finalized text since previous response
|
||||
bool eou = 3; // <EOU> fired during this feed (the user yielded the turn)
|
||||
repeated TranscriptWord words = 4; // words finalized by this feed (stream-relative ns)
|
||||
TranscriptResult final_result = 5; // terminal message only, after the send side closes
|
||||
bool eob = 6; // <EOB> fired: a backchannel ("uh-huh") ended — NOT a turn boundary
|
||||
}
|
||||
|
||||
message TranscriptWord {
|
||||
int64 start = 1;
|
||||
int64 end = 2;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=f96eaddba8bed6a9a5e628bbf6a566775c70b49c
|
||||
IK_LLAMA_VERSION?=87fc8701ff4da81a7d2a91ec0695f95eb3066a47
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -101,4 +101,13 @@ if(LLAMA_GRPC_BUILD_TESTS)
|
||||
target_link_libraries(message_content_test PRIVATE ${_LLAMA_COMMON_TARGET})
|
||||
target_compile_features(message_content_test PRIVATE cxx_std_17)
|
||||
add_test(NAME message_content_test COMMAND message_content_test)
|
||||
|
||||
# Parent-death watcher test (parent_watch.h) — standard library only, but
|
||||
# needs a threading runtime for std::thread.
|
||||
find_package(Threads REQUIRED)
|
||||
add_executable(parent_watch_test parent_watch_test.cpp parent_watch.h)
|
||||
target_include_directories(parent_watch_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_link_libraries(parent_watch_test PRIVATE Threads::Threads)
|
||||
target_compile_features(parent_watch_test PRIVATE cxx_std_17)
|
||||
add_test(NAME parent_watch_test COMMAND parent_watch_test)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=0ed235ea2c17a19fc8238668653946721ed136fd
|
||||
LLAMA_VERSION?=fdb1db877c526ec90f668eca1b858da5dba85560
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -75,6 +75,8 @@
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
#include "parent_watch.h" // best-effort parent-death backstop (see header)
|
||||
|
||||
|
||||
using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
@@ -3442,6 +3444,10 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
// Best-effort backstop: self-terminate if the LocalAI process that spawned
|
||||
// us dies without cleaning us up (see parent_watch.h).
|
||||
llama_grpc::start_parent_death_watcher();
|
||||
|
||||
server_context ctx_server;
|
||||
BackendServiceImpl service(ctx_server);
|
||||
|
||||
|
||||
179
backend/cpp/llama-cpp/parent_watch.h
Normal file
179
backend/cpp/llama-cpp/parent_watch.h
Normal file
@@ -0,0 +1,179 @@
|
||||
// Parent-death watcher (best-effort backstop) for the llama.cpp gRPC backend.
|
||||
//
|
||||
// LocalAI spawns this backend as a child process and, on a clean shutdown,
|
||||
// tears it down itself (SIGTERM -> grace -> SIGKILL). That graceful path only
|
||||
// runs when LocalAI receives a catchable signal and lives long enough to run
|
||||
// its handlers. If LocalAI is SIGKILLed (e.g. a supervising process's grace
|
||||
// period elapses first), that teardown never runs and this backend would be
|
||||
// reparented to init and linger, holding VRAM and its listen port.
|
||||
//
|
||||
// The watcher here is a best-effort backstop for exactly that case: it does
|
||||
// NOT replace the graceful teardown, it only covers the "parent vanished
|
||||
// without cleaning up" path. It detects reparenting: when the process that
|
||||
// spawned this backend dies, the kernel reparents us to the nearest sub-reaper
|
||||
// or to init (PID 1), so getppid() stops matching the value captured at
|
||||
// startup. This getppid() approach is portable across Linux/macOS (unlike the
|
||||
// Linux-only PR_SET_PDEATHSIG), which is why it is used here, mirroring the Go
|
||||
// backends' pkg/grpc/parentwatch.go. It is disabled on Windows, which has no
|
||||
// equivalent orphan-reparenting semantics.
|
||||
//
|
||||
// This header is intentionally dependency-free (C++ standard library only) so
|
||||
// it can be exercised by a standalone unit test (parent_watch_test.cpp) without
|
||||
// building the full llama.cpp + gRPC backend.
|
||||
#ifndef LLAMA_GRPC_PARENT_WATCH_H
|
||||
#define LLAMA_GRPC_PARENT_WATCH_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#if !defined(_WIN32)
|
||||
#include <unistd.h> // getppid(2), _exit(2)
|
||||
#endif
|
||||
|
||||
namespace llama_grpc {
|
||||
|
||||
// Env var names are shared verbatim with the Go and Python backends for
|
||||
// consistency across languages.
|
||||
inline const char *kEnvParentWatch() { return "LOCALAI_BACKEND_PARENT_WATCH"; }
|
||||
inline const char *kEnvParentWatchInterval() { return "LOCALAI_BACKEND_PARENT_WATCH_INTERVAL"; }
|
||||
|
||||
// Default poll interval in milliseconds. Matches the Go side's 2 * time.Second.
|
||||
inline long parent_watch_default_interval_ms() { return 2000; }
|
||||
|
||||
namespace detail {
|
||||
inline std::string trim_lower(const std::string &in, bool lower) {
|
||||
size_t a = in.find_first_not_of(" \t\r\n");
|
||||
size_t b = in.find_last_not_of(" \t\r\n");
|
||||
if (a == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
std::string s = in.substr(a, b - a + 1);
|
||||
if (lower) {
|
||||
std::transform(s.begin(), s.end(), s.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
}
|
||||
return s;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// parent_watch_enabled reports whether the watcher should run. Enabled by
|
||||
// default; a falsey value ("false"/"0"/"no"/"off", case-insensitive) disables
|
||||
// it, matching the Go implementation's exact semantics.
|
||||
inline bool parent_watch_enabled() {
|
||||
#if defined(_WIN32)
|
||||
return false;
|
||||
#else
|
||||
const char *v = std::getenv(kEnvParentWatch());
|
||||
if (v == nullptr || v[0] == '\0') {
|
||||
return true;
|
||||
}
|
||||
const std::string s = detail::trim_lower(v, true);
|
||||
return !(s == "false" || s == "0" || s == "no" || s == "off");
|
||||
#endif
|
||||
}
|
||||
|
||||
// parent_watch_interval_ms returns the poll interval in milliseconds. Accepts
|
||||
// Go-style duration strings ("500ms", "2s", "1m") for cross-language parity, or
|
||||
// a bare number interpreted as seconds. Defaults to
|
||||
// parent_watch_default_interval_ms().
|
||||
inline long parent_watch_interval_ms() {
|
||||
const long def = parent_watch_default_interval_ms();
|
||||
const char *v = std::getenv(kEnvParentWatchInterval());
|
||||
if (v == nullptr || v[0] == '\0') {
|
||||
return def;
|
||||
}
|
||||
const std::string s = detail::trim_lower(v, false);
|
||||
if (s.empty()) {
|
||||
return def;
|
||||
}
|
||||
size_t i = 0;
|
||||
while (i < s.size() && (std::isdigit((unsigned char)s[i]) || s[i] == '.')) {
|
||||
i++;
|
||||
}
|
||||
if (i == 0) {
|
||||
return def;
|
||||
}
|
||||
double num = 0.0;
|
||||
try {
|
||||
num = std::stod(s.substr(0, i));
|
||||
} catch (...) {
|
||||
return def;
|
||||
}
|
||||
const std::string unit = s.substr(i);
|
||||
long ms;
|
||||
if (unit == "ms") {
|
||||
ms = (long)num;
|
||||
} else if (unit == "s" || unit.empty()) {
|
||||
ms = (long)(num * 1000.0);
|
||||
} else if (unit == "m") {
|
||||
ms = (long)(num * 60000.0);
|
||||
} else {
|
||||
return def; // unrecognized unit
|
||||
}
|
||||
return ms > 0 ? ms : def;
|
||||
}
|
||||
|
||||
#if !defined(_WIN32)
|
||||
// parent_died reports whether this process has been reparented away from the
|
||||
// parent it had when the watcher started. Reparenting is the standard POSIX
|
||||
// signal that the original parent (here, the LocalAI process that spawned this
|
||||
// backend) has exited: the orphan is handed to the nearest sub-reaper or to
|
||||
// init (PID 1), so getppid() no longer matches the value captured at startup.
|
||||
inline bool parent_died(pid_t orig_ppid) {
|
||||
const pid_t ppid = getppid();
|
||||
return ppid != orig_ppid || ppid == 1;
|
||||
}
|
||||
|
||||
// watch_parent_death polls until parent_died reports the original parent is
|
||||
// gone, then invokes on_death. It blocks, so run it on its own thread.
|
||||
inline void watch_parent_death(pid_t orig_ppid, long interval_ms,
|
||||
const std::function<void()> &on_death) {
|
||||
for (;;) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(interval_ms));
|
||||
if (parent_died(orig_ppid)) {
|
||||
on_death();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// start_parent_death_watcher installs the best-effort safety net described in
|
||||
// the file header on the calling backend process. It is a no-op when disabled,
|
||||
// on Windows, or when the process is already orphaned at startup
|
||||
// (getppid() <= 1). This is a backstop alongside — never a replacement for —
|
||||
// LocalAI's graceful teardown.
|
||||
inline void start_parent_death_watcher() {
|
||||
#if !defined(_WIN32)
|
||||
if (!parent_watch_enabled()) {
|
||||
return;
|
||||
}
|
||||
const pid_t orig_ppid = getppid();
|
||||
// A parent of 1 (or less) at startup means we were already orphaned (or
|
||||
// launched directly under init) — there is no original parent to watch for.
|
||||
if (orig_ppid <= 1) {
|
||||
return;
|
||||
}
|
||||
const long interval_ms = parent_watch_interval_ms();
|
||||
std::thread([orig_ppid, interval_ms]() {
|
||||
watch_parent_death(orig_ppid, interval_ms, [orig_ppid]() {
|
||||
fprintf(stderr,
|
||||
"backend parent process (pid %d) exited without stopping "
|
||||
"this backend; self-terminating to avoid orphaning\n",
|
||||
(int)orig_ppid);
|
||||
fflush(stderr);
|
||||
_exit(1);
|
||||
});
|
||||
}).detach();
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace llama_grpc
|
||||
|
||||
#endif // LLAMA_GRPC_PARENT_WATCH_H
|
||||
197
backend/cpp/llama-cpp/parent_watch_test.cpp
Normal file
197
backend/cpp/llama-cpp/parent_watch_test.cpp
Normal file
@@ -0,0 +1,197 @@
|
||||
// Unit tests for the parent-death watcher (parent_watch.h).
|
||||
//
|
||||
// Build & run standalone (C++ standard library only, no nlohmann/json needed):
|
||||
// g++ -std=c++17 -pthread parent_watch_test.cpp -o t && ./t
|
||||
//
|
||||
// The core test (TestDetectsReparent) builds a genuine two-level process tree
|
||||
// (test -> middle -> grandchild), lets the middle process die, and asserts the
|
||||
// grandchild's watch_parent_death detects the reparenting and self-terminates —
|
||||
// mirroring the Go test in pkg/grpc/parentwatch_test.go, but with fork(2).
|
||||
//
|
||||
// On Windows this file compiles to a no-op success (the watcher is unsupported
|
||||
// there), matching parent_watch.h's platform gating.
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
#include "parent_watch.h"
|
||||
|
||||
static int failures = 0;
|
||||
|
||||
static void check(bool ok, const std::string &name) {
|
||||
if (!ok) {
|
||||
failures++;
|
||||
fprintf(stderr, "FAIL: %s\n", name.c_str());
|
||||
} else {
|
||||
fprintf(stderr, "ok: %s\n", name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Env-parsing tests are platform-independent and always run.
|
||||
static void test_env_parsing() {
|
||||
using namespace llama_grpc;
|
||||
|
||||
// Interval: default when unset.
|
||||
unsetenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL");
|
||||
check(parent_watch_interval_ms() == 2000, "interval default 2000ms");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "500ms", 1);
|
||||
check(parent_watch_interval_ms() == 500, "interval 500ms");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "2s", 1);
|
||||
check(parent_watch_interval_ms() == 2000, "interval 2s");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "1m", 1);
|
||||
check(parent_watch_interval_ms() == 60000, "interval 1m");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "3", 1); // bare number -> seconds
|
||||
check(parent_watch_interval_ms() == 3000, "interval bare 3 -> 3000ms");
|
||||
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL", "garbage", 1);
|
||||
check(parent_watch_interval_ms() == 2000, "interval garbage -> default");
|
||||
unsetenv("LOCALAI_BACKEND_PARENT_WATCH_INTERVAL");
|
||||
|
||||
#if !defined(_WIN32)
|
||||
// Enabled semantics (POSIX only; always false on Windows).
|
||||
unsetenv("LOCALAI_BACKEND_PARENT_WATCH");
|
||||
check(parent_watch_enabled(), "enabled by default");
|
||||
|
||||
for (const char *falsey : {"false", "0", "no", "off", "OFF", " False "}) {
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH", falsey, 1);
|
||||
check(!parent_watch_enabled(), std::string("disabled by '") + falsey + "'");
|
||||
}
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH", "true", 1);
|
||||
check(parent_watch_enabled(), "enabled by 'true'");
|
||||
setenv("LOCALAI_BACKEND_PARENT_WATCH", "1", 1);
|
||||
check(parent_watch_enabled(), "enabled by '1'");
|
||||
unsetenv("LOCALAI_BACKEND_PARENT_WATCH");
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !defined(_WIN32)
|
||||
|
||||
#include <atomic>
|
||||
#include <ctime>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/wait.h>
|
||||
#include <unistd.h>
|
||||
|
||||
static bool file_exists(const std::string &p) {
|
||||
struct stat st;
|
||||
return ::stat(p.c_str(), &st) == 0;
|
||||
}
|
||||
|
||||
static bool wait_for_file(const std::string &p, int timeout_ms) {
|
||||
int waited = 0;
|
||||
while (waited < timeout_ms) {
|
||||
if (file_exists(p)) {
|
||||
return true;
|
||||
}
|
||||
usleep(20 * 1000);
|
||||
waited += 20;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void write_file(const std::string &p, const std::string &content) {
|
||||
FILE *f = fopen(p.c_str(), "w");
|
||||
if (f) {
|
||||
fwrite(content.data(), 1, content.size(), f);
|
||||
fclose(f);
|
||||
}
|
||||
}
|
||||
|
||||
// Builds test -> middle -> grandchild via fork(2). The grandchild arms the REAL
|
||||
// watch_parent_death against middle; middle exits, orphaning the grandchild;
|
||||
// the watcher must detect the reparenting and self-terminate.
|
||||
static void test_detects_reparent() {
|
||||
char tmpl[] = "/tmp/parentwatch_test_XXXXXX";
|
||||
char *dir = mkdtemp(tmpl);
|
||||
if (dir == nullptr) {
|
||||
check(false, "mkdtemp");
|
||||
return;
|
||||
}
|
||||
const std::string ready_file = std::string(dir) + "/ready";
|
||||
const std::string exited_file = std::string(dir) + "/exited";
|
||||
|
||||
pid_t middle = fork();
|
||||
if (middle < 0) {
|
||||
check(false, "fork middle");
|
||||
return;
|
||||
}
|
||||
|
||||
if (middle == 0) {
|
||||
// ---- middle process ----
|
||||
pid_t grandchild = fork();
|
||||
if (grandchild < 0) {
|
||||
_exit(4);
|
||||
}
|
||||
if (grandchild == 0) {
|
||||
// ---- grandchild process ----
|
||||
pid_t orig_ppid = getppid(); // == middle
|
||||
std::thread([&]() {
|
||||
llama_grpc::watch_parent_death(orig_ppid, 50 /*ms*/, [&]() {
|
||||
write_file(exited_file, "1");
|
||||
_exit(7);
|
||||
});
|
||||
}).detach();
|
||||
|
||||
// Safety valve: never linger if something goes wrong.
|
||||
std::thread([]() {
|
||||
usleep(30 * 1000 * 1000);
|
||||
_exit(2);
|
||||
}).detach();
|
||||
|
||||
// Signal readiness only after the watcher captured orig_ppid.
|
||||
write_file(ready_file, std::to_string(getpid()));
|
||||
for (;;) {
|
||||
pause();
|
||||
}
|
||||
}
|
||||
// middle: wait until grandchild is ready, then exit to orphan it.
|
||||
if (!wait_for_file(ready_file, 10000)) {
|
||||
_exit(5);
|
||||
}
|
||||
_exit(0);
|
||||
}
|
||||
|
||||
// ---- test (top) process ----
|
||||
int status = 0;
|
||||
waitpid(middle, &status, 0); // reap middle only; grandchild is orphaned
|
||||
|
||||
check(file_exists(ready_file), "grandchild signaled readiness");
|
||||
|
||||
bool detected = wait_for_file(exited_file, 10000);
|
||||
check(detected, "watcher detected parent death and self-terminated");
|
||||
|
||||
// Best-effort cleanup: kill the grandchild if it somehow survived.
|
||||
if (file_exists(ready_file)) {
|
||||
FILE *f = fopen(ready_file.c_str(), "r");
|
||||
if (f) {
|
||||
int pid = 0;
|
||||
if (fscanf(f, "%d", &pid) == 1 && pid > 1) {
|
||||
kill(pid, SIGKILL);
|
||||
}
|
||||
fclose(f);
|
||||
}
|
||||
}
|
||||
unlink(ready_file.c_str());
|
||||
unlink(exited_file.c_str());
|
||||
rmdir(dir);
|
||||
}
|
||||
|
||||
#endif // !_WIN32
|
||||
|
||||
int main() {
|
||||
test_env_parsing();
|
||||
#if !defined(_WIN32)
|
||||
test_detects_reparent();
|
||||
#endif
|
||||
if (failures == 0) {
|
||||
fprintf(stderr, "\nAll parent_watch tests passed.\n");
|
||||
return 0;
|
||||
}
|
||||
fprintf(stderr, "\n%d parent_watch test(s) failed.\n", failures);
|
||||
return 1;
|
||||
}
|
||||
@@ -22,6 +22,10 @@ cp -r grpc-server.cpp llama.cpp/tools/grpc-server/
|
||||
# unit test (compiled only when -DLLAMA_GRPC_BUILD_TESTS=ON).
|
||||
cp -r message_content.h llama.cpp/tools/grpc-server/
|
||||
cp -r message_content_test.cpp llama.cpp/tools/grpc-server/
|
||||
# Parent-death watcher (included by grpc-server.cpp) and its standalone unit
|
||||
# test (run via backend/cpp/run-unit-tests.sh; also buildable under ctest).
|
||||
cp -r parent_watch.h llama.cpp/tools/grpc-server/
|
||||
cp -r parent_watch_test.cpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/tools/grpc-server/
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
# Local development: point at a working checkout instead of cloning, e.g.
|
||||
# make PRIVACY_FILTER_SRC=$HOME/c/privacy-filter.cpp grpc-server
|
||||
|
||||
PRIVACY_FILTER_VERSION?=98f52c5ef2250f207cc6b9a6aef05393a120cb7c
|
||||
PRIVACY_FILTER_VERSION?=735a6c28607ee82afc3a670383f41b55266a3b9a
|
||||
PRIVACY_FILTER_REPO?=https://github.com/localai-org/privacy-filter.cpp
|
||||
PRIVACY_FILTER_SRC?=
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ for test_src in "${tests[@]}"; do
|
||||
name="$(basename "$test_src" .cpp)"
|
||||
bin="$(mktemp -d)/$name"
|
||||
echo "==> $test_src"
|
||||
if ! "$CXX" -std=c++17 -Wall -Wextra \
|
||||
if ! "$CXX" -std=c++17 -Wall -Wextra -pthread \
|
||||
-I"$JSON_INC" -I"$(dirname "$test_src")" \
|
||||
"$test_src" -o "$bin"; then
|
||||
echo "COMPILE FAILED: $test_src" >&2
|
||||
|
||||
@@ -142,19 +142,12 @@ func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream boo
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = anthropicDefaultMaxTokens
|
||||
}
|
||||
// Newer Anthropic models 400 when both temperature and top_p are
|
||||
// set ("`temperature` and `top_p` cannot both be specified for
|
||||
// this model. Please use only one.") even though their docs only
|
||||
// "recommend" picking one. The OpenAI-compatible chat UI almost
|
||||
// always sends both with default values, so prefer temperature
|
||||
// and drop top_p when both are present.
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
} else if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
// Do not forward temperature/top_p. Newer Anthropic reasoning models reject
|
||||
// requests that carry temperature ("`temperature` is deprecated for this
|
||||
// model"), and the OpenAI-compatible clients typically send only the
|
||||
// server-side DEFAULT sampling values rather than user intent — dropping
|
||||
// them loses nothing and lets the upstream apply its own defaults.
|
||||
_ = opts
|
||||
|
||||
req.Tools = convertOpenAITools(opts.GetTools())
|
||||
req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice())
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -75,15 +74,16 @@ func TestPredict_Anthropic_BasicMessages(t *testing.T) {
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
// Anthropic 400s when both temperature and top_p are set; the
|
||||
// translator must prefer temperature and drop top_p.
|
||||
// Newer Anthropic reasoning models reject requests carrying temperature
|
||||
// ("`temperature` is deprecated for this model"); clients typically send
|
||||
// only default sampling values, so the translator forwards neither.
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
// When only top_p is set, it should be forwarded.
|
||||
// Sampling parameters are not forwarded at all — the upstream applies its
|
||||
// own defaults (newest models reject explicit temperature/top_p).
|
||||
func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
@@ -99,11 +99,7 @@ func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
// PredictOptions.TopP is float32 on the wire; the translator widens
|
||||
// to float64 so 0.9 round-trips as 0.8999999761581421… — compare
|
||||
// with a small tolerance rather than exact equality.
|
||||
g.Expect(captured.TopP).NotTo(BeNil())
|
||||
g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6))
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) {
|
||||
|
||||
@@ -30,7 +30,7 @@ type openAIRequest struct {
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int32 `json:"max_tokens,omitempty"`
|
||||
MaxTokens *int32 `json:"max_completion_tokens,omitempty"` // newer OpenAI models reject max_tokens ("use max_completion_tokens instead")
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
@@ -107,14 +107,10 @@ func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool)
|
||||
Tools: parseRawJSON(opts.GetTools()),
|
||||
ToolChoice: parseRawJSON(opts.GetToolChoice()),
|
||||
}
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
}
|
||||
if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
// Do not forward temperature/top_p. Newer OpenAI reasoning models reject
|
||||
// temperature as deprecated, and clients typically send only default
|
||||
// sampling values rather than user intent — let the upstream apply its
|
||||
// own defaults.
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
req.MaxTokens = &n
|
||||
}
|
||||
|
||||
@@ -74,8 +74,9 @@ func TestPredict_OpenAI_BasicChat(t *testing.T) {
|
||||
g.Expect(captured.Messages).To(HaveLen(2))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("system"))
|
||||
g.Expect(captured.Messages[1].Role).To(Equal("user"))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
// Sampling parameters are not forwarded (newest models reject explicit
|
||||
// temperature); token limit is serialized as max_completion_tokens.
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
g.Expect(captured.MaxTokens).NotTo(BeNil())
|
||||
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=6514c9da00b03a2f0f1b49a43fae4f3a01a41844
|
||||
CRISPASR_VERSION?=9a26976a8c8cf5af0afcdd04463cf8ba91e96a54
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# face-detect backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as FACEDETECT_VERSION?=06914b0... (.github/bump_deps.sh
|
||||
# Upstream pin lives below as FACEDETECT_VERSION?=e22260d5d5490b37b021b7f795079f386d553afd
|
||||
# can find and update it - matches the voice-detect / parakeet.cpp / whisper.cpp
|
||||
# convention).
|
||||
#
|
||||
@@ -14,7 +14,7 @@
|
||||
# The default target below does the proper clone-at-pin + cmake build so CI does
|
||||
# not need a side-checkout.
|
||||
|
||||
FACEDETECT_VERSION?=06914b077d52f90d5421299138e7be6bdd06b5e8
|
||||
FACEDETECT_VERSION?=e22260d5d5490b37b021b7f795079f386d553afd
|
||||
FACEDETECT_REPO?=https://github.com/mudler/face-detect.cpp
|
||||
|
||||
GOCMD?=go
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=f469a57270a1cc4554acb15febf60e56619673b9
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=e8acc6172a94e20a952cf1843decace5d771a94b
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
@@ -15,7 +15,7 @@
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=f469a57270a1cc4554acb15febf60e56619673b9
|
||||
PARAKEET_VERSION?=e8acc6172a94e20a952cf1843decace5d771a94b
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
|
||||
81
backend/go/parakeet-cpp/boundary.go
Normal file
81
backend/go/parakeet-cpp/boundary.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package main
|
||||
|
||||
// utteranceBoundary is the single definition of a small state machine that was
|
||||
// previously open-coded three times — as a bare `finalEou` bool with an ad-hoc
|
||||
// toggle — in the live feed (live.go), the file-stream text path, and the
|
||||
// file-stream JSON path (goparakeetcpp.go).
|
||||
//
|
||||
// It answers one running question: does the decode currently rest on an
|
||||
// end-of-utterance boundary? That is the value a closing FinalResult reports as
|
||||
// .Eou and the realtime turn detector treats as a commit point.
|
||||
//
|
||||
// parakeet auto-resets its decoder after every <EOU>/<EOB>, so one streaming
|
||||
// session is a sequence of utterances and this is a LATCH, not a monotonic
|
||||
// flag: it closes on an <EOU> and reopens as soon as the next utterance starts.
|
||||
// (Contrast the realtime API's per-turn `eouSeen`, which only ever goes
|
||||
// false->true because each turn gets a fresh stream. Here the stream outlives
|
||||
// the turn, so the boundary status must be able to reopen.)
|
||||
//
|
||||
// The only transitions, over the events one streamFeedResult carries — an
|
||||
// <EOU>, an <EOB> (backchannel), or plain speech output (text and/or words):
|
||||
//
|
||||
// <EOU>
|
||||
// open ───────────► closed
|
||||
// ▲ ▲ │ │ │
|
||||
// │ └─┘ <EOB>|speech │ │ <EOU>
|
||||
// │ (stay open) │ └─┘ (stay closed)
|
||||
// └──────────────────┘
|
||||
// <EOB>|speech
|
||||
//
|
||||
// open = NOT on an utterance boundary: mid-utterance, the last boundary was
|
||||
// a backchannel <EOB>, or the stream just began (the initial state).
|
||||
// closed = the last meaningful event was an <EOU> with no later speech: a real
|
||||
// turn boundary.
|
||||
//
|
||||
// A feed that carries nothing (no eou/eob/text/words — e.g. a finalize flush
|
||||
// that produced no tail) is a no-op and leaves the state unchanged, matching
|
||||
// the legacy "leave finalEou as it was" behaviour.
|
||||
//
|
||||
// The state carries no data, so it is modelled as a two-valued type (a named
|
||||
// bool) rather than an int enum: every inhabitant is legal, so illegal states
|
||||
// are unrepresentable — the payload-free analog of the sealed sum types the
|
||||
// realtime machines use (those need interfaces because their states carry data,
|
||||
// e.g. Active{ID}, where "Active with no ID" is the illegal combination a scalar
|
||||
// cannot even express).
|
||||
type utteranceBoundary bool
|
||||
|
||||
const (
|
||||
// boundaryOpen is the zero value (false), so a fresh decode starts open —
|
||||
// exactly the legacy `var finalEou bool` (false) initial condition.
|
||||
boundaryOpen utteranceBoundary = false
|
||||
boundaryClosed utteranceBoundary = true
|
||||
)
|
||||
|
||||
// observe folds one decode increment into the latch and returns the new state.
|
||||
//
|
||||
// <EOU> takes priority when a single feed carries both an <EOU> and speech
|
||||
// (e.g. {"text":"hello","eou":1}): the utterance both produced that text AND
|
||||
// ended, so the decode rests on the boundary. This matches the legacy
|
||||
// eou-checked-first ordering at every call site.
|
||||
func (b utteranceBoundary) observe(r streamFeedResult) utteranceBoundary {
|
||||
switch {
|
||||
case r.Eou:
|
||||
return boundaryClosed
|
||||
case r.Eob || r.Delta != "" || len(r.Words) > 0:
|
||||
return boundaryOpen
|
||||
default:
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
// ended reports whether the decode currently rests on an end-of-utterance
|
||||
// boundary (a real <EOU>, not a backchannel <EOB>). This is what a closing
|
||||
// FinalResult carries as .Eou.
|
||||
func (b utteranceBoundary) ended() bool { return b == boundaryClosed }
|
||||
|
||||
func (b utteranceBoundary) String() string {
|
||||
if b == boundaryClosed {
|
||||
return "closed"
|
||||
}
|
||||
return "open"
|
||||
}
|
||||
92
backend/go/parakeet-cpp/boundary_test.go
Normal file
92
backend/go/parakeet-cpp/boundary_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("utteranceBoundary (decode end-of-utterance latch)", func() {
|
||||
It("starts open: a fresh decode is not on a boundary", func() {
|
||||
var b utteranceBoundary
|
||||
Expect(b).To(Equal(boundaryOpen))
|
||||
Expect(b.ended()).To(BeFalse())
|
||||
})
|
||||
|
||||
DescribeTable("single feed transition from the open state",
|
||||
func(r streamFeedResult, wantEnded bool) {
|
||||
Expect(boundaryOpen.observe(r).ended()).To(Equal(wantEnded))
|
||||
},
|
||||
Entry("<EOU> closes it", streamFeedResult{Eou: true}, true),
|
||||
Entry("<EOU> with text closes it (eou wins)", streamFeedResult{Delta: "hi", Eou: true}, true),
|
||||
Entry("<EOB> stays open (backchannel is not a turn boundary)", streamFeedResult{Eob: true}, false),
|
||||
Entry("plain text stays open", streamFeedResult{Delta: "hello"}, false),
|
||||
Entry("words-only stays open", streamFeedResult{Words: []transcriptWord{{W: "x"}}}, false),
|
||||
Entry("empty feed is a no-op (stays open)", streamFeedResult{}, false),
|
||||
)
|
||||
|
||||
DescribeTable("single feed transition from the closed state",
|
||||
func(r streamFeedResult, wantEnded bool) {
|
||||
Expect(boundaryClosed.observe(r).ended()).To(Equal(wantEnded))
|
||||
},
|
||||
Entry("another <EOU> stays closed", streamFeedResult{Eou: true}, true),
|
||||
Entry("trailing speech reopens it", streamFeedResult{Delta: "and more"}, false),
|
||||
Entry("words reopen it", streamFeedResult{Words: []transcriptWord{{W: "x"}}}, false),
|
||||
Entry("a backchannel <EOB> reopens it", streamFeedResult{Eob: true}, false),
|
||||
Entry("empty feed is a no-op (stays closed)", streamFeedResult{}, true),
|
||||
)
|
||||
|
||||
It("is a latch: <EOU> then trailing speech reopens, then <EOU> closes again", func() {
|
||||
b := boundaryOpen
|
||||
b = b.observe(streamFeedResult{Delta: "turn one", Eou: true})
|
||||
Expect(b.ended()).To(BeTrue())
|
||||
b = b.observe(streamFeedResult{Delta: " and more"})
|
||||
Expect(b.ended()).To(BeFalse(), "trailing speech without an EOU is an open utterance")
|
||||
b = b.observe(streamFeedResult{Eou: true})
|
||||
Expect(b.ended()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("treats a backchannel before a real EOU correctly", func() {
|
||||
b := boundaryOpen
|
||||
b = b.observe(streamFeedResult{Delta: "uh huh", Eob: true})
|
||||
Expect(b.ended()).To(BeFalse(), "a backchannel must not masquerade as a turn boundary")
|
||||
b = b.observe(streamFeedResult{Delta: "done", Eou: true})
|
||||
Expect(b.ended()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches the reference fold over seeded random feed sequences", func() {
|
||||
// The invariant: after any sequence of feeds, ended() is true iff the
|
||||
// last feed that carried ANY event was an <EOU>. <EOU> takes priority
|
||||
// when a feed carries both an EOU and speech; empty feeds are ignored.
|
||||
for seed := uint64(1); seed <= 200; seed++ {
|
||||
rng := rand.New(rand.NewPCG(seed, seed*2654435761))
|
||||
b := boundaryOpen
|
||||
lastWasEou := false // reference: did the last meaningful feed end on EOU?
|
||||
steps := rng.IntN(30)
|
||||
for i := 0; i < steps; i++ {
|
||||
var r streamFeedResult
|
||||
switch rng.IntN(5) {
|
||||
case 0:
|
||||
r = streamFeedResult{Eou: true}
|
||||
case 1:
|
||||
r = streamFeedResult{Eob: true}
|
||||
case 2:
|
||||
r = streamFeedResult{Delta: "w"}
|
||||
case 3:
|
||||
r = streamFeedResult{Delta: "w", Eou: true} // eou + speech, eou wins
|
||||
case 4:
|
||||
r = streamFeedResult{} // empty: no-op
|
||||
}
|
||||
b = b.observe(r)
|
||||
if r.Eou {
|
||||
lastWasEou = true
|
||||
} else if r.Eob || r.Delta != "" || len(r.Words) > 0 {
|
||||
lastWasEou = false
|
||||
}
|
||||
}
|
||||
Expect(b.ended()).To(Equal(lastWasEou),
|
||||
"seed %d: latch disagreed with the reference fold", seed)
|
||||
}
|
||||
})
|
||||
})
|
||||
82
backend/go/parakeet-cpp/driver.go
Normal file
82
backend/go/parakeet-cpp/driver.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// streamFeedResult is one decode increment from a cache-aware streaming session:
|
||||
// the newly-finalized text plus the model's own per-feed boundary tokens
|
||||
// (<EOU>/<EOB>) and word timings. It is the single event type both the live
|
||||
// (bidi) and file (server-stream) paths fold over, hiding the ABI v4 JSON vs
|
||||
// older text-only entry-point split behind one shape.
|
||||
type streamFeedResult struct {
|
||||
Delta string
|
||||
Eou bool
|
||||
Eob bool
|
||||
Words []transcriptWord
|
||||
}
|
||||
|
||||
// feedChunk feeds one PCM chunk to the streaming session (or finalizes it, when
|
||||
// finalize is true) and returns the unified decode increment. It prefers the
|
||||
// ABI v4 JSON entry points (which also carry per-word timestamps) and falls
|
||||
// back to the older text-only entry points against an older libparakeet.so.
|
||||
//
|
||||
// This is the one place the JSON-vs-text choice is made; every consumer works
|
||||
// in terms of streamFeedResult.
|
||||
func (p *ParakeetCpp) feedChunk(stream uintptr, pcm []float32, finalize bool) (streamFeedResult, error) {
|
||||
if CppStreamFeedJSON != nil {
|
||||
doc, err := p.streamFeedDoc(stream, pcm, finalize)
|
||||
if err != nil {
|
||||
return streamFeedResult{}, err
|
||||
}
|
||||
return streamFeedResult{Delta: doc.Text, Eou: doc.Eou != 0, Eob: doc.Eob != 0, Words: doc.Words}, nil
|
||||
}
|
||||
delta, eou, eob, err := p.streamFeedText(stream, pcm, finalize)
|
||||
if err != nil {
|
||||
return streamFeedResult{}, err
|
||||
}
|
||||
return streamFeedResult{Delta: delta, Eou: eou, Eob: eob}, nil
|
||||
}
|
||||
|
||||
// feedSlices feeds pcm through the session in streamChunkSamples slices,
|
||||
// invoking onFeed for each decode increment. It does NOT finalize: callers
|
||||
// decide when the send side is done. The file path finalizes after the whole
|
||||
// file; the live path finalizes only when its request channel closes, never
|
||||
// between audio messages. Slicing keeps each per-call engineMu hold short so
|
||||
// concurrent unary transcription interleaves fairly (the C session buffers
|
||||
// internally).
|
||||
//
|
||||
// If ctx is non-nil it is checked before each slice so a cancelled file
|
||||
// transcription stops promptly; the live path passes nil (it is bounded by its
|
||||
// request channel instead of a ctx).
|
||||
func (p *ParakeetCpp) feedSlices(ctx context.Context, stream uintptr, pcm []float32, onFeed func(streamFeedResult) error) error {
|
||||
for off := 0; off < len(pcm); off += streamChunkSamples {
|
||||
if ctx != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(pcm))
|
||||
res, err := p.feedChunk(stream, pcm[off:end], false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := onFeed(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushTail finalizes the session once and folds the flushed tail (the last
|
||||
// ~2 encoder frames of text, which only appear on finalize) through onFeed.
|
||||
func (p *ParakeetCpp) flushTail(stream uintptr, onFeed func(streamFeedResult) error) error {
|
||||
res, err := p.feedChunk(stream, nil, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return onFeed(res)
|
||||
}
|
||||
@@ -103,12 +103,13 @@ type transcriptJSON struct {
|
||||
// {"text":"...","eou":0,"eob":0,"frame_sec":0.080000,
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
||||
// <EOU> (end of utterance) fired this feed and "eob" is 1 when an <EOB>
|
||||
// (backchannel) fired. ABI v4 conflated the two into "eou"; v5 split them, so
|
||||
// we read both and treat either as an utterance boundary for segmentation.
|
||||
// "words" are the words finalized this call with absolute (stream-relative)
|
||||
// start/end seconds.
|
||||
// "text" is the newly-finalized text since the last call. Under ABI v5 "eou"
|
||||
// is 1 iff an <EOU> fired this feed (the user yielded the turn) and "eob" 1
|
||||
// iff an <EOB> fired (a backchannel like "uh-huh" ended — NOT a turn
|
||||
// boundary). A v4 library has no "eob" field and its "eou" conflates both
|
||||
// tokens: Eob stays 0 and Eou keeps the old any-event meaning. "words" are
|
||||
// the words finalized this call with absolute (stream-relative) start/end
|
||||
// seconds.
|
||||
type streamFeedJSON struct {
|
||||
Text string `json:"text"`
|
||||
Eou int `json:"eou"`
|
||||
@@ -364,7 +365,7 @@ var segmentSeparators = []rune{'.', '?', '!'}
|
||||
// the caller requested word granularity; token ids populate each segment's
|
||||
// Tokens by time-window membership. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
text, eou := stripEouMarker(strings.TrimSpace(doc.Text))
|
||||
|
||||
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
|
||||
gapSeconds := 0.0
|
||||
@@ -383,6 +384,7 @@ func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gap
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
|
||||
Eou: eou,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,7 +411,25 @@ func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gap
|
||||
}
|
||||
segments = append(segments, seg)
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments, Eou: eou}
|
||||
}
|
||||
|
||||
// stripEouMarker removes a trailing literal <EOU>/<EOB> from offline-decode
|
||||
// text and reports whether the decode ended on an end-of-UTTERANCE token. The
|
||||
// realtime EOU model's offline decode keeps the special token in the
|
||||
// detokenized text (the streaming path strips it and surfaces it as flags
|
||||
// instead); user-visible transcripts must never carry either marker, but only
|
||||
// <EOU> may confirm the semantic_vad retranscribe cross-check — a decode
|
||||
// ending on <EOB> means the last thing heard was a backchannel, not the user
|
||||
// yielding the turn.
|
||||
func stripEouMarker(text string) (string, bool) {
|
||||
if strings.HasSuffix(text, "<EOU>") {
|
||||
return strings.TrimSpace(strings.TrimSuffix(text, "<EOU>")), true
|
||||
}
|
||||
if strings.HasSuffix(text, "<EOB>") {
|
||||
return strings.TrimSpace(strings.TrimSuffix(text, "<EOB>")), false
|
||||
}
|
||||
return text, false
|
||||
}
|
||||
|
||||
// splitWordsIntoSegments groups words into segments exactly as NeMo's
|
||||
@@ -476,41 +496,55 @@ func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
|
||||
return ids
|
||||
}
|
||||
|
||||
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
|
||||
// is the model's own utterance boundary; each closed segment takes its start/end
|
||||
// from its first/last accumulated word.
|
||||
// streamSegmenter accumulates streaming decode increments into per-utterance
|
||||
// segments. <EOU>/<EOB> are the model's own utterance boundaries; each closes a
|
||||
// segment. When the feed carries per-word timings (ABI v4 JSON), a closed
|
||||
// segment takes its start/end from its first/last word; against an older
|
||||
// text-only library (no words) it falls back to segmenting the delta text, so
|
||||
// the same assembler serves both paths.
|
||||
type streamSegmenter struct {
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord
|
||||
nextID int32
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord // words for the open segment (ABI v4 JSON path)
|
||||
curText []string // delta text for the open segment (text-only path)
|
||||
nextID int32
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
||||
s.cur = append(s.cur, doc.Words...)
|
||||
// Close the segment on either turn signal: <EOU> (end of utterance) or
|
||||
// <EOB> (backchannel). ABI v4 reported both via "eou"; v5 split them, so we
|
||||
// OR them here to keep the v4 segmentation boundaries.
|
||||
if doc.Eou != 0 || doc.Eob != 0 {
|
||||
func (s *streamSegmenter) add(r streamFeedResult) {
|
||||
s.cur = append(s.cur, r.Words...)
|
||||
if len(r.Words) == 0 && r.Delta != "" {
|
||||
// Older libparakeet.so with no per-word timing: segment from the text.
|
||||
s.curText = append(s.curText, r.Delta)
|
||||
}
|
||||
// Both <EOU> and <EOB> reset the decoder, so both close a segment.
|
||||
if r.Eou || r.Eob {
|
||||
s.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) flush() {
|
||||
if len(s.cur) == 0 {
|
||||
return
|
||||
switch {
|
||||
case len(s.cur) > 0:
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
case len(s.curText) > 0:
|
||||
// No words this segment: emit a text-only segment (no timestamps),
|
||||
// skipping a purely-whitespace one as the legacy text path did.
|
||||
if t := strings.TrimSpace(strings.Join(s.curText, "")); t != "" {
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{Id: s.nextID, Text: t})
|
||||
s.nextID++
|
||||
}
|
||||
}
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
s.cur = nil
|
||||
s.curText = nil
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
|
||||
@@ -535,18 +569,119 @@ func secondsToNanos(sec float64) int64 {
|
||||
return int64(sec * 1e9)
|
||||
}
|
||||
|
||||
// Per-C-call engine serialization for the streaming paths.
|
||||
//
|
||||
// Every individual C call (begin / feed / finalize / free) takes engineMu and
|
||||
// re-checks ctxPtr under the lock; the lock is NEVER held across a stream's
|
||||
// lifetime. This is safe because each parakeet.cpp call builds its own ggml
|
||||
// graph and all streaming caches live in the session object, not the ctx —
|
||||
// the only ctx-shared mutable state is last_error, which is why it is read
|
||||
// under the same lock as the failing call. Holding the lock per call (rather
|
||||
// than per stream, as this file previously did) keeps a long-lived live
|
||||
// session from starving batched unary transcription and vice versa.
|
||||
//
|
||||
// A stream must not outlive its ctx (C-API contract). Free() takes engineMu
|
||||
// and zeroes ctxPtr, so a racing per-call helper returns ModelNotLoaded
|
||||
// instead of feeding a freed engine; streamFree of an orphaned session only
|
||||
// runs the session destructor, which does not touch the ctx.
|
||||
|
||||
// streamBegin opens a cache-aware streaming session. A 0 stream with nil
|
||||
// error means the loaded model is not a streaming model.
|
||||
func (p *ParakeetCpp) streamBegin(lang string) (uintptr, error) {
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
if p.ctxPtr == 0 {
|
||||
return 0, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if CppStreamBeginLang != nil {
|
||||
return CppStreamBeginLang(p.ctxPtr, lang), nil
|
||||
}
|
||||
return CppStreamBegin(p.ctxPtr), nil
|
||||
}
|
||||
|
||||
func (p *ParakeetCpp) streamFree(stream uintptr) {
|
||||
if stream == 0 {
|
||||
return
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
CppStreamFree(stream)
|
||||
}
|
||||
|
||||
// streamFeedText runs one text-mode feed (or the finalize flush when
|
||||
// finalize is true) under engineMu, returning the newly-finalized delta and
|
||||
// whether an <EOU>/<EOB> fired during the call.
|
||||
func (p *ParakeetCpp) streamFeedText(stream uintptr, pcm []float32, finalize bool) (delta string, eou, eob bool, err error) {
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
if p.ctxPtr == 0 {
|
||||
return "", false, false, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
var ret uintptr
|
||||
var events int32
|
||||
if finalize {
|
||||
ret = CppStreamFinalize(stream)
|
||||
} else {
|
||||
ret = CppStreamFeed(stream, pcm, int32(len(pcm)), unsafe.Pointer(&events))
|
||||
}
|
||||
if ret == 0 {
|
||||
// last_error is ctx-shared: read it under the same lock as the call.
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return "", false, false, fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
delta = goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
// ABI v5: eou_out is a bitmask (bit 0 = <EOU>, bit 1 = <EOB>). A v4
|
||||
// library sets 0/1 for either token, which the bit-0 test reads as the
|
||||
// old conflated eou — the EOB distinction simply isn't available there.
|
||||
return delta, events&1 != 0, events&2 != 0, nil
|
||||
}
|
||||
|
||||
// streamFeedDoc runs one ABI v4 JSON feed (or finalize) under engineMu and
|
||||
// returns the parsed {text,eou,frame_sec,words} document.
|
||||
func (p *ParakeetCpp) streamFeedDoc(stream uintptr, pcm []float32, finalize bool) (streamFeedJSON, error) {
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
if p.ctxPtr == 0 {
|
||||
return streamFeedJSON{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
var ret uintptr
|
||||
if finalize {
|
||||
ret = CppStreamFinalizeJSON(stream)
|
||||
} else {
|
||||
ret = CppStreamFeedJSON(stream, pcm, int32(len(pcm)))
|
||||
}
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return streamFeedJSON{}, fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return streamFeedJSON{}, fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream drives the cache-aware streaming RNN-T over the
|
||||
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it in
|
||||
// chunks to parakeet_capi_stream_feed, and emits each newly-finalized text
|
||||
// run as a TranscriptStreamResponse delta. <EOU>/<EOB> events close the
|
||||
// current segment; a closing FinalResult carries the full transcript and the
|
||||
// per-utterance segments.
|
||||
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it through
|
||||
// the shared decode driver (feedSlices/flushTail), and emits each
|
||||
// newly-finalized text run as a TranscriptStreamResponse delta. <EOU>/<EOB>
|
||||
// events close the current segment; a closing FinalResult carries the full
|
||||
// transcript, the per-utterance segments, and whether the file ended on an
|
||||
// utterance boundary.
|
||||
//
|
||||
// stream_begin returns 0 for models that are not cache-aware streaming models
|
||||
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those we fall
|
||||
// back to a single offline transcription emitted as one delta plus a closing
|
||||
// FinalResult, matching LocalAI's non-streaming streaming contract (and the
|
||||
// whisper backend), so the streaming endpoint works for every model.
|
||||
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those this
|
||||
// returns codes.Unimplemented rather than faking a stream from an offline
|
||||
// decode — see the stream==0 branch and grpcerrors.StreamTranscriptionUnsupported.
|
||||
func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
@@ -560,185 +695,73 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
var stream uintptr
|
||||
if CppStreamBeginLang != nil {
|
||||
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
|
||||
} else {
|
||||
stream = CppStreamBegin(p.ctxPtr)
|
||||
stream, err := p.streamBegin(opts.GetLanguage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
res, err := p.AudioTranscription(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t := strings.TrimSpace(res.Text); t != "" {
|
||||
results <- &pb.TranscriptStreamResponse{Delta: t}
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: &res}
|
||||
return nil
|
||||
// Not a cache-aware streaming model. Report the missing capability
|
||||
// honestly instead of decoding offline and emitting it as one "delta"
|
||||
// + final: a client that asked for streaming must learn the model
|
||||
// cannot stream, not receive a batch result dressed as a stream (which
|
||||
// is indistinguishable except qualitatively, and silently breaks any
|
||||
// feature that genuinely needs incremental output). Callers wanting a
|
||||
// plain transcript use the unary AudioTranscription path. This mirrors
|
||||
// AudioTranscriptionLive, which already returns Unimplemented here.
|
||||
return grpcerrors.StreamTranscriptionUnsupported("parakeet-cpp",
|
||||
"loaded model is not a cache-aware streaming model")
|
||||
}
|
||||
defer CppStreamFree(stream)
|
||||
// The C engine is a single shared context: a streaming session and a batched
|
||||
// unary dispatch must never touch it at once, so hold engineMu for the whole
|
||||
// stream. This lock is intentionally taken AFTER the non-streaming fallback
|
||||
// above returns: that fallback goes through AudioTranscription -> the batcher
|
||||
// -> runBatch, which itself acquires engineMu, so locking here first would
|
||||
// deadlock. Do not hoist this lock above the fallback.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
defer p.streamFree(stream)
|
||||
|
||||
data, duration, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ABI v4: when the streaming JSON entry points are present, drive them so the
|
||||
// per-utterance segments carry per-word start/end timestamps. Falls through to
|
||||
// the text-only loop below against an older libparakeet.so. Runs under the
|
||||
// engineMu already held above.
|
||||
if CppStreamFeedJSON != nil {
|
||||
return p.streamJSON(ctx, stream, data, duration, results)
|
||||
}
|
||||
|
||||
// Fold the shared decode driver's per-feed increments into the streamed
|
||||
// deltas and the closing batch result: words/text accumulate into
|
||||
// per-utterance segments (streamSegmenter), and the utterance-boundary
|
||||
// latch (boundary.go) records whether the file ended on an <EOU>. These
|
||||
// are the offline path's concern — the live RPC carries none of them.
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
segments []*pb.TranscriptSegment
|
||||
segID int32
|
||||
seg streamSegmenter
|
||||
boundary utteranceBoundary
|
||||
)
|
||||
|
||||
flushSegment := func() {
|
||||
t := strings.TrimSpace(segText.String())
|
||||
segText.Reset()
|
||||
if t == "" {
|
||||
return
|
||||
emit := func(r streamFeedResult) error {
|
||||
if r.Delta != "" {
|
||||
full.WriteString(r.Delta)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: r.Delta}
|
||||
}
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: segID, Text: t})
|
||||
segID++
|
||||
}
|
||||
|
||||
// emitDelta consumes the malloc'd char* returned by feed/finalize: frees
|
||||
// it, accumulates the text, and sends a delta when non-empty. A 0 return
|
||||
// is an error (vs the "" empty-but-non-NULL no-new-text case).
|
||||
emitDelta := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
delta := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
full.WriteString(delta)
|
||||
segText.WriteString(delta)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
seg.add(r)
|
||||
boundary = boundary.observe(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
|
||||
var eou int32
|
||||
ret := CppStreamFeed(stream, chunk, int32(len(chunk)), unsafe.Pointer(&eou))
|
||||
if err := emitDelta(ret); err != nil {
|
||||
return err
|
||||
}
|
||||
if eou != 0 {
|
||||
flushSegment()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the streaming tail (final encoder chunk).
|
||||
if err := emitDelta(CppStreamFinalize(stream)); err != nil {
|
||||
if err := p.feedSlices(ctx, stream, data, emit); err != nil {
|
||||
return err
|
||||
}
|
||||
flushSegment()
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamJSON drives the streaming JSON entry points (present since ABI v4): each
|
||||
// feed/finalize returns a {text,eou,eob,frame_sec,words} document. The
|
||||
// newly-finalized text is emitted as a delta (unchanged streaming contract)
|
||||
// while words are accumulated into per-utterance segments (closed on <EOU> or
|
||||
// <EOB>) so the closing FinalResult carries timestamped segments. Runs under
|
||||
// engineMu (already held by the caller).
|
||||
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
||||
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
||||
var (
|
||||
full strings.Builder
|
||||
seg streamSegmenter
|
||||
)
|
||||
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
|
||||
// emits the delta, and routes words through the segmenter.
|
||||
consume := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
if doc.Text != "" {
|
||||
full.WriteString(doc.Text)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
|
||||
}
|
||||
seg.add(doc)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
|
||||
if err := p.flushTail(stream, emit); err != nil {
|
||||
return err
|
||||
}
|
||||
seg.flush() // close any trailing utterance that never saw an EOU
|
||||
seg.flush() // close a trailing utterance that never saw an <EOU>
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
// final.Text is the exact concatenation of the streamed deltas (full is
|
||||
// their accumulation), so concat(deltas) == FinalResult.Text holds even
|
||||
// when the model prepends a leading space to the first word (SentencePiece
|
||||
// detokenization). This matches the whisper backend's streaming contract.
|
||||
// The single-segment fallback stays trimmed.
|
||||
fullText := full.String()
|
||||
segments := seg.segments()
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
if trimmed := strings.TrimSpace(fullText); len(segments) == 0 && trimmed != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: trimmed})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Text: fullText,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
Eou: boundary.ended(),
|
||||
},
|
||||
}
|
||||
return nil
|
||||
@@ -803,6 +826,10 @@ func (p *ParakeetCpp) Free() error {
|
||||
close(p.batStop)
|
||||
p.batStop = nil
|
||||
}
|
||||
// engineMu so an in-flight streaming call (which locks per C call and
|
||||
// re-checks ctxPtr under the lock) can never feed into a freed ctx.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
if p.ctxPtr != 0 {
|
||||
CppFree(p.ctxPtr)
|
||||
p.ctxPtr = 0
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestParakeetCpp(t *testing.T) {
|
||||
@@ -201,6 +203,29 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("returns the typed Unimplemented signal for non-streaming models (no offline fallback)", func() {
|
||||
// stream_begin == 0 means the loaded model is not a cache-aware
|
||||
// streaming model. The backend must surface that, not silently
|
||||
// decode offline and fake a one-shot "stream".
|
||||
savedBegin, savedBeginLang := CppStreamBegin, CppStreamBeginLang
|
||||
defer func() { CppStreamBegin, CppStreamBeginLang = savedBegin, savedBeginLang }()
|
||||
CppStreamBeginLang = nil
|
||||
CppStreamBegin = func(ctx uintptr) uintptr { return 0 }
|
||||
|
||||
p := &ParakeetCpp{ctxPtr: 1}
|
||||
results := make(chan *pb.TranscriptStreamResponse, 8)
|
||||
err := p.AudioTranscriptionStream(context.Background(),
|
||||
&pb.TranscriptRequest{Dst: "ignored.wav"}, results)
|
||||
Expect(status.Code(err)).To(Equal(codes.Unimplemented))
|
||||
|
||||
// Honest signal: nothing was emitted — no faked batch result.
|
||||
var emitted []*pb.TranscriptStreamResponse
|
||||
for r := range results {
|
||||
emitted = append(emitted, r)
|
||||
}
|
||||
Expect(emitted).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
|
||||
// Streaming needs a cache-aware streaming model (e.g.
|
||||
// realtime_eou); the offline test model would fail stream_begin.
|
||||
|
||||
186
backend/go/parakeet-cpp/live.go
Normal file
186
backend/go/parakeet-cpp/live.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// liveSampleRate is the only PCM rate the parakeet C streaming API accepts.
|
||||
const liveSampleRate = 16000
|
||||
|
||||
// AudioTranscriptionLive drives one cache-aware streaming session over audio
|
||||
// fed incrementally by the caller (the realtime API's semantic_vad turn
|
||||
// detection). Contract:
|
||||
//
|
||||
// - the first request must carry a Config; a Config mid-stream resets the
|
||||
// decode session (free + begin) and drops accumulated transcript state;
|
||||
// - a Ready ack is sent right after a successful stream_begin so callers
|
||||
// can degrade synchronously when the model has no streaming support
|
||||
// (LiveTranscriptionUnsupported, codes.Unimplemented);
|
||||
// - every feed that produced output is forwarded as {delta, eou, words};
|
||||
// the <EOU>/<EOB> flag is the model's own utterance boundary and the
|
||||
// decoder auto-resets after it, so one session spans many utterances;
|
||||
// - closing the send side finalizes: the held-back tail chunk is flushed
|
||||
// (the last ~2 encoder frames of words only appear here) and a terminal
|
||||
// FinalResult carries the full transcript Text only. Per-utterance
|
||||
// segments, duration, and the terminal <EOU> flag are NOT produced here —
|
||||
// the realtime core consumes the streamed per-feed tokens and the final
|
||||
// Text; those batch fields are the file path's concern (see
|
||||
// AudioTranscriptionStream).
|
||||
//
|
||||
// Engine access is serialized per C call (streamBegin/streamFeed*/streamFree
|
||||
// take engineMu internally), never for the session lifetime — unary
|
||||
// transcription keeps flowing between feeds.
|
||||
func (p *ParakeetCpp) AudioTranscriptionLive(in <-chan *pb.TranscriptLiveRequest, out chan<- *pb.TranscriptLiveResponse) error {
|
||||
defer close(out)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
|
||||
first, ok := <-in
|
||||
if !ok {
|
||||
return nil // caller closed without sending anything
|
||||
}
|
||||
cfg := first.GetConfig()
|
||||
if cfg == nil {
|
||||
return status.Error(codes.InvalidArgument, "parakeet-cpp: first live message must carry a config")
|
||||
}
|
||||
if err := validateLiveConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := p.streamBegin(cfg.GetLanguage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stream == 0 {
|
||||
return grpcerrors.LiveTranscriptionUnsupported("parakeet-cpp",
|
||||
"loaded model is not a cache-aware streaming model")
|
||||
}
|
||||
// stream is reassigned on a mid-stream Config reset; free whatever is
|
||||
// current when the RPC unwinds.
|
||||
defer func() { p.streamFree(stream) }()
|
||||
|
||||
out <- &pb.TranscriptLiveResponse{Ready: true}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
fedSecs float64
|
||||
|
||||
// behindSec accumulates how far decode wall time has fallen behind
|
||||
// the audio it was fed. A live caller feeds in real time, so a
|
||||
// persistent positive backlog means every downstream signal —
|
||||
// including the <EOU> the turn detector waits on — arrives that many
|
||||
// seconds late. Warned once per session; reset by a Config reset.
|
||||
behindSec float64
|
||||
behindWarned bool
|
||||
)
|
||||
|
||||
// emit forwards one decode increment: it streams the per-feed tokens the
|
||||
// realtime turn detector consumes (delta/eou/eob/words) and accumulates the
|
||||
// running transcript for the closing FinalResult. No segmentation or
|
||||
// boundary latch here — the live consumer reads only the streamed tokens
|
||||
// and the final Text; per-utterance segments and the terminal <EOU> flag
|
||||
// are an offline-path concern (see AudioTranscriptionStream / boundary.go).
|
||||
emit := func(r streamFeedResult) error {
|
||||
if r.Delta != "" {
|
||||
full.WriteString(r.Delta)
|
||||
}
|
||||
if r.Delta != "" || r.Eou || r.Eob || len(r.Words) > 0 {
|
||||
out <- &pb.TranscriptLiveResponse{
|
||||
Delta: r.Delta,
|
||||
Eou: r.Eou,
|
||||
Eob: r.Eob,
|
||||
Words: liveWordsToProto(r.Words),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for req := range in {
|
||||
switch payload := req.GetPayload().(type) {
|
||||
case *pb.TranscriptLiveRequest_Config:
|
||||
if err := validateLiveConfig(payload.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
// Reset: a fresh decode session, dropping accumulated state.
|
||||
p.streamFree(stream)
|
||||
stream, err = p.streamBegin(payload.Config.GetLanguage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stream == 0 {
|
||||
return grpcerrors.LiveTranscriptionUnsupported("parakeet-cpp",
|
||||
"loaded model is not a cache-aware streaming model")
|
||||
}
|
||||
full.Reset()
|
||||
fedSecs = 0
|
||||
case *pb.TranscriptLiveRequest_Audio:
|
||||
pcm := payload.Audio.GetPcm()
|
||||
audioSec := float64(len(pcm)) / liveSampleRate
|
||||
fedSecs += audioSec
|
||||
start := time.Now()
|
||||
// nil ctx: a live session is bounded by this request channel, not a
|
||||
// context — cancellation is the caller closing the stream.
|
||||
if err := p.feedSlices(nil, stream, pcm, emit); err != nil {
|
||||
return err
|
||||
}
|
||||
wallSec := time.Since(start).Seconds()
|
||||
behindSec += wallSec - audioSec
|
||||
if behindSec < 0 {
|
||||
behindSec = 0
|
||||
}
|
||||
xlog.Debug("parakeet-cpp: live feed",
|
||||
"audio_ms", int(audioSec*1000), "wall_ms", int(wallSec*1000),
|
||||
"behind_ms", int(behindSec*1000), "fed_s", fedSecs)
|
||||
if behindSec > 1 && !behindWarned {
|
||||
behindWarned = true
|
||||
xlog.Warn("parakeet-cpp: live decode is falling behind real time; "+
|
||||
"end-of-utterance signals will arrive late",
|
||||
"behind_s", behindSec, "fed_s", fedSecs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send side closed: flush the streaming tail and emit the final transcript.
|
||||
// The live FinalResult carries only Text — the authoritative full-turn
|
||||
// transcript the realtime core commits. Per-utterance segments, duration,
|
||||
// and the terminal <EOU> flag are not produced on the live path.
|
||||
if err := p.flushTail(stream, emit); err != nil {
|
||||
return err
|
||||
}
|
||||
out <- &pb.TranscriptLiveResponse{
|
||||
FinalResult: &pb.TranscriptResult{Text: strings.TrimSpace(full.String())},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateLiveConfig(cfg *pb.TranscriptLiveConfig) error {
|
||||
if sr := cfg.GetSampleRate(); sr != 0 && sr != liveSampleRate {
|
||||
return status.Errorf(codes.InvalidArgument,
|
||||
"parakeet-cpp: unsupported live sample_rate %d (only %d)", sr, liveSampleRate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func liveWordsToProto(words []transcriptWord) []*pb.TranscriptWord {
|
||||
if len(words) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*pb.TranscriptWord, len(words))
|
||||
for i, w := range words {
|
||||
out[i] = &pb.TranscriptWord{
|
||||
Start: secondsToNanos(w.Start),
|
||||
End: secondsToNanos(w.End),
|
||||
Text: w.W,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
417
backend/go/parakeet-cpp/live_test.go
Normal file
417
backend/go/parakeet-cpp/live_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// The live-RPC specs drive AudioTranscriptionLive entirely against stubbed
|
||||
// Cpp* package vars (the same seam batcher_test.go uses), so they run
|
||||
// without libparakeet.so.
|
||||
|
||||
// liveCstrPool hands out NUL-terminated C-style strings backed by Go memory
|
||||
// and keeps them alive for the duration of a spec (goStringFromCPtr reads
|
||||
// through the raw pointer; Go's GC must not collect the backing array while
|
||||
// a stub's return value is in flight).
|
||||
type liveCstrPool struct {
|
||||
mu sync.Mutex
|
||||
bufs [][]byte
|
||||
}
|
||||
|
||||
func (p *liveCstrPool) cstr(s string) uintptr {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
b := append([]byte(s), 0)
|
||||
p.bufs = append(p.bufs, b)
|
||||
return uintptr(unsafe.Pointer(&b[0]))
|
||||
}
|
||||
|
||||
// liveStubs swaps every C entry point the live path touches and returns a
|
||||
// restore func for AfterEach.
|
||||
func liveStubs() (restore func()) {
|
||||
savedBegin, savedBeginLang := CppStreamBegin, CppStreamBeginLang
|
||||
savedFeed, savedFeedJSON := CppStreamFeed, CppStreamFeedJSON
|
||||
savedFinalize, savedFinalizeJSON := CppStreamFinalize, CppStreamFinalizeJSON
|
||||
savedFree, savedLastError := CppStreamFree, CppLastError
|
||||
savedFreeString := CppFreeString
|
||||
return func() {
|
||||
CppStreamBegin, CppStreamBeginLang = savedBegin, savedBeginLang
|
||||
CppStreamFeed, CppStreamFeedJSON = savedFeed, savedFeedJSON
|
||||
CppStreamFinalize, CppStreamFinalizeJSON = savedFinalize, savedFinalizeJSON
|
||||
CppStreamFree, CppLastError = savedFree, savedLastError
|
||||
CppFreeString = savedFreeString
|
||||
}
|
||||
}
|
||||
|
||||
// runLive starts the RPC on its own goroutine and returns the request
|
||||
// channel plus a collector for everything the backend emitted.
|
||||
func runLive(p *ParakeetCpp) (chan *pb.TranscriptLiveRequest, chan *pb.TranscriptLiveResponse, chan error) {
|
||||
in := make(chan *pb.TranscriptLiveRequest)
|
||||
out := make(chan *pb.TranscriptLiveResponse, 32)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- p.AudioTranscriptionLive(in, out) }()
|
||||
return in, out, errCh
|
||||
}
|
||||
|
||||
func liveConfig(lang string) *pb.TranscriptLiveRequest {
|
||||
return &pb.TranscriptLiveRequest{
|
||||
Payload: &pb.TranscriptLiveRequest_Config{Config: &pb.TranscriptLiveConfig{Language: lang}},
|
||||
}
|
||||
}
|
||||
|
||||
func liveAudio(pcm []float32) *pb.TranscriptLiveRequest {
|
||||
return &pb.TranscriptLiveRequest{
|
||||
Payload: &pb.TranscriptLiveRequest_Audio{Audio: &pb.TranscriptLiveAudio{Pcm: pcm}},
|
||||
}
|
||||
}
|
||||
|
||||
func collectLive(out chan *pb.TranscriptLiveResponse) []*pb.TranscriptLiveResponse {
|
||||
var got []*pb.TranscriptLiveResponse
|
||||
for r := range out {
|
||||
got = append(got, r)
|
||||
}
|
||||
return got
|
||||
}
|
||||
|
||||
var _ = Describe("AudioTranscriptionLive (stubbed C API)", func() {
|
||||
var (
|
||||
pool *liveCstrPool
|
||||
restore func()
|
||||
p *ParakeetCpp
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
pool = &liveCstrPool{}
|
||||
restore = liveStubs()
|
||||
p = &ParakeetCpp{ctxPtr: 1}
|
||||
|
||||
CppStreamBeginLang = nil
|
||||
CppStreamBegin = func(ctx uintptr) uintptr { return 7 }
|
||||
CppStreamFree = func(s uintptr) {}
|
||||
CppFreeString = func(s uintptr) {}
|
||||
CppLastError = func(ctx uintptr) string { return "stub error" }
|
||||
CppStreamFeed = nil
|
||||
CppStreamFeedJSON = nil
|
||||
CppStreamFinalize = nil
|
||||
CppStreamFinalizeJSON = nil
|
||||
})
|
||||
|
||||
AfterEach(func() { restore() })
|
||||
|
||||
It("rejects a stream whose first message is not a config", func() {
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveAudio([]float32{0.1})
|
||||
close(in)
|
||||
|
||||
err := <-errCh
|
||||
Expect(status.Code(err)).To(Equal(codes.InvalidArgument))
|
||||
Expect(collectLive(out)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("rejects a non-16k sample rate", func() {
|
||||
in, _, errCh := runLive(p)
|
||||
in <- &pb.TranscriptLiveRequest{
|
||||
Payload: &pb.TranscriptLiveRequest_Config{Config: &pb.TranscriptLiveConfig{SampleRate: 8000}},
|
||||
}
|
||||
close(in)
|
||||
Expect(status.Code(<-errCh)).To(Equal(codes.InvalidArgument))
|
||||
})
|
||||
|
||||
It("returns the typed Unimplemented signal for non-streaming models, before any ack", func() {
|
||||
CppStreamBegin = func(ctx uintptr) uintptr { return 0 }
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
close(in)
|
||||
|
||||
err := <-errCh
|
||||
Expect(grpcerrors.IsLiveTranscriptionUnsupported(err)).To(BeTrue())
|
||||
Expect(collectLive(out)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("streams deltas, eou flags and words on the JSON path and finalizes on close", func() {
|
||||
var freed []uintptr
|
||||
CppStreamFree = func(s uintptr) { freed = append(freed, s) }
|
||||
feeds := 0
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
feeds++
|
||||
switch feeds {
|
||||
case 1:
|
||||
return pool.cstr(`{"text":"hello ","eou":0,"frame_sec":0.08,` +
|
||||
`"words":[{"w":"hello","start":0.1,"end":0.4,"conf":0.9}]}`)
|
||||
default:
|
||||
return pool.cstr(`{"text":"world","eou":1,"frame_sec":0.08,` +
|
||||
`"words":[{"w":"world","start":0.5,"end":0.8,"conf":0.9}]}`)
|
||||
}
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("en")
|
||||
in <- liveAudio(make([]float32, 100))
|
||||
in <- liveAudio(make([]float32, 200))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(4)) // ready, two deltas, final
|
||||
|
||||
Expect(got[0].Ready).To(BeTrue())
|
||||
|
||||
Expect(got[1].Delta).To(Equal("hello "))
|
||||
Expect(got[1].Eou).To(BeFalse())
|
||||
Expect(got[1].Words).To(HaveLen(1))
|
||||
Expect(got[1].Words[0].Text).To(Equal("hello"))
|
||||
|
||||
Expect(got[2].Delta).To(Equal("world"))
|
||||
Expect(got[2].Eou).To(BeTrue())
|
||||
|
||||
final := got[3].FinalResult
|
||||
Expect(final).NotTo(BeNil())
|
||||
Expect(final.Text).To(Equal("hello world"))
|
||||
// The live FinalResult carries only Text. Per-utterance segments,
|
||||
// duration and the terminal eou flag are an offline-path concern (see
|
||||
// boundary.go / AudioTranscriptionStream); the realtime core reads the
|
||||
// streamed per-feed tokens above plus this Text.
|
||||
Expect(final.Eou).To(BeFalse())
|
||||
Expect(final.Segments).To(BeEmpty())
|
||||
Expect(final.Duration).To(BeZero())
|
||||
|
||||
Expect(freed).To(Equal([]uintptr{7}))
|
||||
})
|
||||
|
||||
It("falls back to the text feed (eou out-param) when the JSON entry points are absent", func() {
|
||||
feeds := 0
|
||||
CppStreamFeed = func(s uintptr, pcm []float32, n int32, eouOut unsafe.Pointer) uintptr {
|
||||
feeds++
|
||||
if feeds == 2 {
|
||||
*(*int32)(eouOut) = 1
|
||||
return pool.cstr("done")
|
||||
}
|
||||
return pool.cstr("first ")
|
||||
}
|
||||
CppStreamFinalize = func(s uintptr) uintptr { return pool.cstr("") }
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(4))
|
||||
Expect(got[1].Delta).To(Equal("first "))
|
||||
Expect(got[1].Eou).To(BeFalse())
|
||||
Expect(got[2].Delta).To(Equal("done"))
|
||||
Expect(got[2].Eou).To(BeTrue())
|
||||
Expect(got[3].FinalResult.Text).To(Equal("first done"))
|
||||
})
|
||||
|
||||
It("forwards <EOB> as eob — a backchannel, never an eou (ABI v5 JSON)", func() {
|
||||
feeds := 0
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
feeds++
|
||||
if feeds == 1 {
|
||||
return pool.cstr(`{"text":"uh-huh","eou":0,"eob":1,"frame_sec":0.08,` +
|
||||
`"words":[{"w":"uh-huh","start":0.1,"end":0.3,"conf":0.9}]}`)
|
||||
}
|
||||
return pool.cstr(`{"text":"the turn","eou":1,"eob":0,"frame_sec":0.08,` +
|
||||
`"words":[{"w":"the","start":0.5,"end":0.6,"conf":0.9},{"w":"turn","start":0.6,"end":0.8,"conf":0.9}]}`)
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"eob":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(4))
|
||||
Expect(got[1].Eob).To(BeTrue())
|
||||
Expect(got[1].Eou).To(BeFalse(), "a backchannel must not masquerade as a turn boundary")
|
||||
Expect(got[2].Eou).To(BeTrue())
|
||||
})
|
||||
|
||||
It("maps the v5 eou_out bitmask on the text path (bit0 <EOU>, bit1 <EOB>)", func() {
|
||||
feeds := 0
|
||||
CppStreamFeed = func(s uintptr, pcm []float32, n int32, eouOut unsafe.Pointer) uintptr {
|
||||
feeds++
|
||||
if feeds == 1 {
|
||||
*(*int32)(eouOut) = 2 // <EOB> only
|
||||
return pool.cstr("uh-huh")
|
||||
}
|
||||
*(*int32)(eouOut) = 1 // <EOU>
|
||||
return pool.cstr(" done")
|
||||
}
|
||||
CppStreamFinalize = func(s uintptr) uintptr { return pool.cstr("") }
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(4))
|
||||
Expect(got[1].Eob).To(BeTrue())
|
||||
Expect(got[1].Eou).To(BeFalse())
|
||||
Expect(got[2].Eou).To(BeTrue())
|
||||
Expect(got[2].Eob).To(BeFalse())
|
||||
})
|
||||
|
||||
It("accumulates trailing text after an EOU into the final transcript", func() {
|
||||
feeds := 0
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
feeds++
|
||||
if feeds == 1 {
|
||||
return pool.cstr(`{"text":"turn one","eou":1,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
return pool.cstr(`{"text":" and more","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
final := got[len(got)-1].FinalResult
|
||||
Expect(final.Text).To(Equal("turn one and more"))
|
||||
})
|
||||
|
||||
It("resets the decode session on a mid-stream config", func() {
|
||||
var begun, freed int
|
||||
CppStreamBegin = func(ctx uintptr) uintptr { begun++; return uintptr(10 + begun) }
|
||||
CppStreamFree = func(s uintptr) { freed++ }
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
return pool.cstr(`{"text":"x","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveConfig("") // reset
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
final := got[len(got)-1].FinalResult
|
||||
Expect(final.Text).To(Equal("x"), "pre-reset transcript dropped")
|
||||
Expect(begun).To(Equal(2))
|
||||
Expect(freed).To(Equal(2), "old session freed on reset, new one on unwind")
|
||||
})
|
||||
|
||||
It("does not hold engineMu between feeds (unary work interleaves with a live session)", func() {
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
|
||||
// The session is open and idle between feeds: the engine lock must be
|
||||
// acquirable, which is what lets batched unary transcription proceed
|
||||
// mid-session. Under stream-lifetime locking this probe would block
|
||||
// until the stream ended and the Eventually would time out.
|
||||
locked := make(chan struct{})
|
||||
go func() {
|
||||
p.engineMu.Lock()
|
||||
p.engineMu.Unlock() //nolint:staticcheck // probe: acquire-release proves availability
|
||||
close(locked)
|
||||
}()
|
||||
Eventually(locked, time.Second).Should(BeClosed())
|
||||
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
collectLive(out)
|
||||
})
|
||||
|
||||
It("errors out and reads last_error under the lock when a feed fails", func() {
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr { return 0 }
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
|
||||
err := <-errCh
|
||||
Expect(err).To(MatchError(ContainSubstring("stub error")))
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(1)) // just the ready ack
|
||||
close(in)
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("stripEouMarker", func() {
|
||||
It("strips a trailing <EOU> and reports it", func() {
|
||||
text, eou := stripEouMarker("it is certainly very like the old portrait<EOU>")
|
||||
Expect(text).To(Equal("it is certainly very like the old portrait"))
|
||||
Expect(eou).To(BeTrue())
|
||||
})
|
||||
|
||||
It("strips a trailing <EOB> WITHOUT reporting an utterance end", func() {
|
||||
// A decode ending on a backchannel must not confirm the
|
||||
// retranscribe gate — the user was acknowledging, not yielding.
|
||||
text, eou := stripEouMarker("uh-huh<EOB>")
|
||||
Expect(text).To(Equal("uh-huh"))
|
||||
Expect(eou).To(BeFalse())
|
||||
})
|
||||
|
||||
It("leaves marker-free text alone", func() {
|
||||
text, eou := stripEouMarker("plain transcript")
|
||||
Expect(text).To(Equal("plain transcript"))
|
||||
Expect(eou).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not strip a marker in the middle of the text", func() {
|
||||
text, eou := stripEouMarker("a<EOU>b")
|
||||
Expect(text).To(Equal("a<EOU>b"))
|
||||
Expect(eou).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("transcriptResultFromDoc EOU handling", func() {
|
||||
It("strips the offline marker from text and sets the result flag", func() {
|
||||
doc := transcriptJSON{Text: "the old portrait<EOU>"}
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Text).To(Equal("the old portrait"))
|
||||
Expect(res.Eou).To(BeTrue())
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
Expect(res.Segments[0].Text).To(Equal("the old portrait"))
|
||||
})
|
||||
|
||||
It("reports eou=false for marker-free decodes", func() {
|
||||
doc := transcriptJSON{Text: "no marker here"}
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Text).To(Equal("no marker here"))
|
||||
Expect(res.Eou).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -106,7 +106,7 @@ var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
|
||||
var _ = Describe("streaming segment assembly", func() {
|
||||
It("closes a segment with start/end from its words on EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
|
||||
acc.add(streamFeedResult{Delta: "hello world", Eou: true, Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
@@ -118,9 +118,9 @@ var _ = Describe("streaming segment assembly", func() {
|
||||
|
||||
It("buffers words across feeds until EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
acc.add(streamFeedResult{Delta: "hi", Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
Expect(acc.segments()).To(BeEmpty())
|
||||
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
acc.add(streamFeedResult{Delta: "there", Eou: true, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
Expect(acc.segments()).To(HaveLen(1))
|
||||
Expect(acc.segments()[0].Text).To(Equal("hi there"))
|
||||
})
|
||||
@@ -129,7 +129,7 @@ var _ = Describe("streaming segment assembly", func() {
|
||||
// field; a backchannel must still close the segment as it did in v4.
|
||||
It("closes a segment on EOB (backchannel) too", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "uh huh", Eou: 0, Eob: 1, Words: []transcriptWord{
|
||||
acc.add(streamFeedResult{Delta: "uh huh", Eob: true, Words: []transcriptWord{
|
||||
{W: "uh", Start: 0.0, End: 0.2}, {W: "huh", Start: 0.2, End: 0.5},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
@@ -137,4 +137,18 @@ var _ = Describe("streaming segment assembly", func() {
|
||||
Expect(segs[0].Text).To(Equal("uh huh"))
|
||||
Expect(segs[0].End).To(Equal(secondsToNanos(0.5)))
|
||||
})
|
||||
|
||||
// Older text-only libparakeet.so: no per-word timings, so a segment is cut
|
||||
// from the delta text on each <EOU>/<EOB> (no timestamps), one per utterance.
|
||||
It("falls back to text segments when the feed carries no words", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedResult{Delta: "first turn", Eou: true})
|
||||
acc.add(streamFeedResult{Delta: "second turn", Eou: true})
|
||||
segs := acc.segments()
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0].Text).To(Equal("first turn"))
|
||||
Expect(segs[1].Text).To(Equal("second turn"))
|
||||
Expect(segs[0].Start).To(Equal(int64(0)), "no per-word timing on the text path")
|
||||
Expect(segs[0].End).To(Equal(int64(0)))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=9956436c925a367daeab097598b1ea1f32d3503f
|
||||
STABLEDIFFUSION_GGML_VERSION?=2574f5936571645f784b77623e1f09bad97d948a
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -798,6 +798,7 @@ void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed) {
|
||||
int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char* ref_images[], int ref_images_count) {
|
||||
|
||||
sd_image_t* results;
|
||||
int num_results_out = 0;
|
||||
|
||||
std::vector<int> skip_layers = {7, 8, 9};
|
||||
|
||||
@@ -994,10 +995,14 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha
|
||||
sd_ctx_params_to_str(&ctx_params),
|
||||
sd_img_gen_params_to_str(p));
|
||||
|
||||
results = generate_image(sd_c, p);
|
||||
bool gen_ok = generate_image(sd_c, p, &results, &num_results_out);
|
||||
|
||||
std::free(p);
|
||||
|
||||
if (!gen_ok || num_results_out == 0) {
|
||||
results = NULL;
|
||||
}
|
||||
|
||||
if (results == NULL) {
|
||||
fprintf (stderr, "NO results\n");
|
||||
if (input_image_buffer) free(input_image_buffer);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# voice-detect backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as VOICEDETECT_VERSION?=3d51077... (.github/bump_deps.sh
|
||||
# Upstream pin lives below as VOICEDETECT_VERSION?=1db1759572c90faef6f3a78c36b5941a096a9f89
|
||||
# can find and update it - matches the parakeet.cpp / whisper.cpp / ds4 convention).
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree voice-detect.cpp build,
|
||||
@@ -13,7 +13,7 @@
|
||||
# The default target below does the proper clone-at-pin + cmake build so CI does
|
||||
# not need a side-checkout.
|
||||
|
||||
VOICEDETECT_VERSION?=3d510772357538c5182808ac7de2278b84824e24
|
||||
VOICEDETECT_VERSION?=1db1759572c90faef6f3a78c36b5941a096a9f89
|
||||
VOICEDETECT_REPO?=https://github.com/mudler/voice-detect.cpp
|
||||
|
||||
GOCMD?=go
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=0ae02cdb2c7317b50991367c165736ce42ed96ac
|
||||
WHISPER_CPP_VERSION?=6fc7c33b4c3a2cec83e4b65abd5e96a890480375
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -11,6 +11,8 @@ import os
|
||||
|
||||
import grpc
|
||||
|
||||
from parent_watch import start_parent_death_watcher
|
||||
|
||||
|
||||
class _AbortHandler(grpc.RpcMethodHandler):
|
||||
"""A method handler that immediately aborts with UNAUTHENTICATED."""
|
||||
@@ -70,6 +72,13 @@ def get_auth_interceptors(*, aio: bool = False):
|
||||
|
||||
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
|
||||
"""
|
||||
# Arm the best-effort parent-death backstop here: this is the single helper
|
||||
# every LocalAI Python backend invokes exactly once while building its gRPC
|
||||
# server (mirroring how the Go watcher arms in pkg/grpc's shared serve path).
|
||||
# start_parent_death_watcher() is idempotent and a no-op when disabled or on
|
||||
# unsupported platforms — see parent_watch.py.
|
||||
start_parent_death_watcher()
|
||||
|
||||
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
|
||||
if not token:
|
||||
return []
|
||||
|
||||
@@ -20,7 +20,15 @@ def split_reasoning(text, think_start, think_end):
|
||||
Returns ``(reasoning_content, remaining_text)``. When ``think_start`` is
|
||||
empty or not found, returns ``("", text)`` unchanged.
|
||||
"""
|
||||
if not think_start or not text or think_start not in text:
|
||||
if not think_start or not text:
|
||||
return "", text
|
||||
if think_start not in text:
|
||||
# Models like Qwen3.5 open assistant turns already INSIDE thinking, so
|
||||
# the generated text carries only the closing tag. Everything before it
|
||||
# is reasoning that would otherwise leak into the content.
|
||||
if think_end and think_end in text:
|
||||
head, _, tail = text.partition(think_end)
|
||||
return head.strip(), tail.strip()
|
||||
return "", text
|
||||
pattern = re.compile(
|
||||
re.escape(think_start) + r"(.*?)" + re.escape(think_end or ""),
|
||||
|
||||
75
backend/python/common/mlx_utils_test.py
Normal file
75
backend/python/common/mlx_utils_test.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Unit tests for the mlx/mlx-vlm shared helpers (mlx_utils.py).
|
||||
|
||||
Run standalone (Python standard library only, no backend venv needed):
|
||||
python3 -m unittest mlx_utils_test
|
||||
|
||||
These mirror the server-less helper tests in backend/python/mlx/test.py
|
||||
(TestSharedHelpers), but live here so they run on any platform: the mlx
|
||||
test module imports grpc/backend_pb2 at import time and needs the MLX venv,
|
||||
whereas mlx_utils only needs the standard library.
|
||||
"""
|
||||
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
|
||||
class TestSplitReasoning(unittest.TestCase):
|
||||
def test_both_tags(self):
|
||||
r, c = split_reasoning(
|
||||
"<think>step 1\nstep 2</think>The answer is 42.", "<think>", "</think>"
|
||||
)
|
||||
self.assertEqual(r, "step 1\nstep 2")
|
||||
self.assertEqual(c, "The answer is 42.")
|
||||
|
||||
def test_implicit_opener_only_closing_tag(self):
|
||||
# Qwen3.5 opens the assistant turn already inside thinking, so the
|
||||
# output carries only the closing tag; everything before it is reasoning.
|
||||
r, c = split_reasoning(
|
||||
"The user is asking about the weather.\n</think>\n\nThe weather in Rome is sunny.",
|
||||
"<think>",
|
||||
"</think>",
|
||||
)
|
||||
self.assertEqual(r, "The user is asking about the weather.")
|
||||
self.assertEqual(c, "The weather in Rome is sunny.")
|
||||
|
||||
def test_no_tags_at_all(self):
|
||||
r, c = split_reasoning("just text", "<think>", "</think>")
|
||||
self.assertEqual(r, "")
|
||||
self.assertEqual(c, "just text")
|
||||
|
||||
def test_empty_think_end_and_no_opener_match(self):
|
||||
# No think_end to anchor on, and the opener is absent → return unchanged.
|
||||
r, c = split_reasoning("no opener here", "<think>", "")
|
||||
self.assertEqual(r, "")
|
||||
self.assertEqual(c, "no opener here")
|
||||
|
||||
def test_empty_text(self):
|
||||
r, c = split_reasoning("", "<think>", "</think>")
|
||||
self.assertEqual(r, "")
|
||||
self.assertEqual(c, "")
|
||||
|
||||
|
||||
class TestParseToolCalls(unittest.TestCase):
|
||||
def test_with_shim(self):
|
||||
tm = types.SimpleNamespace(
|
||||
tool_call_start="<tool_call>",
|
||||
tool_call_end="</tool_call>",
|
||||
parse_tool_call=lambda body, tools: {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": body.strip()},
|
||||
},
|
||||
)
|
||||
calls, remaining = parse_tool_calls(
|
||||
"Sure: <tool_call>Paris</tool_call>", tm, tools=None
|
||||
)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0]["name"], "get_weather")
|
||||
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
|
||||
self.assertEqual(calls[0]["index"], 0)
|
||||
self.assertNotIn("<tool_call>", remaining)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
149
backend/python/common/parent_watch.py
Normal file
149
backend/python/common/parent_watch.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Parent-death watcher (best-effort backstop) for LocalAI Python backends.
|
||||
|
||||
LocalAI spawns each backend as a child process and, on a clean shutdown, tears
|
||||
it down itself (SIGTERM -> grace -> SIGKILL). That graceful path only runs when
|
||||
LocalAI receives a catchable signal and lives long enough to run its handlers.
|
||||
If LocalAI is SIGKILLed (e.g. a supervising process's grace period elapses
|
||||
first), that teardown never runs and this backend would be reparented to init
|
||||
and linger, holding GPU/VRAM and its listen port.
|
||||
|
||||
The watcher here is a best-effort backstop for exactly that case: it does NOT
|
||||
replace the graceful teardown, it only covers the "parent vanished without
|
||||
cleaning up" path. It detects reparenting: when the process that spawned this
|
||||
backend dies, the kernel reparents us to the nearest sub-reaper or to init
|
||||
(PID 1), so os.getppid() stops matching the value captured at startup. This
|
||||
getppid() approach is portable across Linux/macOS (unlike the Linux-only
|
||||
PR_SET_PDEATHSIG), which is why it is used here, mirroring the Go backends'
|
||||
pkg/grpc/parentwatch.go and the C++ backends' parent_watch.h. It is disabled on
|
||||
Windows, which has no equivalent orphan-reparenting semantics.
|
||||
|
||||
Env vars (shared verbatim across the Go, C++ and Python backends):
|
||||
LOCALAI_BACKEND_PARENT_WATCH enabled by default; a falsey value
|
||||
("false"/"0"/"no"/"off", case-insensitive)
|
||||
disables it.
|
||||
LOCALAI_BACKEND_PARENT_WATCH_INTERVAL poll interval as a Go-style duration
|
||||
string ("500ms", "2s", "1m") or a bare
|
||||
number of seconds. Defaults to 2s.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
|
||||
ENV_PARENT_WATCH = "LOCALAI_BACKEND_PARENT_WATCH"
|
||||
ENV_PARENT_WATCH_INTERVAL = "LOCALAI_BACKEND_PARENT_WATCH_INTERVAL"
|
||||
|
||||
_DEFAULT_INTERVAL_SECONDS = 2.0
|
||||
|
||||
# Guard so repeated calls (e.g. get_auth_interceptors invoked more than once)
|
||||
# only ever arm a single watcher thread per process.
|
||||
_started = False
|
||||
_started_lock = threading.Lock()
|
||||
|
||||
|
||||
def _enabled():
|
||||
"""Report whether the watcher should run in this process."""
|
||||
# Windows does not reparent orphans to a well-known init PID, so the
|
||||
# getppid() heuristic used here doesn't apply there.
|
||||
if os.name == "nt" or sys.platform.startswith("win"):
|
||||
return False
|
||||
val = os.environ.get(ENV_PARENT_WATCH, "").strip().lower()
|
||||
if val in ("false", "0", "no", "off"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _interval_seconds():
|
||||
"""Return the configured poll interval in seconds, or the default.
|
||||
|
||||
Accepts Go-style duration strings ("500ms", "2s", "1m") for cross-language
|
||||
parity, or a bare number interpreted as seconds.
|
||||
"""
|
||||
raw = os.environ.get(ENV_PARENT_WATCH_INTERVAL, "").strip()
|
||||
if not raw:
|
||||
return _DEFAULT_INTERVAL_SECONDS
|
||||
# Split numeric prefix from unit suffix.
|
||||
i = 0
|
||||
while i < len(raw) and (raw[i].isdigit() or raw[i] == "." or (i == 0 and raw[i] in "+-")):
|
||||
i += 1
|
||||
if i == 0:
|
||||
return _DEFAULT_INTERVAL_SECONDS
|
||||
try:
|
||||
num = float(raw[:i])
|
||||
except ValueError:
|
||||
return _DEFAULT_INTERVAL_SECONDS
|
||||
unit = raw[i:].lower()
|
||||
if unit == "ms":
|
||||
seconds = num / 1000.0
|
||||
elif unit in ("s", ""):
|
||||
seconds = num
|
||||
elif unit == "m":
|
||||
seconds = num * 60.0
|
||||
else:
|
||||
return _DEFAULT_INTERVAL_SECONDS
|
||||
return seconds if seconds > 0 else _DEFAULT_INTERVAL_SECONDS
|
||||
|
||||
|
||||
def _parent_died(orig_ppid):
|
||||
"""Report whether this process has been reparented away from orig_ppid.
|
||||
|
||||
Reparenting is the standard POSIX signal that the original parent (here, the
|
||||
LocalAI process that spawned this backend) has exited: the orphan is handed
|
||||
to the nearest sub-reaper or to init (PID 1), so os.getppid() no longer
|
||||
matches the value captured at startup.
|
||||
"""
|
||||
ppid = os.getppid()
|
||||
return ppid != orig_ppid or ppid == 1
|
||||
|
||||
|
||||
def _watch(orig_ppid, interval, on_death):
|
||||
"""Poll until _parent_died reports the original parent is gone, then call
|
||||
on_death. Blocks, so run it on its own (daemon) thread."""
|
||||
import time
|
||||
|
||||
while True:
|
||||
time.sleep(interval)
|
||||
if _parent_died(orig_ppid):
|
||||
on_death()
|
||||
return
|
||||
|
||||
|
||||
def start_parent_death_watcher():
|
||||
"""Install the best-effort safety net described in this module's docstring.
|
||||
|
||||
No-op when disabled, on Windows, when already orphaned at startup
|
||||
(os.getppid() <= 1), or if already started. This is a backstop alongside —
|
||||
never a replacement for — LocalAI's graceful teardown.
|
||||
"""
|
||||
global _started
|
||||
if not _enabled():
|
||||
return
|
||||
with _started_lock:
|
||||
if _started:
|
||||
return
|
||||
orig_ppid = os.getppid()
|
||||
# A parent of 1 (or less) at startup means we were already orphaned (or
|
||||
# launched directly under init) — there is no original parent to watch.
|
||||
if orig_ppid <= 1:
|
||||
return
|
||||
interval = _interval_seconds()
|
||||
|
||||
def on_death():
|
||||
print(
|
||||
"backend parent process (pid {}) exited without stopping this "
|
||||
"backend; self-terminating to avoid orphaning".format(orig_ppid),
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
# Immediate, non-cleanup exit: this is a shutdown safety net and the
|
||||
# normal graceful path is already gone.
|
||||
os._exit(1)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=_watch,
|
||||
args=(orig_ppid, interval, on_death),
|
||||
name="parent-death-watcher",
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
_started = True
|
||||
150
backend/python/common/parent_watch_test.py
Normal file
150
backend/python/common/parent_watch_test.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Unit tests for the parent-death watcher (parent_watch.py).
|
||||
|
||||
Run standalone (Python standard library only, no backend venv needed):
|
||||
python3 -m unittest parent_watch_test
|
||||
|
||||
The core test (test_detects_reparent) builds a genuine two-level process tree
|
||||
(test -> middle -> grandchild) with os.fork, lets the middle process die, and
|
||||
asserts the grandchild's parent_watch._watch detects the reparenting and
|
||||
self-terminates — mirroring the Go test in pkg/grpc/parentwatch_test.go and the
|
||||
C++ test in backend/cpp/llama-cpp/parent_watch_test.cpp.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import parent_watch
|
||||
|
||||
|
||||
class TestParentWatchEnvParsing(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._saved = {
|
||||
k: os.environ.get(k)
|
||||
for k in (parent_watch.ENV_PARENT_WATCH, parent_watch.ENV_PARENT_WATCH_INTERVAL)
|
||||
}
|
||||
for k in self._saved:
|
||||
os.environ.pop(k, None)
|
||||
|
||||
def tearDown(self):
|
||||
for k, v in self._saved.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
def test_interval_default(self):
|
||||
self.assertEqual(parent_watch._interval_seconds(), 2.0)
|
||||
|
||||
def test_interval_units(self):
|
||||
cases = {"500ms": 0.5, "2s": 2.0, "1m": 60.0, "3": 3.0, "0.5s": 0.5}
|
||||
for raw, expected in cases.items():
|
||||
os.environ[parent_watch.ENV_PARENT_WATCH_INTERVAL] = raw
|
||||
self.assertAlmostEqual(parent_watch._interval_seconds(), expected, msg=raw)
|
||||
|
||||
def test_interval_garbage_falls_back(self):
|
||||
os.environ[parent_watch.ENV_PARENT_WATCH_INTERVAL] = "garbage"
|
||||
self.assertEqual(parent_watch._interval_seconds(), 2.0)
|
||||
|
||||
@unittest.skipIf(os.name == "nt" or sys.platform.startswith("win"), "POSIX only")
|
||||
def test_enabled_default(self):
|
||||
self.assertTrue(parent_watch._enabled())
|
||||
|
||||
@unittest.skipIf(os.name == "nt" or sys.platform.startswith("win"), "POSIX only")
|
||||
def test_disabled_by_falsey(self):
|
||||
for val in ("false", "0", "no", "off", "OFF", " False "):
|
||||
os.environ[parent_watch.ENV_PARENT_WATCH] = val
|
||||
self.assertFalse(parent_watch._enabled(), msg=val)
|
||||
|
||||
@unittest.skipIf(os.name == "nt" or sys.platform.startswith("win"), "POSIX only")
|
||||
def test_enabled_by_truthy(self):
|
||||
for val in ("true", "1", "yes", "on"):
|
||||
os.environ[parent_watch.ENV_PARENT_WATCH] = val
|
||||
self.assertTrue(parent_watch._enabled(), msg=val)
|
||||
|
||||
|
||||
@unittest.skipIf(os.name == "nt" or sys.platform.startswith("win"), "fork/reparent is POSIX only")
|
||||
class TestParentWatchReparent(unittest.TestCase):
|
||||
def _wait_for_file(self, path, timeout=10.0):
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
if os.path.exists(path):
|
||||
return True
|
||||
time.sleep(0.02)
|
||||
return False
|
||||
|
||||
def test_detects_reparent(self):
|
||||
tmpdir = tempfile.mkdtemp(prefix="parentwatch_test_")
|
||||
ready_file = os.path.join(tmpdir, "ready")
|
||||
exited_file = os.path.join(tmpdir, "exited")
|
||||
|
||||
middle = os.fork()
|
||||
if middle == 0:
|
||||
# ---- middle process ----
|
||||
grandchild = os.fork()
|
||||
if grandchild == 0:
|
||||
# ---- grandchild process: arm the REAL watcher against middle ----
|
||||
orig_ppid = os.getppid()
|
||||
|
||||
def on_death():
|
||||
with open(exited_file, "w") as f:
|
||||
f.write("1")
|
||||
os._exit(7)
|
||||
|
||||
threading.Thread(
|
||||
target=parent_watch._watch,
|
||||
args=(orig_ppid, 0.05, on_death),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
# Safety valve: never linger if something goes wrong.
|
||||
def bail():
|
||||
time.sleep(30)
|
||||
os._exit(2)
|
||||
|
||||
threading.Thread(target=bail, daemon=True).start()
|
||||
|
||||
# Signal readiness only after the watcher captured orig_ppid.
|
||||
with open(ready_file, "w") as f:
|
||||
f.write(str(os.getpid()))
|
||||
while True:
|
||||
time.sleep(1)
|
||||
else:
|
||||
# middle: wait until grandchild is ready, then exit to orphan it.
|
||||
if not self._wait_for_file(ready_file):
|
||||
os._exit(5)
|
||||
os._exit(0)
|
||||
|
||||
# ---- test (top) process ----
|
||||
os.waitpid(middle, 0) # reap middle only; grandchild is orphaned
|
||||
|
||||
self.assertTrue(os.path.exists(ready_file), "grandchild never signaled readiness")
|
||||
self.assertTrue(
|
||||
self._wait_for_file(exited_file),
|
||||
"watcher did not detect parent death within timeout",
|
||||
)
|
||||
|
||||
# Best-effort cleanup: kill the grandchild if it somehow survived.
|
||||
try:
|
||||
with open(ready_file) as f:
|
||||
pid = int(f.read().strip())
|
||||
if pid > 1:
|
||||
os.kill(pid, 9)
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
for p in (ready_file, exited_file):
|
||||
try:
|
||||
os.remove(p)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
os.rmdir(tmpdir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -58,7 +58,18 @@ def messages_to_dicts(proto_messages):
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
tool_calls = json.loads(msg.tool_calls)
|
||||
# Chat templates (e.g. Qwen) iterate function.arguments as a
|
||||
# mapping, but the OpenAI wire format carries it as a JSON
|
||||
# string — decode it back so the template's .items() works.
|
||||
for tc in tool_calls:
|
||||
fn = tc.get("function") if isinstance(tc, dict) else None
|
||||
if isinstance(fn, dict) and isinstance(fn.get("arguments"), str):
|
||||
try:
|
||||
fn["arguments"] = json.loads(fn["arguments"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
d["tool_calls"] = tool_calls
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
|
||||
122
backend/python/common/python_utils_test.py
Normal file
122
backend/python/common/python_utils_test.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Unit tests for the shared python backend helpers (python_utils.py).
|
||||
|
||||
Run standalone (Python standard library only, no backend venv needed):
|
||||
python3 -m unittest python_utils_test
|
||||
|
||||
These mirror the server-less helper tests in backend/python/mlx/test.py
|
||||
(TestSharedHelpers), but live here so they run on any platform: the mlx
|
||||
test module imports grpc/backend_pb2 at import time and needs the MLX venv,
|
||||
whereas python_utils has no third-party dependency. Proto Message objects
|
||||
are faked with types.SimpleNamespace (real proto fields default to "").
|
||||
"""
|
||||
|
||||
import json
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
|
||||
|
||||
def _msg(**fields):
|
||||
"""Fake a proto Message: every unset field is the empty string, as protobuf."""
|
||||
defaults = {
|
||||
"role": "",
|
||||
"content": "",
|
||||
"name": "",
|
||||
"tool_call_id": "",
|
||||
"reasoning_content": "",
|
||||
"tool_calls": "",
|
||||
}
|
||||
defaults.update(fields)
|
||||
return types.SimpleNamespace(**defaults)
|
||||
|
||||
|
||||
class TestParseOptions(unittest.TestCase):
|
||||
def test_type_inference(self):
|
||||
opts = parse_options(
|
||||
["temperature:0.7", "max_tokens:128", "trust:true", "name:hello", "no_colon_skipped"]
|
||||
)
|
||||
self.assertEqual(opts["temperature"], 0.7)
|
||||
self.assertEqual(opts["max_tokens"], 128)
|
||||
self.assertIs(opts["trust"], True)
|
||||
self.assertEqual(opts["name"], "hello")
|
||||
self.assertNotIn("no_colon_skipped", opts)
|
||||
|
||||
|
||||
class TestMessagesToDicts(unittest.TestCase):
|
||||
def test_basic_fields(self):
|
||||
out = messages_to_dicts(
|
||||
[
|
||||
_msg(role="user", content="hi"),
|
||||
_msg(role="tool", content="42", tool_call_id="call_1", name="f"),
|
||||
]
|
||||
)
|
||||
self.assertEqual(out[0], {"role": "user", "content": "hi"})
|
||||
self.assertEqual(out[1]["tool_call_id"], "call_1")
|
||||
self.assertEqual(out[1]["name"], "f")
|
||||
|
||||
def test_tool_call_arguments_string_decoded_to_mapping(self):
|
||||
# OpenAI wire format ships function.arguments as a JSON *string*; chat
|
||||
# templates iterate it as a mapping, so it must come back as a dict.
|
||||
out = messages_to_dicts(
|
||||
[
|
||||
_msg(
|
||||
role="assistant",
|
||||
tool_calls=json.dumps(
|
||||
[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "Rome"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
args = out[0]["tool_calls"][0]["function"]["arguments"]
|
||||
self.assertEqual(args, {"location": "Rome"})
|
||||
self.assertEqual(dict(args.items()), {"location": "Rome"})
|
||||
|
||||
def test_tool_call_arguments_already_mapping_is_idempotent(self):
|
||||
out = messages_to_dicts(
|
||||
[
|
||||
_msg(
|
||||
role="assistant",
|
||||
tool_calls=json.dumps(
|
||||
[{"function": {"name": "f", "arguments": {"a": 1}}}]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.assertEqual(out[0]["tool_calls"][0]["function"]["arguments"], {"a": 1})
|
||||
|
||||
def test_tool_call_arguments_invalid_json_left_as_string(self):
|
||||
out = messages_to_dicts(
|
||||
[
|
||||
_msg(
|
||||
role="assistant",
|
||||
tool_calls=json.dumps(
|
||||
[{"function": {"name": "f", "arguments": "not-json"}}]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.assertEqual(out[0]["tool_calls"][0]["function"]["arguments"], "not-json")
|
||||
|
||||
def test_tool_call_without_function_key(self):
|
||||
out = messages_to_dicts(
|
||||
[_msg(role="assistant", tool_calls=json.dumps([{"id": "call_1"}]))]
|
||||
)
|
||||
self.assertEqual(out[0]["tool_calls"], [{"id": "call_1"}])
|
||||
|
||||
def test_tool_calls_invalid_json_dropped(self):
|
||||
out = messages_to_dicts([_msg(role="assistant", tool_calls="{not json")])
|
||||
self.assertNotIn("tool_calls", out[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -13,6 +13,17 @@ fi
|
||||
# fish-speech uses pyrootutils which requires a .project-root marker
|
||||
touch "${backend_dir}/.project-root"
|
||||
|
||||
# On darwin arm64 the transitive `tokenizers` dep compiles its Rust extension
|
||||
# from source (Linux uses prebuilt manylinux wheels, so it never compiles
|
||||
# there). The pinned tokenizers crate that fish-speech's stack resolves to
|
||||
# contains a `&T` -> `&mut T` cast that trips the now-deny-by-default
|
||||
# `invalid_reference_casting` lint in the macOS runner's newer Rust toolchain,
|
||||
# breaking the build (seen in the v4.5.5 release CI fish-speech darwin/metal
|
||||
# job). Allow that lint so the unchanged third-party crate compiles as before.
|
||||
# Append rather than clobber any pre-existing RUSTFLAGS; harmless on Linux
|
||||
# where no Rust compile happens.
|
||||
export RUSTFLAGS="${RUSTFLAGS:-} -A invalid_reference_casting"
|
||||
|
||||
installRequirements
|
||||
|
||||
# Clone fish-speech source (the pip package doesn't include inference modules)
|
||||
|
||||
@@ -3,4 +3,5 @@ protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
pip
|
||||
chardet
|
||||
chardet
|
||||
click
|
||||
|
||||
@@ -147,9 +147,25 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
tool_calls = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
else:
|
||||
# OpenAI wire format carries function.arguments as a
|
||||
# JSON-encoded string, but chat templates (e.g. Qwen3)
|
||||
# iterate over it as a mapping. The vllm backend
|
||||
# already parses arguments before applying the chat
|
||||
# template (PR #10256); mirror that here so the
|
||||
# sglang backend works with the same wire format.
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function") if isinstance(tc, dict) else None
|
||||
if isinstance(func, dict) and isinstance(func.get("arguments"), str):
|
||||
try:
|
||||
func["arguments"] = json.loads(func["arguments"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
d["tool_calls"] = tool_calls
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
|
||||
@@ -748,7 +748,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# When (A) native streaming ran cleanly, per-delta yields above already
|
||||
# delivered everything — do NOT extract again on the full text or we'd
|
||||
# duplicate content/tool_calls into the final chunk.
|
||||
if has_tool_parser and not (native_streaming and not native_streaming_error):
|
||||
# NOTE: `native_streaming` is a capability flag ("streaming parser is
|
||||
# available"), not a state flag ("streaming actually ran"). For
|
||||
# non-streaming requests it is still True but the per-delta loop was
|
||||
# never entered, so we MUST still run extract_tool_calls here. Hence
|
||||
# the explicit `streaming and …` guard on both branches.
|
||||
if has_tool_parser and not (streaming and native_streaming and not native_streaming_error):
|
||||
try:
|
||||
tp = tp_instance
|
||||
if tp is None:
|
||||
@@ -770,7 +775,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"Tool parser error: {e}", file=sys.stderr)
|
||||
elif native_streaming and not native_streaming_error:
|
||||
elif streaming and native_streaming and not native_streaming_error:
|
||||
# Per-delta path already emitted content + tool_calls; the final
|
||||
# chat_delta should carry only metadata (token counts, logprobs).
|
||||
content = ""
|
||||
|
||||
@@ -35,6 +35,21 @@ if [ "x${BUILD_PROFILE}" == "xcpu" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# AMD ROCm: vLLM ships prebuilt ROCm wheels, but on a DEDICATED index
|
||||
# (https://wheels.vllm.ai/rocm/), NOT PyPI, and ONLY for CPython 3.12. On any
|
||||
# other Python the installer silently falls back to the CUDA-only PyPI wheel,
|
||||
# which is unusable on an AMD GPU (import fails, so the backend never finds the
|
||||
# vllm module). Force Python 3.12 before the venv is created (matches the
|
||||
# intel/l4t13 cp312 bump); the hipblas branch below pulls vllm from the ROCm
|
||||
# wheel index. unsafe-best-match lets uv consult that index and PyPI together.
|
||||
# https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html?device=rocm
|
||||
if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# cublas13 pulls the vLLM wheel from a per-tag cu130 index (PyPI's vllm wheel
|
||||
# is built against CUDA 12 and won't load on cu130). uv's default per-package
|
||||
# first-match strategy would still pick the PyPI wheel, so allow it to consult
|
||||
@@ -104,7 +119,7 @@ if [ "$(uname -s)" = "Darwin" ]; then
|
||||
# can rewrite it. Darwin therefore follows vllm-metal and can lag the Linux
|
||||
# vllm pin (requirements-cublas13-after.txt, bumped independently against
|
||||
# vllm/vllm) until vllm-metal supports a newer vLLM.
|
||||
VLLM_METAL_VERSION="v0.3.0.dev20260622062346"
|
||||
VLLM_METAL_VERSION="v0.3.0.dev20260701212152"
|
||||
|
||||
# The coupled vLLM source version is whatever this vllm-metal release builds
|
||||
# against -- it declares it in its own installer as `vllm_v=`. Derive it from
|
||||
@@ -194,6 +209,22 @@ elif [ "x${BUILD_TYPE}" == "xintel" ]; then
|
||||
export CMAKE_PREFIX_PATH="$(python -c 'import site; print(site.getsitepackages()[0])'):${CMAKE_PREFIX_PATH:-}"
|
||||
VLLM_TARGET_DEVICE=xpu uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps .
|
||||
popd
|
||||
# AMD ROCm: install vllm from its dedicated ROCm wheel index instead of the
|
||||
# CUDA-only PyPI wheel. installRequirements brings the base ROCm
|
||||
# torch/transformers (requirements-hipblas.txt), then we pull vllm (plus the
|
||||
# matching ROCm torch, via --upgrade) from wheels.vllm.ai/rocm. This is the
|
||||
# method upstream prescribes for AMD; the Python-3.12 pin is set above.
|
||||
# There is intentionally no requirements-hipblas-after.txt: a bare `vllm`
|
||||
# there would resolve to the CUDA wheel, and installRequirements never loads
|
||||
# a ${BUILD_TYPE}-after file for hipblas anyway (BUILD_TYPE == BUILD_PROFILE).
|
||||
# https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html?device=rocm
|
||||
elif [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
installRequirements
|
||||
|
||||
# --upgrade reconciles the base ROCm torch to whatever the vllm ROCm wheel
|
||||
# pins; --extra-index-url adds the ROCm wheel repository on top of PyPI.
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} \
|
||||
--extra-index-url https://wheels.vllm.ai/rocm/ --upgrade vllm
|
||||
# FROM_SOURCE=true on a CPU build skips the prebuilt vllm wheel in
|
||||
# requirements-cpu-after.txt and compiles vllm locally against the host's
|
||||
# actual CPU. Not used by default because it takes ~30-40 minutes, but
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
# on a cu130 host. Pull the cu130-flavoured wheel from vLLM's per-tag index
|
||||
# instead — the cublas13 case in install.sh adds --index-strategy=unsafe-best-match
|
||||
# so uv consults this index alongside PyPI.
|
||||
--extra-index-url https://wheels.vllm.ai/0.23.0/cu130
|
||||
--extra-index-url https://wheels.vllm.ai/0.24.0/cu130
|
||||
# VERSION COUPLING: darwin/Apple-Silicon builds use vllm-metal (see install.sh),
|
||||
# which pins this exact vLLM version. Bumping vllm here means coordinating with a
|
||||
# vllm-metal release that supports the new version, or macOS/Metal builds break.
|
||||
vllm==0.23.0
|
||||
vllm==0.24.0
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
vllm
|
||||
@@ -351,6 +351,16 @@ impl Backend for KokorosService {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
type AudioTranscriptionLiveStream =
|
||||
ReceiverStream<Result<backend::TranscriptLiveResponse, Status>>;
|
||||
|
||||
async fn audio_transcription_live(
|
||||
&self,
|
||||
_: Request<tonic::Streaming<backend::TranscriptLiveRequest>>,
|
||||
) -> Result<Response<Self::AudioTranscriptionLiveStream>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn diarize(
|
||||
&self,
|
||||
_: Request<backend::DiarizeRequest>,
|
||||
|
||||
@@ -207,12 +207,20 @@ func (l *Launcher) StartLocalAI() error {
|
||||
}
|
||||
|
||||
// Build command arguments
|
||||
dataPath := l.GetDataPath()
|
||||
args := []string{
|
||||
"run",
|
||||
"--models-path", l.config.ModelsPath,
|
||||
"--backends-path", l.config.BackendsPath,
|
||||
"--address", l.config.Address,
|
||||
"--log-level", l.config.LogLevel,
|
||||
// Keep persistent data and dynamic config under the launcher's data
|
||||
// directory (~/.localai) rather than letting the server resolve them
|
||||
// to ${basepath}/{data,configuration}. ${basepath} expands to the
|
||||
// launcher process's CWD (often the user's home root), which puts
|
||||
// ~/data and ~/configuration outside ~/.localai. See #10610.
|
||||
"--data-path", filepath.Join(dataPath, "data"),
|
||||
"--localai-config-dir", filepath.Join(dataPath, "configuration"),
|
||||
}
|
||||
|
||||
l.localaiCmd = exec.CommandContext(l.ctx, binaryPath, args...)
|
||||
@@ -429,7 +437,7 @@ func (l *Launcher) CheckForUpdates() (bool, string, error) {
|
||||
}
|
||||
|
||||
// DownloadUpdate downloads the latest version
|
||||
func (l *Launcher) DownloadUpdate(version string, progressCallback func(float64)) error {
|
||||
func (l *Launcher) DownloadUpdate(version string, progressCallback func(downloaded, total int64)) error {
|
||||
return l.releaseManager.DownloadRelease(version, progressCallback)
|
||||
}
|
||||
|
||||
@@ -486,7 +494,6 @@ func (l *Launcher) showDownloadLocalAIDialog() {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create a standalone window for the download dialog
|
||||
dialogWindow := l.app.NewWindow("LocalAI Installation Required")
|
||||
dialogWindow.Resize(fyne.NewSize(500, 350))
|
||||
dialogWindow.CenterOnScreen()
|
||||
dialogWindow.SetCloseIntercept(func() {
|
||||
dialogWindow.Close()
|
||||
@@ -548,6 +555,7 @@ func (l *Launcher) showDownloadLocalAIDialog() {
|
||||
)
|
||||
|
||||
dialogWindow.SetContent(content)
|
||||
resizeToContent(dialogWindow, content)
|
||||
dialogWindow.Show()
|
||||
})
|
||||
}
|
||||
@@ -621,88 +629,134 @@ func (l *Launcher) showDownloadError(title, message string) {
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a standalone progress window for downloading LocalAI
|
||||
// after a fresh install (no LocalAI binary present yet).
|
||||
func (l *Launcher) showDownloadProgress(version, title string) {
|
||||
l.showDownloadProgressWindow(version, title, func(win fyne.Window) {
|
||||
dialog.ShowConfirm("Installation Complete",
|
||||
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
|
||||
func(bool) {
|
||||
win.Close()
|
||||
l.updateStatus("LocalAI installed successfully")
|
||||
if l.systray != nil {
|
||||
l.systray.recreateMenu()
|
||||
}
|
||||
}, win)
|
||||
})
|
||||
}
|
||||
|
||||
// showDownloadProgressWindow renders the download progress popup shared by every
|
||||
// "download/upgrade LocalAI" entry point. It owns the progress bar, the
|
||||
// human-readable byte readout, resume-aware retry, and content-fit window
|
||||
// sizing so the behaviour stays identical everywhere. onSuccess runs (on the UI
|
||||
// goroutine) once the download verifies, and is responsible for the success
|
||||
// dialog and any follow-up; the window is passed in so it can be parented/closed.
|
||||
func (l *Launcher) showDownloadProgressWindow(version, title string, onSuccess func(win fyne.Window)) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create progress window
|
||||
progressWindow := l.app.NewWindow("Downloading LocalAI")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
progressWindow.SetCloseIntercept(func() {
|
||||
progressWindow.Close()
|
||||
})
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label. Truncate with an ellipsis so a long "Download failed:
|
||||
// <url>" message can't stretch the window (and progress bar) to fit the
|
||||
// whole error on one line; the full error is shown in the dialog below.
|
||||
// whole error on one line.
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
statusLabel.Truncation = fyne.TextTruncateEllipsis
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := l.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
l.app.OpenURL(releaseNotesURL)
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
// Retry button: hidden until a download fails. GitHub downloads are
|
||||
// flaky, and the underlying download resumes from the partial file, so
|
||||
// a retry continues where it left off rather than starting over.
|
||||
retryButton := widget.NewButton("Retry", nil)
|
||||
retryButton.Importance = widget.HighImportance
|
||||
retryButton.Hide()
|
||||
|
||||
buttonRow := container.NewHBox(releaseNotesButton, retryButton)
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel(title),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
buttonRow,
|
||||
)
|
||||
progressWindow.SetContent(content)
|
||||
resizeToContent(progressWindow, content)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
var startDownload func()
|
||||
startDownload = func() {
|
||||
retryButton.Hide()
|
||||
progressBar.SetValue(0)
|
||||
statusLabel.SetText("Preparing download...")
|
||||
resizeToContent(progressWindow, content)
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := l.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
go func() {
|
||||
err := l.DownloadUpdate(version, func(downloaded, total int64) {
|
||||
fyne.Do(func() {
|
||||
if total > 0 {
|
||||
progressBar.SetValue(float64(downloaded) / float64(total))
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading… %s / %s", formatBytes(downloaded), formatBytes(total)))
|
||||
} else {
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading… %s", formatBytes(downloaded)))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
retryButton.Show()
|
||||
resizeToContent(progressWindow, content)
|
||||
return
|
||||
}
|
||||
progressBar.SetValue(1.0)
|
||||
statusLabel.SetText("Download complete")
|
||||
onSuccess(progressWindow)
|
||||
})
|
||||
}()
|
||||
}
|
||||
retryButton.OnTapped = startDownload
|
||||
|
||||
// Show success dialog
|
||||
dialog.ShowConfirm("Installation Complete",
|
||||
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
|
||||
func(close bool) {
|
||||
progressWindow.Close()
|
||||
// Update status and refresh systray menu
|
||||
l.updateStatus("LocalAI installed successfully")
|
||||
|
||||
if l.systray != nil {
|
||||
l.systray.recreateMenu()
|
||||
}
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
}()
|
||||
progressWindow.Show()
|
||||
startDownload()
|
||||
})
|
||||
}
|
||||
|
||||
// resizeToContent sizes a window to fit its content (with a sane minimum width)
|
||||
// so the dialog doesn't show a large blank gap below the last widget.
|
||||
func resizeToContent(w fyne.Window, content fyne.CanvasObject) {
|
||||
size := content.MinSize()
|
||||
if size.Width < 400 {
|
||||
size.Width = 400
|
||||
}
|
||||
w.Resize(size)
|
||||
}
|
||||
|
||||
// formatBytes renders a byte count as a human-readable size (e.g. "12.3 MB").
|
||||
func formatBytes(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// monitorLogs monitors the output of LocalAI and adds it to the log buffer
|
||||
func (l *Launcher) monitorLogs(reader io.Reader, prefix string) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -50,6 +51,12 @@ type ReleaseManager struct {
|
||||
ChecksumsPath string
|
||||
// MetadataPath is where version metadata is stored
|
||||
MetadataPath string
|
||||
// BaseDownloadURL is the base URL release assets are downloaded from
|
||||
// (defaults to https://github.com; overridable for testing)
|
||||
BaseDownloadURL string
|
||||
// RetryBackoff is the base wait between download attempts; the Nth retry
|
||||
// waits N*RetryBackoff (defaults to 1s; lowered in tests)
|
||||
RetryBackoff time.Duration
|
||||
// HTTPClient is the HTTP client used for downloads
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
@@ -62,28 +69,94 @@ func NewReleaseManager() *ReleaseManager {
|
||||
metadataPath := filepath.Join(homeDir, ".localai", "metadata")
|
||||
|
||||
return &ReleaseManager{
|
||||
GitHubOwner: "mudler",
|
||||
GitHubRepo: "LocalAI",
|
||||
BinaryPath: binaryPath,
|
||||
CurrentVersion: internal.PrintableVersion(),
|
||||
ChecksumsPath: checksumsPath,
|
||||
MetadataPath: metadataPath,
|
||||
HTTPClient: httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects()),
|
||||
GitHubOwner: "mudler",
|
||||
GitHubRepo: "LocalAI",
|
||||
BinaryPath: binaryPath,
|
||||
CurrentVersion: internal.PrintableVersion(),
|
||||
ChecksumsPath: checksumsPath,
|
||||
MetadataPath: metadataPath,
|
||||
BaseDownloadURL: "https://github.com",
|
||||
RetryBackoff: 1 * time.Second,
|
||||
HTTPClient: httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects()),
|
||||
}
|
||||
}
|
||||
|
||||
// GetLatestRelease fetches the latest release information from GitHub
|
||||
// GetLatestRelease resolves the latest LocalAI release.
|
||||
//
|
||||
// It first follows the github.com "releases/latest" redirect, which reveals the
|
||||
// latest tag in the final URL and—crucially—is NOT subject to the
|
||||
// 60-requests/hour unauthenticated rate limit of api.github.com. That limit is
|
||||
// per-IP, so on shared/NAT/CGNAT/cloud addresses the API returns 403 almost
|
||||
// immediately (e.g. on a fresh install with no LocalAI present yet). The
|
||||
// redirect avoids that entirely. The richer JSON API is kept only as a fallback.
|
||||
//
|
||||
// Only the version is consumed by callers, so the redirect's tag is sufficient.
|
||||
func (rm *ReleaseManager) GetLatestRelease() (*Release, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo)
|
||||
version, redirectErr := rm.latestVersionFromRedirect()
|
||||
if redirectErr == nil {
|
||||
return &Release{Version: version}, nil
|
||||
}
|
||||
log.Printf("Could not resolve latest version via release redirect (%v); falling back to GitHub API", redirectErr)
|
||||
|
||||
release, apiErr := rm.latestReleaseFromAPI()
|
||||
if apiErr != nil {
|
||||
// Surface both failures so a rate-limited API doesn't mask the (usually
|
||||
// more relevant) redirect error.
|
||||
return nil, fmt.Errorf("failed to fetch latest release: %v (redirect: %v)", apiErr, redirectErr)
|
||||
}
|
||||
return release, nil
|
||||
}
|
||||
|
||||
// latestVersionFromRedirect returns the latest tag by following the github.com
|
||||
// "releases/latest" redirect to ".../releases/tag/<tag>".
|
||||
func (rm *ReleaseManager) latestVersionFromRedirect() (string, error) {
|
||||
url := fmt.Sprintf("%s/%s/%s/releases/latest", rm.BaseDownloadURL, rm.GitHubOwner, rm.GitHubRepo)
|
||||
|
||||
resp, err := rm.HTTPClient.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status %s", resp.Status)
|
||||
}
|
||||
|
||||
// After the redirect is followed, the final request URL is the tag page.
|
||||
version := path.Base(resp.Request.URL.Path)
|
||||
if version == "" || version == "." || version == "latest" {
|
||||
return "", fmt.Errorf("could not determine version from %s", resp.Request.URL.String())
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// latestReleaseFromAPI fetches the latest release JSON from api.github.com. This
|
||||
// is the fallback path; it is rate-limited unless GITHUB_TOKEN is set.
|
||||
func (rm *ReleaseManager) latestReleaseFromAPI() (*Release, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
// An optional token lifts the unauthenticated 60/hour limit to 5000/hour.
|
||||
if token := os.Getenv("GITHUB_TOKEN"); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
resp, err := rm.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch latest release: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch latest release: status %d", resp.StatusCode)
|
||||
if (resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests) &&
|
||||
resp.Header.Get("X-RateLimit-Remaining") == "0" {
|
||||
return nil, fmt.Errorf("GitHub API rate limit exceeded (status %d); retry later or set GITHUB_TOKEN to raise the limit", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse the JSON response properly
|
||||
@@ -106,7 +179,7 @@ func (rm *ReleaseManager) GetLatestRelease() (*Release, error) {
|
||||
}
|
||||
|
||||
// DownloadRelease downloads a specific version of LocalAI
|
||||
func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(float64)) error {
|
||||
func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(downloaded, total int64)) error {
|
||||
// Ensure the binary directory exists
|
||||
if err := os.MkdirAll(rm.BinaryPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create binary directory: %w", err)
|
||||
@@ -117,16 +190,16 @@ func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(
|
||||
localPath := filepath.Join(rm.BinaryPath, "local-ai")
|
||||
|
||||
// Download the binary
|
||||
downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
|
||||
rm.GitHubOwner, rm.GitHubRepo, version, binaryName)
|
||||
downloadURL := fmt.Sprintf("%s/%s/%s/releases/download/%s/%s",
|
||||
rm.BaseDownloadURL, rm.GitHubOwner, rm.GitHubRepo, version, binaryName)
|
||||
|
||||
if err := rm.downloadFile(downloadURL, localPath, progressCallback); err != nil {
|
||||
return fmt.Errorf("failed to download binary: %w", err)
|
||||
}
|
||||
|
||||
// Download and verify checksums
|
||||
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt",
|
||||
rm.GitHubOwner, rm.GitHubRepo, version, version)
|
||||
checksumURL := fmt.Sprintf("%s/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt",
|
||||
rm.BaseDownloadURL, rm.GitHubOwner, rm.GitHubRepo, version, version)
|
||||
|
||||
checksumPath := filepath.Join(rm.BinaryPath, "checksums.txt")
|
||||
manualChecksumPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version))
|
||||
@@ -154,6 +227,10 @@ func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(
|
||||
// Verify the checksum if we have a checksum file
|
||||
if _, err := os.Stat(checksumPath); err == nil {
|
||||
if err := rm.VerifyChecksum(localPath, checksumPath, binaryName); err != nil {
|
||||
// Discard the corrupt binary (and any leftover partial) so the next
|
||||
// retry starts from a clean slate rather than resuming corruption.
|
||||
os.Remove(localPath)
|
||||
os.Remove(localPath + ".part")
|
||||
return fmt.Errorf("checksum verification failed: %w", err)
|
||||
}
|
||||
log.Printf("Checksum verification successful")
|
||||
@@ -196,44 +273,88 @@ func (rm *ReleaseManager) GetBinaryName(version string) string {
|
||||
}
|
||||
|
||||
// downloadFile downloads a file from a URL to a local path with optional progress callback
|
||||
func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(float64)) error {
|
||||
func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(downloaded, total int64)) error {
|
||||
return rm.downloadFileWithRetry(url, filepath, progressCallback, 3)
|
||||
}
|
||||
|
||||
// downloadFileWithRetry downloads a file from a URL with retry logic
|
||||
func (rm *ReleaseManager) downloadFileWithRetry(url, filepath string, progressCallback func(float64), maxRetries int) error {
|
||||
// downloadFileWithRetry downloads a file with retry and HTTP Range resume.
|
||||
//
|
||||
// The body is streamed to "<dest>.part" and only renamed to dest on success, so
|
||||
// a dropped connection leaves a partial file that the next attempt continues via
|
||||
// a "Range: bytes=N-" request instead of restarting from zero. This matters for
|
||||
// GitHub release downloads, which are large and flaky.
|
||||
func (rm *ReleaseManager) downloadFileWithRetry(url, dest string, progressCallback func(downloaded, total int64), maxRetries int) error {
|
||||
partPath := dest + ".part"
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 1 {
|
||||
log.Printf("Retrying download (attempt %d/%d): %s", attempt, maxRetries, url)
|
||||
time.Sleep(time.Duration(attempt) * time.Second)
|
||||
time.Sleep(time.Duration(attempt) * rm.RetryBackoff)
|
||||
}
|
||||
|
||||
resp, err := rm.HTTPClient.Get(url)
|
||||
// Resume from however much we already have on disk.
|
||||
var offset int64
|
||||
if fi, err := os.Stat(partPath); err == nil {
|
||||
offset = fi.Size()
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if offset > 0 {
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", offset))
|
||||
}
|
||||
|
||||
resp, err := rm.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
// Server ignored the Range (or we had nothing): start fresh.
|
||||
offset = 0
|
||||
case http.StatusPartialContent:
|
||||
// Resume: append to the existing partial file.
|
||||
case http.StatusRequestedRangeNotSatisfiable:
|
||||
// Stale or already-complete partial: discard and restart fresh.
|
||||
resp.Body.Close()
|
||||
os.Remove(partPath)
|
||||
lastErr = fmt.Errorf("partial download no longer valid (status %s), restarting", resp.Status)
|
||||
continue
|
||||
default:
|
||||
resp.Body.Close()
|
||||
lastErr = fmt.Errorf("bad status: %s", resp.Status)
|
||||
continue
|
||||
}
|
||||
|
||||
out, err := os.Create(filepath)
|
||||
var out *os.File
|
||||
if offset > 0 {
|
||||
out, err = os.OpenFile(partPath, os.O_WRONLY|os.O_APPEND, 0644)
|
||||
} else {
|
||||
out, err = os.Create(partPath)
|
||||
}
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a progress reader if callback is provided
|
||||
// On a 206 the Content-Length is the remaining bytes, so the full size
|
||||
// is what we already have plus what's still to come.
|
||||
total := resp.ContentLength
|
||||
if offset > 0 && total > 0 {
|
||||
total += offset
|
||||
}
|
||||
|
||||
var reader io.Reader = resp.Body
|
||||
if progressCallback != nil && resp.ContentLength > 0 {
|
||||
if progressCallback != nil && total > 0 {
|
||||
reader = &progressReader{
|
||||
Reader: resp.Body,
|
||||
Total: resp.ContentLength,
|
||||
Total: total,
|
||||
Current: offset,
|
||||
Callback: progressCallback,
|
||||
}
|
||||
}
|
||||
@@ -243,11 +364,14 @@ func (rm *ReleaseManager) downloadFileWithRetry(url, filepath string, progressCa
|
||||
out.Close()
|
||||
|
||||
if err != nil {
|
||||
// Keep the partial file so the next attempt can resume from it.
|
||||
lastErr = err
|
||||
os.Remove(filepath)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.Rename(partPath, dest); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -322,20 +446,21 @@ func (rm *ReleaseManager) saveVersionMetadata(version string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// progressReader wraps an io.Reader to provide download progress
|
||||
// progressReader wraps an io.Reader to provide download progress as a
|
||||
// (downloaded, total) byte count so callers can render both a progress bar and
|
||||
// a human-readable size.
|
||||
type progressReader struct {
|
||||
io.Reader
|
||||
Total int64
|
||||
Current int64
|
||||
Callback func(float64)
|
||||
Callback func(downloaded, total int64)
|
||||
}
|
||||
|
||||
func (pr *progressReader) Read(p []byte) (int, error) {
|
||||
n, err := pr.Reader.Read(p)
|
||||
pr.Current += int64(n)
|
||||
if pr.Callback != nil {
|
||||
progress := float64(pr.Current) / float64(pr.Total)
|
||||
pr.Callback(progress)
|
||||
pr.Callback(pr.Current, pr.Total)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
package launcher_test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -178,4 +186,221 @@ var _ = Describe("ReleaseManager", func() {
|
||||
Expect(err.Error()).To(ContainSubstring("checksum not found"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("DownloadRelease resume and retry", func() {
|
||||
var (
|
||||
version string
|
||||
binaryName string
|
||||
content []byte
|
||||
checksums string
|
||||
finalPath string
|
||||
partPath string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
version = "v9.9.9"
|
||||
binaryName = rm.GetBinaryName(version)
|
||||
|
||||
// Deterministic, non-trivial content so resume/append bugs surface.
|
||||
content = make([]byte, 4096)
|
||||
for i := range content {
|
||||
content[i] = byte(i % 251)
|
||||
}
|
||||
sum := sha256.Sum256(content)
|
||||
checksums = fmt.Sprintf("%s %s\n", hex.EncodeToString(sum[:]), binaryName)
|
||||
|
||||
finalPath = filepath.Join(tempDir, "local-ai")
|
||||
partPath = finalPath + ".part"
|
||||
|
||||
// Isolate the persistent checksum/metadata dirs to the temp dir so
|
||||
// the test never touches the real ~/.localai and existing checksum
|
||||
// files don't short-circuit the download.
|
||||
rm.ChecksumsPath = filepath.Join(tempDir, "checksums")
|
||||
rm.MetadataPath = filepath.Join(tempDir, "metadata")
|
||||
rm.GitHubOwner = "owner"
|
||||
rm.GitHubRepo = "repo"
|
||||
rm.RetryBackoff = time.Millisecond
|
||||
|
||||
Expect(os.MkdirAll(tempDir, 0755)).To(Succeed())
|
||||
})
|
||||
|
||||
It("resumes from a partial .part file using a Range request", func() {
|
||||
Expect(os.WriteFile(partPath, content[:1024], 0644)).To(Succeed())
|
||||
|
||||
var mu sync.Mutex
|
||||
sawRange := false
|
||||
binBytesServed := 0
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
if rangeHdr := r.Header.Get("Range"); rangeHdr != "" {
|
||||
var start int
|
||||
_, _ = fmt.Sscanf(rangeHdr, "bytes=%d-", &start)
|
||||
mu.Lock()
|
||||
sawRange = true
|
||||
mu.Unlock()
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, len(content)-1, len(content)))
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
n, _ := w.Write(content[start:])
|
||||
mu.Lock()
|
||||
binBytesServed += n
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
n, _ := w.Write(content)
|
||||
mu.Lock()
|
||||
binBytesServed += n
|
||||
mu.Unlock()
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
err := rm.DownloadRelease(version, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got, err := os.ReadFile(finalPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(content))
|
||||
Expect(sawRange).To(BeTrue(), "expected the download to resume with a Range request")
|
||||
Expect(binBytesServed).To(Equal(len(content)-1024), "expected only the remaining bytes to be served")
|
||||
Expect(partPath).ToNot(BeAnExistingFile())
|
||||
})
|
||||
|
||||
It("starts fresh when the server ignores the Range header (200)", func() {
|
||||
// A stale/garbage partial that must NOT be appended to.
|
||||
Expect(os.WriteFile(partPath, []byte("garbage-garbage-garbage"), 0644)).To(Succeed())
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
// Ignore any Range and always serve the full body.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(content)
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
err := rm.DownloadRelease(version, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got, err := os.ReadFile(finalPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(content))
|
||||
})
|
||||
|
||||
It("restarts the download when the partial is stale (416)", func() {
|
||||
// Oversized partial -> requested Range start is beyond the content.
|
||||
Expect(os.WriteFile(partPath, make([]byte, len(content)+10), 0644)).To(Succeed())
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
if rangeHdr := r.Header.Get("Range"); rangeHdr != "" {
|
||||
var start int
|
||||
_, _ = fmt.Sscanf(rangeHdr, "bytes=%d-", &start)
|
||||
if start >= len(content) {
|
||||
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, len(content)-1, len(content)))
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
_, _ = w.Write(content[start:])
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(content)
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
err := rm.DownloadRelease(version, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got, err := os.ReadFile(finalPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(content))
|
||||
})
|
||||
|
||||
It("removes the downloaded file when checksum verification fails", func() {
|
||||
bad := []byte("this is definitely not the expected binary content")
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
// Checksums are for `content`, but we serve `bad`.
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bad)
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
err := rm.DownloadRelease(version, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("checksum"))
|
||||
Expect(finalPath).ToNot(BeAnExistingFile())
|
||||
Expect(partPath).ToNot(BeAnExistingFile())
|
||||
})
|
||||
|
||||
It("reports progress as downloaded and total byte counts", func() {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(content)
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
var mu sync.Mutex
|
||||
var lastDownloaded, lastTotal int64
|
||||
err := rm.DownloadRelease(version, func(downloaded, total int64) {
|
||||
mu.Lock()
|
||||
lastDownloaded = downloaded
|
||||
lastTotal = total
|
||||
mu.Unlock()
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(lastTotal).To(Equal(int64(len(content))))
|
||||
Expect(lastDownloaded).To(Equal(int64(len(content))))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetLatestRelease", func() {
|
||||
It("resolves the latest version from the releases/latest redirect", func() {
|
||||
// The github.com redirect path must be preferred over the
|
||||
// rate-limited api.github.com, so a working redirect yields the tag
|
||||
// without ever needing the API.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/releases/latest"):
|
||||
http.Redirect(w, r, "/owner/repo/releases/tag/v9.9.9", http.StatusFound)
|
||||
case strings.HasSuffix(r.URL.Path, "/releases/tag/v9.9.9"):
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
rm.GitHubOwner = "owner"
|
||||
rm.GitHubRepo = "repo"
|
||||
|
||||
release, err := rm.GetLatestRelease()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(release.Version).To(Equal("v9.9.9"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -443,84 +443,23 @@ func (sm *SystrayManager) showStartupErrorDialog(err error) {
|
||||
})
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a progress window for downloading updates
|
||||
// showDownloadProgress shows a progress window for downloading updates. The
|
||||
// progress UI (byte readout, resume-aware retry, sizing) is shared with the
|
||||
// other download entry points via the launcher; only the post-success behaviour
|
||||
// (restart prompt + systray refresh) is specific to the update flow.
|
||||
func (sm *SystrayManager) showDownloadProgress(version string) {
|
||||
// Create a new window for download progress
|
||||
progressWindow := sm.app.NewWindow("Downloading LocalAI Update")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
sm.launcher.showDownloadProgressWindow(version, fmt.Sprintf("Downloading LocalAI version %s", version), func(win fyne.Window) {
|
||||
dialog.ShowConfirm("Update Downloaded",
|
||||
"LocalAI has been updated successfully. Please restart the launcher to use the new version.",
|
||||
func(restart bool) {
|
||||
if restart {
|
||||
sm.app.Quit()
|
||||
}
|
||||
win.Close()
|
||||
}, win)
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label. Truncate with an ellipsis so a long "Download failed:
|
||||
// <url>" message can't stretch the window (and progress bar) to fit the
|
||||
// whole error on one line; the full error is shown in the dialog below.
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
statusLabel.Truncation = fyne.TextTruncateEllipsis
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := sm.launcher.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sm.app.OpenURL(releaseNotesURL)
|
||||
sm.hasUpdateAvailable = false
|
||||
sm.latestVersion = ""
|
||||
sm.recreateMenu()
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
widget.NewLabel(fmt.Sprintf("Downloading LocalAI version %s", version)),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := sm.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
progressBar.SetValue(1.0)
|
||||
|
||||
// Show restart dialog
|
||||
dialog.ShowConfirm("Update Downloaded",
|
||||
"LocalAI has been updated successfully. Please restart the launcher to use the new version.",
|
||||
func(restart bool) {
|
||||
if restart {
|
||||
sm.app.Quit()
|
||||
}
|
||||
progressWindow.Close()
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
|
||||
// Update systray menu
|
||||
if err == nil {
|
||||
sm.hasUpdateAvailable = false
|
||||
sm.latestVersion = ""
|
||||
sm.recreateMenu()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -490,14 +490,19 @@ func (ui *LauncherUI) downloadUpdate() {
|
||||
ui.UpdateStatus("Downloading update " + version + "...")
|
||||
|
||||
go func() {
|
||||
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
err := ui.launcher.DownloadUpdate(version, func(downloaded, total int64) {
|
||||
fyne.Do(func() {
|
||||
ui.progressBar.SetValue(progress)
|
||||
if total > 0 {
|
||||
ui.progressBar.SetValue(float64(downloaded) / float64(total))
|
||||
}
|
||||
})
|
||||
// Update status with percentage
|
||||
percentage := int(progress * 100)
|
||||
ui.UpdateStatus(fmt.Sprintf("Downloading update %s... %d%%", version, percentage))
|
||||
// The progress bar already shows the percentage, so report the
|
||||
// human-readable size here instead of repeating the percent.
|
||||
if total > 0 {
|
||||
ui.UpdateStatus(fmt.Sprintf("Downloading update %s… %s / %s", version, formatBytes(downloaded), formatBytes(total)))
|
||||
} else {
|
||||
ui.UpdateStatus(fmt.Sprintf("Downloading update %s… %s", version, formatBytes(downloaded)))
|
||||
}
|
||||
})
|
||||
|
||||
fyne.Do(func() {
|
||||
@@ -598,82 +603,6 @@ func (ui *LauncherUI) LoadConfiguration() {
|
||||
log.Printf("UI LoadConfiguration: configuration loaded successfully")
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a progress window for downloading LocalAI
|
||||
func (ui *LauncherUI) showDownloadProgress(version, title string) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create progress window using the launcher's app
|
||||
progressWindow := ui.launcher.app.NewWindow("Downloading LocalAI")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label. Truncate with an ellipsis so a long "Download failed:
|
||||
// <url>" message can't stretch the window (and progress bar) to fit the
|
||||
// whole error on one line; the full error is shown in the dialog below.
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
statusLabel.Truncation = fyne.TextTruncateEllipsis
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := ui.launcher.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ui.launcher.app.OpenURL(releaseNotesURL)
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
widget.NewLabel(title),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
progressBar.SetValue(1.0)
|
||||
|
||||
// Show success dialog
|
||||
dialog.ShowConfirm("Installation Complete",
|
||||
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
|
||||
func(close bool) {
|
||||
progressWindow.Close()
|
||||
// Update status
|
||||
ui.UpdateStatus("LocalAI installed successfully")
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRunningState updates UI based on LocalAI running state
|
||||
func (ui *LauncherUI) UpdateRunningState(isRunning bool) {
|
||||
fyne.Do(func() {
|
||||
|
||||
@@ -71,13 +71,42 @@ cmd_notarize() {
|
||||
echo "[notarize] notarized and stapled $dmg"
|
||||
}
|
||||
|
||||
# Notarize and staple the .app bundle itself. Stapling the dmg alone is not
|
||||
# enough: an app with no embedded ticket has no local proof of notarization, so
|
||||
# Gatekeeper falls back to an online check — and the app then fails to launch on
|
||||
# a machine that is offline / behind a firewall, or once it has been copied out
|
||||
# of the dmg. Stapling the bundle makes it verify offline. notarytool needs an
|
||||
# archive for a bundle, so we zip it first.
|
||||
cmd_notarize_app() {
|
||||
local app="$1"
|
||||
if [ -z "${MACOS_NOTARY_KEY:-}" ]; then
|
||||
echo "[notarize] MACOS_NOTARY_KEY unset: skipping notarization of $app"
|
||||
return 0
|
||||
fi
|
||||
local keyfile zip
|
||||
keyfile="$(mktemp).p8"
|
||||
zip="$(mktemp).zip"
|
||||
echo "$MACOS_NOTARY_KEY" | base64 --decode > "$keyfile"
|
||||
ditto -c -k --keepParent "$app" "$zip"
|
||||
xcrun notarytool submit "$zip" \
|
||||
--key "$keyfile" \
|
||||
--key-id "${MACOS_NOTARY_KEY_ID:?}" \
|
||||
--issuer "${MACOS_NOTARY_ISSUER_ID:?}" \
|
||||
--wait
|
||||
rm -f "$keyfile" "$zip"
|
||||
xcrun stapler staple "$app"
|
||||
xcrun stapler validate "$app"
|
||||
echo "[notarize] notarized and stapled $app"
|
||||
}
|
||||
|
||||
main() {
|
||||
local sub="${1:-}"; shift || true
|
||||
case "$sub" in
|
||||
import-cert) cmd_import_cert ;;
|
||||
sign) cmd_sign "$@" ;;
|
||||
notarize) cmd_notarize "$@" ;;
|
||||
*) echo "usage: $0 {import-cert|sign <path>|notarize <dmg>}" >&2; exit 2 ;;
|
||||
import-cert) cmd_import_cert ;;
|
||||
sign) cmd_sign "$@" ;;
|
||||
notarize) cmd_notarize "$@" ;;
|
||||
notarize-app) cmd_notarize_app "$@" ;;
|
||||
*) echo "usage: $0 {import-cert|sign <path>|notarize <dmg>|notarize-app <app>}" >&2; exit 2 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
|
||||
@@ -103,6 +103,11 @@ func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
mcpTools.CloseMCPSessions(modelName)
|
||||
})
|
||||
|
||||
// Record a model_load backend trace for every real backend load, so the
|
||||
// Traces UI shows which backend runtime served each model and how long
|
||||
// the load took. Load failures are traced by the modality wrappers.
|
||||
ml.SetLoadObserver(corebackend.ModelLoadTraceObserver(appConfig))
|
||||
|
||||
app := &Application{
|
||||
backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath),
|
||||
modelLoader: ml,
|
||||
|
||||
@@ -197,6 +197,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
envWatchdogBusy := appConfig.WatchDogBusy == startupAppConfig.WatchDogBusy
|
||||
envWatchdogIdleTimeout := appConfig.WatchDogIdleTimeout == startupAppConfig.WatchDogIdleTimeout
|
||||
envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout
|
||||
envWatchdogInterval := appConfig.WatchDogInterval == startupAppConfig.WatchDogInterval
|
||||
envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend
|
||||
envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends
|
||||
envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled
|
||||
@@ -257,6 +258,14 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
xlog.Warn("invalid watchdog busy timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogBusyTimeout)
|
||||
}
|
||||
}
|
||||
if settings.WatchdogInterval != nil && !envWatchdogInterval {
|
||||
dur, err := time.ParseDuration(*settings.WatchdogInterval)
|
||||
if err == nil {
|
||||
appConfig.WatchDogInterval = dur
|
||||
} else {
|
||||
xlog.Warn("invalid watchdog interval in runtime_settings.json", "error", err, "interval", *settings.WatchdogInterval)
|
||||
}
|
||||
}
|
||||
// Handle MaxActiveBackends (new) and SingleBackend (deprecated)
|
||||
if settings.MaxActiveBackends != nil && !envMaxActiveBackends {
|
||||
appConfig.MaxActiveBackends = *settings.MaxActiveBackends
|
||||
|
||||
@@ -356,6 +356,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
SharedModels: cfg.Distributed.SharedModels,
|
||||
// Cap how long a cold load may hold the per-model advisory lock: the
|
||||
// configured backend.install deadline plus a margin for file staging and
|
||||
// the remote LoadModel. Derived from the install timeout so raising it
|
||||
// (for slow links pulling multi-GB images) widens the ceiling too,
|
||||
// instead of letting the static default cut a legitimately slow load.
|
||||
ModelLoadCeiling: cfg.Distributed.BackendInstallTimeoutOrDefault() + 10*time.Minute,
|
||||
})
|
||||
|
||||
// Wire staging-progress broadcasting so file-staging shows up on every
|
||||
|
||||
@@ -87,6 +87,31 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() {
|
||||
})
|
||||
})
|
||||
|
||||
// Watchdog check interval (issue #10601). Unlike the idle/busy timeouts
|
||||
// (which default to 0), NewApplicationConfig baseline-defaults the
|
||||
// interval to 500ms. The loader's "apply file value only if still at the
|
||||
// zero default" env-detection therefore never fired for the interval, so
|
||||
// a UI-saved Check Interval silently reverted to 500ms on every restart
|
||||
// while the idle/busy timeouts persisted. These specs construct the
|
||||
// config the same way boot does (NewApplicationConfig) so they observe
|
||||
// the real default the loader sees.
|
||||
Describe("watchdog interval", func() {
|
||||
It("loads a UI-saved watchdog_interval on the next startup", func() {
|
||||
cfg := config.NewApplicationConfig()
|
||||
cfg.DynamicConfigsDir = seedSettings(`{"watchdog_interval": "2s"}`)
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.WatchDogInterval).To(Equal(2 * time.Second))
|
||||
})
|
||||
|
||||
It("does not override an explicit env/CLI interval", func() {
|
||||
cfg := config.NewApplicationConfig()
|
||||
cfg.DynamicConfigsDir = seedSettings(`{"watchdog_interval": "2s"}`)
|
||||
cfg.WatchDogInterval = 1 * time.Second // simulate SetWatchDogInterval from env
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.WatchDogInterval).To(Equal(1*time.Second), "env/CLI interval must win over the persisted file value")
|
||||
})
|
||||
})
|
||||
|
||||
// MITM listener address. The file is the only source — no env var
|
||||
// exists — so a regression here means an admin who configured the
|
||||
// listener via /api/settings loses it after a reboot, even though
|
||||
|
||||
@@ -369,7 +369,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", "", options.RequireBackendIntegrity); err != nil {
|
||||
if err := galleryop.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", "", false, options.RequireBackendIntegrity); err != nil {
|
||||
xlog.Error("error installing external backend", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
72
core/backend/model_load_trace_test.go
Normal file
72
core/backend/model_load_trace_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package backend_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// ModelLoadTraceObserver is what makes successful loads visible on the
|
||||
// Traces page: one model_load row per real backend load, carrying the
|
||||
// resolved backend runtime. Failures must NOT be recorded here — the
|
||||
// modality wrappers own those — and the observer must respect the runtime
|
||||
// tracing toggle.
|
||||
var _ = Describe("ModelLoadTraceObserver", func() {
|
||||
var appConfig *config.ApplicationConfig
|
||||
|
||||
successEvent := model.BackendLoadEvent{
|
||||
ModelID: "parakeet-cpp-realtime_eou_120m-v1",
|
||||
ModelName: "realtime_eou_120m.gguf",
|
||||
Backend: "parakeet-cpp",
|
||||
BackendURI: "/backends/intel-sycl-f16-parakeet-cpp-development/run.sh",
|
||||
Duration: 1500 * time.Millisecond,
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
appConfig = &config.ApplicationConfig{
|
||||
EnableTracing: true,
|
||||
TracingMaxItems: 64,
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.ClearBackendTraces()
|
||||
})
|
||||
|
||||
It("records a model_load trace with the backend runtime on success", func() {
|
||||
backend.ModelLoadTraceObserver(appConfig)(successEvent)
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
Expect(got.Type).To(Equal(trace.BackendTraceModelLoad))
|
||||
Expect(got.Summary).To(Equal("Model loaded"))
|
||||
Expect(got.ModelName).To(Equal("parakeet-cpp-realtime_eou_120m-v1"))
|
||||
Expect(got.Backend).To(Equal("parakeet-cpp"))
|
||||
Expect(got.Duration).To(Equal(1500 * time.Millisecond))
|
||||
Expect(got.Data["backend_runtime"]).To(Equal("/backends/intel-sycl-f16-parakeet-cpp-development/run.sh"))
|
||||
Expect(got.Data["model_file"]).To(Equal("realtime_eou_120m.gguf"))
|
||||
Expect(got.Error).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("skips failed loads — the modality wrappers trace those with request context", func() {
|
||||
failed := successEvent
|
||||
failed.Err = errors.New("grpc service not ready")
|
||||
|
||||
backend.ModelLoadTraceObserver(appConfig)(failed)
|
||||
|
||||
Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty())
|
||||
})
|
||||
|
||||
It("records nothing when tracing is disabled", func() {
|
||||
appConfig.EnableTracing = false
|
||||
|
||||
backend.ModelLoadTraceObserver(appConfig)(successEvent)
|
||||
|
||||
Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -19,6 +19,39 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ModelLoadTraceObserver returns the ModelLoader load observer that records
|
||||
// a model_load backend trace for every successful real load (backend process
|
||||
// spawn + LoadModel RPC; cache hits never reach the observer). Failures are
|
||||
// deliberately skipped here: the modality wrappers already record them via
|
||||
// recordModelLoadFailure with request context, and the backend auto-discovery
|
||||
// scan probes several backends before one succeeds — tracing every probe
|
||||
// failure would bury the buffer in noise.
|
||||
//
|
||||
// The traced data includes the resolved backend runtime (the installed
|
||||
// backend's launcher path, which names the variant directory) — that is what
|
||||
// identifies WHICH build served the load. A stale installed backend is
|
||||
// invisible in the model config but obvious here.
|
||||
func ModelLoadTraceObserver(appConfig *config.ApplicationConfig) func(model.BackendLoadEvent) {
|
||||
return func(ev model.BackendLoadEvent) {
|
||||
if ev.Err != nil || !appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Duration: ev.Duration,
|
||||
Type: trace.BackendTraceModelLoad,
|
||||
ModelName: ev.ModelID,
|
||||
Backend: ev.Backend,
|
||||
Summary: "Model loaded",
|
||||
Data: map[string]any{
|
||||
"model_file": ev.ModelName,
|
||||
"backend_runtime": ev.BackendURI,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// recordModelLoadFailure records a backend trace when model loading fails.
|
||||
func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, backend string, err error, data map[string]any) {
|
||||
if !appConfig.EnableTracing {
|
||||
|
||||
@@ -181,6 +181,7 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
|
||||
Text: r.Text,
|
||||
Language: r.Language,
|
||||
Duration: float64(r.Duration),
|
||||
Eou: r.Eou,
|
||||
}
|
||||
|
||||
for _, s := range r.Segments {
|
||||
|
||||
297
core/backend/transcript_live.go
Normal file
297
core/backend/transcript_live.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sound"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// LiveTranscriptionEvent is one streamed event from a live (bidirectional)
|
||||
// transcription session. Delta/Eou/Eob/Words arrive as the user speaks; Final
|
||||
// is set exactly once, on the terminal event after Close flushes the decode
|
||||
// tail. Eou means the model judged the user yielded the turn; Eob means a
|
||||
// backchannel ("uh-huh") ended — callers must NOT treat Eob as a turn
|
||||
// boundary.
|
||||
type LiveTranscriptionEvent struct {
|
||||
Delta string
|
||||
Eou bool
|
||||
Eob bool
|
||||
Words []schema.TranscriptionWord
|
||||
Final *schema.TranscriptionResult
|
||||
}
|
||||
|
||||
// LiveTranscriptionSession is a handle on an open live transcription stream.
|
||||
// Feed pushes 16 kHz mono float PCM; Close signals end-of-audio, waits for
|
||||
// the backend's terminal Final event to be delivered, and releases the
|
||||
// stream.
|
||||
type LiveTranscriptionSession interface {
|
||||
Feed(pcm []float32) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// liveCloseDrainTimeout bounds how long Close waits for the backend to flush
|
||||
// the decode tail before force-cancelling the stream. Finalize is one short
|
||||
// engine call; seconds here means the backend is wedged.
|
||||
const liveCloseDrainTimeout = 10 * time.Second
|
||||
|
||||
type liveTranscriptionSession struct {
|
||||
stream grpcPkg.AudioTranscriptionLiveClient
|
||||
cancel context.CancelFunc
|
||||
recvDone chan struct{}
|
||||
recvErr error // written by the recv goroutine before recvDone closes
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
trace *liveTraceState // nil when tracing was disabled at open
|
||||
}
|
||||
|
||||
func (s *liveTranscriptionSession) Feed(pcm []float32) error {
|
||||
s.trace.addPCM(pcm)
|
||||
return s.stream.Send(&proto.TranscriptLiveRequest{
|
||||
Payload: &proto.TranscriptLiveRequest_Audio{Audio: &proto.TranscriptLiveAudio{Pcm: pcm}},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *liveTranscriptionSession) Close() error {
|
||||
s.closeOnce.Do(func() {
|
||||
err := s.stream.CloseSend()
|
||||
select {
|
||||
case <-s.recvDone:
|
||||
case <-time.After(liveCloseDrainTimeout):
|
||||
xlog.Warn("live transcription: backend did not finalize in time; cancelling stream")
|
||||
s.cancel()
|
||||
<-s.recvDone
|
||||
}
|
||||
s.cancel()
|
||||
if err == nil {
|
||||
err = s.recvErr
|
||||
}
|
||||
s.closeErr = err
|
||||
s.trace.record(err)
|
||||
})
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
// liveSampleRate is the PCM rate of a live transcription session, fixed by
|
||||
// the session config sent in ModelTranscriptionLive.
|
||||
const liveSampleRate = 16000
|
||||
|
||||
// liveTraceState accumulates what the per-turn backend trace needs while a
|
||||
// live session runs: a bounded copy of the fed PCM for the audio snippet,
|
||||
// the decode outputs, and timing. One trace is recorded at Close — the live
|
||||
// path never touches the unary transcription wrapper, so without this a
|
||||
// streaming-only pipeline produced no transcription traces at all. Feed and
|
||||
// the recv goroutine run concurrently; mu guards the accumulators.
|
||||
type liveTraceState struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
modelName string
|
||||
backend string
|
||||
language string
|
||||
started time.Time
|
||||
|
||||
mu sync.Mutex
|
||||
pcm []byte // first trace.MaxSnippetSeconds of fed audio, int16 LE
|
||||
fedSamples int // ALL samples fed, beyond the snippet cap
|
||||
deltaEvents int
|
||||
eouEvents int
|
||||
eobEvents int
|
||||
finalText string
|
||||
}
|
||||
|
||||
func newLiveTraceState(modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, language string) *liveTraceState {
|
||||
if !appConfig.EnableTracing {
|
||||
return nil
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
return &liveTraceState{
|
||||
appConfig: appConfig,
|
||||
modelName: modelConfig.Name,
|
||||
backend: modelConfig.Backend,
|
||||
language: language,
|
||||
started: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *liveTraceState) addPCM(pcm []float32) {
|
||||
if ts == nil {
|
||||
return
|
||||
}
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.fedSamples += len(pcm)
|
||||
maxBytes := trace.MaxSnippetSeconds * liveSampleRate * 2
|
||||
if room := (maxBytes - len(ts.pcm)) / 2; room > 0 {
|
||||
if len(pcm) > room {
|
||||
pcm = pcm[:room]
|
||||
}
|
||||
ts.pcm = append(ts.pcm, sound.Float32sToInt16LEBytes(pcm)...)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *liveTraceState) observe(ev LiveTranscriptionEvent) {
|
||||
if ts == nil {
|
||||
return
|
||||
}
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ev.Delta != "" {
|
||||
ts.deltaEvents++
|
||||
}
|
||||
if ev.Eou {
|
||||
ts.eouEvents++
|
||||
}
|
||||
if ev.Eob {
|
||||
ts.eobEvents++
|
||||
}
|
||||
if ev.Final != nil {
|
||||
ts.finalText = ev.Final.Text
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *liveTraceState) record(closeErr error) {
|
||||
if ts == nil || !ts.appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
ts.mu.Lock()
|
||||
data := map[string]any{
|
||||
"source": "live_stream",
|
||||
"language": ts.language,
|
||||
"result_text": ts.finalText,
|
||||
"eou_events": ts.eouEvents,
|
||||
"eob_events": ts.eobEvents,
|
||||
"delta_events": ts.deltaEvents,
|
||||
}
|
||||
if snippet := trace.AudioSnippetFromPCM(ts.pcm, liveSampleRate, ts.fedSamples*2, ts.appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
summary := "live -> " + ts.finalText
|
||||
ts.mu.Unlock()
|
||||
|
||||
bt := trace.BackendTrace{
|
||||
Timestamp: ts.started,
|
||||
Duration: time.Since(ts.started),
|
||||
Type: trace.BackendTraceTranscription,
|
||||
ModelName: ts.modelName,
|
||||
Backend: ts.backend,
|
||||
Summary: trace.TruncateString(summary, 200),
|
||||
Data: data,
|
||||
}
|
||||
if closeErr != nil {
|
||||
bt.Error = closeErr.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(bt)
|
||||
}
|
||||
|
||||
// ModelTranscriptionLive loads the transcription backend, opens the
|
||||
// bidirectional AudioTranscriptionLive RPC, sends the session config, and
|
||||
// BLOCKS until the backend's ready ack. A grpcerrors.
|
||||
// IsLiveTranscriptionUnsupported error means the backend (or the loaded
|
||||
// model) cannot do live transcription and the caller should degrade to the
|
||||
// unary/file path. After a successful return, onEvent is invoked from a
|
||||
// background goroutine — in order, one event at a time — for every response
|
||||
// the backend streams, ending with the Final event triggered by Close.
|
||||
func ModelTranscriptionLive(ctx context.Context, language string,
|
||||
ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig,
|
||||
onEvent func(LiveTranscriptionEvent)) (LiveTranscriptionSession, error) {
|
||||
|
||||
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The derived cancel out-lives this call inside the session: Close uses
|
||||
// it to unwind the stream (and, in embed mode, the server-side recv
|
||||
// pump, which only stops on send-close or context cancellation).
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
stream, err := transcriptionModel.AudioTranscriptionLive(streamCtx)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fail := func(err error) (LiveTranscriptionSession, error) {
|
||||
_ = stream.CloseSend()
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := stream.Send(&proto.TranscriptLiveRequest{
|
||||
Payload: &proto.TranscriptLiveRequest_Config{Config: &proto.TranscriptLiveConfig{
|
||||
Language: language,
|
||||
SampleRate: liveSampleRate,
|
||||
}},
|
||||
}); err != nil {
|
||||
return fail(err)
|
||||
}
|
||||
|
||||
// Ready-ack contract: the backend answers a successful open with a
|
||||
// {ready:true} response before any transcript data; unsupported
|
||||
// backends surface Unimplemented here instead.
|
||||
ack, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fail(err)
|
||||
}
|
||||
if !ack.GetReady() {
|
||||
return fail(fmt.Errorf("live transcription: backend %q broke the ready-ack contract (first response carried data)", modelConfig.Backend))
|
||||
}
|
||||
|
||||
s := &liveTranscriptionSession{
|
||||
stream: stream,
|
||||
cancel: cancel,
|
||||
recvDone: make(chan struct{}),
|
||||
trace: newLiveTraceState(modelConfig, appConfig, language),
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(s.recvDone)
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && streamCtx.Err() == nil {
|
||||
xlog.Warn("live transcription stream ended unexpectedly", "error", err)
|
||||
s.recvErr = err
|
||||
}
|
||||
return
|
||||
}
|
||||
ev := liveEventFromProto(resp)
|
||||
if ev.Delta == "" && !ev.Eou && !ev.Eob && len(ev.Words) == 0 && ev.Final == nil {
|
||||
continue // duplicate ready ack / keep-alive: nothing to deliver
|
||||
}
|
||||
s.trace.observe(ev)
|
||||
onEvent(ev)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func liveEventFromProto(r *proto.TranscriptLiveResponse) LiveTranscriptionEvent {
|
||||
ev := LiveTranscriptionEvent{
|
||||
Delta: r.GetDelta(),
|
||||
Eou: r.GetEou(),
|
||||
Eob: r.GetEob(),
|
||||
}
|
||||
for _, w := range r.GetWords() {
|
||||
ev.Words = append(ev.Words, schema.TranscriptionWord{
|
||||
Start: time.Duration(w.Start),
|
||||
End: time.Duration(w.End),
|
||||
Text: w.Text,
|
||||
})
|
||||
}
|
||||
if r.GetFinalResult() != nil {
|
||||
ev.Final = transcriptResultFromProto(r.GetFinalResult())
|
||||
}
|
||||
return ev
|
||||
}
|
||||
162
core/backend/transcript_live_internal_test.go
Normal file
162
core/backend/transcript_live_internal_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("liveEventFromProto", func() {
|
||||
It("maps deltas, eou flags and words (ns -> duration)", func() {
|
||||
ev := liveEventFromProto(&proto.TranscriptLiveResponse{
|
||||
Delta: "hello ",
|
||||
Eou: true,
|
||||
Words: []*proto.TranscriptWord{
|
||||
{Start: int64(100 * time.Millisecond), End: int64(400 * time.Millisecond), Text: "hello"},
|
||||
},
|
||||
})
|
||||
Expect(ev.Delta).To(Equal("hello "))
|
||||
Expect(ev.Eou).To(BeTrue())
|
||||
Expect(ev.Words).To(HaveLen(1))
|
||||
Expect(ev.Words[0].Text).To(Equal("hello"))
|
||||
Expect(ev.Words[0].Start).To(Equal(100 * time.Millisecond))
|
||||
Expect(ev.Words[0].End).To(Equal(400 * time.Millisecond))
|
||||
Expect(ev.Final).To(BeNil())
|
||||
})
|
||||
|
||||
It("maps the terminal final result including the eou flag", func() {
|
||||
ev := liveEventFromProto(&proto.TranscriptLiveResponse{
|
||||
FinalResult: &proto.TranscriptResult{
|
||||
Text: "hello world",
|
||||
Duration: 1.5,
|
||||
Eou: true,
|
||||
Segments: []*proto.TranscriptSegment{{Id: 0, Text: "hello world"}},
|
||||
},
|
||||
})
|
||||
Expect(ev.Final).NotTo(BeNil())
|
||||
Expect(ev.Final.Text).To(Equal("hello world"))
|
||||
Expect(ev.Final.Duration).To(BeNumerically("~", 1.5, 1e-6))
|
||||
Expect(ev.Final.Eou).To(BeTrue())
|
||||
Expect(ev.Final.Segments).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("yields an empty event for a bare ready ack (filtered by the recv loop)", func() {
|
||||
ev := liveEventFromProto(&proto.TranscriptLiveResponse{Ready: true})
|
||||
Expect(ev.Delta).To(BeEmpty())
|
||||
Expect(ev.Eou).To(BeFalse())
|
||||
Expect(ev.Words).To(BeEmpty())
|
||||
Expect(ev.Final).To(BeNil())
|
||||
})
|
||||
|
||||
It("maps the eob backchannel flag separately from eou", func() {
|
||||
ev := liveEventFromProto(&proto.TranscriptLiveResponse{Delta: "uh-huh", Eob: true})
|
||||
Expect(ev.Eob).To(BeTrue())
|
||||
Expect(ev.Eou).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
// liveTraceState is what makes streaming-only pipelines visible on the
|
||||
// Traces page: without it a semantic_vad session with retranscribe off
|
||||
// produced no transcription trace at all. One trace per session (= one per
|
||||
// realtime turn), recorded at Close.
|
||||
var _ = Describe("liveTraceState", func() {
|
||||
var appConfig *config.ApplicationConfig
|
||||
|
||||
BeforeEach(func() {
|
||||
appConfig = &config.ApplicationConfig{
|
||||
EnableTracing: true,
|
||||
TracingMaxItems: 64,
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.ClearBackendTraces()
|
||||
})
|
||||
|
||||
modelCfg := func() config.ModelConfig {
|
||||
cfg := config.ModelConfig{Backend: "parakeet-cpp"}
|
||||
cfg.Name = "parakeet-live"
|
||||
return cfg
|
||||
}
|
||||
|
||||
It("is disabled (nil) when tracing is off, and nil receivers are no-ops", func() {
|
||||
appConfig.EnableTracing = false
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "en")
|
||||
Expect(ts).To(BeNil())
|
||||
|
||||
// The session calls these unconditionally; nil must be safe.
|
||||
ts.addPCM([]float32{0.5})
|
||||
ts.observe(LiveTranscriptionEvent{Eou: true})
|
||||
ts.record(nil)
|
||||
Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty())
|
||||
})
|
||||
|
||||
It("records one transcription trace with text, eou event counts and audio snippet at Close", func() {
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "en")
|
||||
Expect(ts).NotTo(BeNil())
|
||||
|
||||
// One second of a loud-ish constant tone so the snippet has signal.
|
||||
pcm := make([]float32, liveSampleRate)
|
||||
for i := range pcm {
|
||||
pcm[i] = 0.25
|
||||
}
|
||||
ts.addPCM(pcm)
|
||||
ts.observe(LiveTranscriptionEvent{Delta: "hello "})
|
||||
ts.observe(LiveTranscriptionEvent{Delta: "world", Eou: true})
|
||||
ts.observe(LiveTranscriptionEvent{Final: &schema.TranscriptionResult{Text: "hello world", Eou: true}})
|
||||
|
||||
ts.record(nil)
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
Expect(got.Type).To(Equal(trace.BackendTraceTranscription))
|
||||
Expect(got.ModelName).To(Equal("parakeet-live"))
|
||||
Expect(got.Backend).To(Equal("parakeet-cpp"))
|
||||
Expect(got.Summary).To(ContainSubstring("hello world"))
|
||||
Expect(got.Data["source"]).To(Equal("live_stream"))
|
||||
Expect(got.Data["result_text"]).To(Equal("hello world"))
|
||||
// The live FinalResult no longer carries a terminal eou flag; the
|
||||
// per-feed eou_events count is what the trace records instead.
|
||||
Expect(got.Data).NotTo(HaveKey("eou"))
|
||||
Expect(got.Data["eou_events"]).To(Equal(1))
|
||||
Expect(got.Data["delta_events"]).To(Equal(2))
|
||||
Expect(got.Data["audio_duration_s"]).To(BeNumerically("~", 1.0, 0.01))
|
||||
Expect(got.Data["audio_wav_base64"]).NotTo(BeEmpty())
|
||||
Expect(got.Error).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("caps the stored snippet but keeps counting the full fed duration", func() {
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "")
|
||||
|
||||
// Feed past the snippet cap in two chunks (cap + one extra second).
|
||||
ts.addPCM(make([]float32, trace.MaxSnippetSeconds*liveSampleRate))
|
||||
ts.addPCM(make([]float32, liveSampleRate))
|
||||
|
||||
Expect(len(ts.pcm)).To(Equal(trace.MaxSnippetSeconds * liveSampleRate * 2))
|
||||
Expect(ts.fedSamples).To(Equal((trace.MaxSnippetSeconds + 1) * liveSampleRate))
|
||||
|
||||
ts.record(nil)
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
Expect(got.Data["audio_duration_s"]).To(BeNumerically("~", float64(trace.MaxSnippetSeconds+1), 0.01))
|
||||
Expect(got.Data["audio_snippet_s"]).To(BeNumerically("~", float64(trace.MaxSnippetSeconds), 0.01))
|
||||
})
|
||||
|
||||
It("clamps out-of-range float samples instead of wrapping", func() {
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "")
|
||||
ts.addPCM([]float32{2.0, -2.0})
|
||||
Expect(ts.pcm).To(Equal([]byte{0xff, 0x7f, 0x00, 0x80})) // 32767, -32768
|
||||
})
|
||||
|
||||
It("stamps the close error on the trace", func() {
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "")
|
||||
ts.record(errors.New("stream torn down"))
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
Expect(trace.GetBackendTraces()[0].Error).To(Equal("stream torn down"))
|
||||
})
|
||||
})
|
||||
@@ -127,7 +127,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState)
|
||||
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias, bi.RequireBackendIntegrity)
|
||||
err = galleryop.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias, false, bi.RequireBackendIntegrity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -241,12 +242,19 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
Context: context.Background(),
|
||||
UploadLimitMB: 15,
|
||||
Debug: true,
|
||||
AgentJobRetentionDays: 30, // Default: 30 days
|
||||
LRUEvictionMaxRetries: 30, // Default: 30 retries
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||
TracingMaxItems: 1024,
|
||||
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||
AgentJobRetentionDays: 30, // Default: 30 days
|
||||
LRUEvictionMaxRetries: 30, // Default: 30 retries
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
// WatchDogInterval is intentionally left at the zero value here.
|
||||
// The startup loader applies a persisted runtime_settings.json value
|
||||
// only when the interval is still 0 (its "not set by env var"
|
||||
// heuristic, matching the idle/busy timeouts); a non-zero baseline
|
||||
// default would defeat that and silently revert a UI-saved Check
|
||||
// Interval to the default on every restart (#10601). The effective
|
||||
// 500ms default is supplied at the watchdog layer (DefaultWatchdogInterval)
|
||||
// when the value is still 0.
|
||||
TracingMaxItems: 1024,
|
||||
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||
AgentPool: AgentPoolConfig{
|
||||
Enabled: true,
|
||||
Timeout: "5m",
|
||||
@@ -1097,7 +1105,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
if o.WatchDogInterval > 0 {
|
||||
watchdogInterval = o.WatchDogInterval.String()
|
||||
} else {
|
||||
watchdogInterval = "2s" // default
|
||||
watchdogInterval = model.DefaultWatchdogInterval.String() // default: 500ms
|
||||
}
|
||||
var lruEvictionRetryInterval string
|
||||
if o.LRUEvictionRetryInterval > 0 {
|
||||
|
||||
@@ -67,6 +67,16 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
ApplyMTPDefaults(cfg, n)
|
||||
}
|
||||
|
||||
// Sliding-window-attention models (Gemma 2/3, Cohere2, Llama 4, ...) ship
|
||||
// with a reduced SWA KV cache by default, which cannot reuse a prompt
|
||||
// prefix across requests and so defeats the cross-request prefix cache
|
||||
// (cache_reuse) we enable in serving_defaults.go. Enable the full SWA cache
|
||||
// for these models so the prefix survives; skipped for dense models and
|
||||
// when the user already pinned an SWA cache option.
|
||||
if w, ok := HasSlidingWindowAttention(f); ok {
|
||||
ApplySWAFullDefault(cfg, w)
|
||||
}
|
||||
|
||||
// Thinking support detection is done after model load via DetectThinkingSupportFromBackend
|
||||
|
||||
// template estimations
|
||||
|
||||
@@ -567,6 +567,38 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Advanced: true,
|
||||
Order: 83,
|
||||
},
|
||||
"pipeline.turn_detection.type": {
|
||||
Section: "pipeline",
|
||||
Label: "Turn Detection",
|
||||
Description: "Default turn-detection mode for realtime sessions on this pipeline. server_vad commits after a fixed silence window; semantic_vad lets the transcription model's end-of-utterance token drive a dynamic window (fast commit after the token, long eagerness fallback without it). semantic_vad requires a streaming-EOU transcription model (e.g. parakeet-cpp-realtime_eou_120m-v1) and degrades to silence-only otherwise. Clients can override per session via session.update.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Default (server_vad)"},
|
||||
{Value: "server_vad", Label: "server_vad (silence-based)"},
|
||||
{Value: "semantic_vad", Label: "semantic_vad (end-of-utterance token)"},
|
||||
},
|
||||
Order: 87,
|
||||
},
|
||||
"pipeline.turn_detection.eagerness": {
|
||||
Section: "pipeline",
|
||||
Label: "Eagerness",
|
||||
Description: "semantic_vad fallback silence window used when no end-of-utterance token was seen: low waits 8s, medium/auto 4s, high 2s.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Default (auto)"},
|
||||
{Value: "low", Label: "low (8s)"},
|
||||
{Value: "medium", Label: "medium (4s)"},
|
||||
{Value: "high", Label: "high (2s)"},
|
||||
},
|
||||
Order: 88,
|
||||
},
|
||||
"pipeline.turn_detection.retranscribe": {
|
||||
Section: "pipeline",
|
||||
Label: "Retranscribe on Commit",
|
||||
Description: "Cross-check every semantic_vad commit with an offline decode of the buffered turn: commit only proceeds when the batch decode also ends in the end-of-utterance token, and its transcript is used. Logs a streamed-vs-batch comparison — useful to gauge streaming/batch alignment — at the cost of one extra decode per turn.",
|
||||
Component: "toggle",
|
||||
Order: 89,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
|
||||
@@ -650,6 +650,12 @@ type Pipeline struct {
|
||||
// VoiceRecognition gates the pipeline behind speaker verification. Nil
|
||||
// (block absent) means no gate, preserving existing behavior.
|
||||
VoiceRecognition *PipelineVoiceRecognition `yaml:"voice_recognition,omitempty" json:"voice_recognition,omitempty"`
|
||||
|
||||
// TurnDetection sets the server-side default turn-detection mode for
|
||||
// realtime sessions on this pipeline, so clients need no session.update
|
||||
// to benefit. A client session.update still overrides type and eagerness
|
||||
// per session; retranscribe is server-side only. Unset keeps server_vad.
|
||||
TurnDetection PipelineTurnDetection `yaml:"turn_detection,omitempty" json:"turn_detection,omitempty"`
|
||||
}
|
||||
|
||||
// PipelineCompaction configures summarize-then-drop for a realtime pipeline.
|
||||
@@ -934,6 +940,38 @@ func (v PipelineVoiceRecognition) Validate(registryAvailable bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// @Description PipelineTurnDetection sets realtime turn-detection defaults.
|
||||
type PipelineTurnDetection struct {
|
||||
// Type selects the default turn_detection mode for sessions on this
|
||||
// pipeline: "server_vad" (silence-based) or "semantic_vad" (the
|
||||
// transcription model's end-of-utterance token drives a dynamic silence
|
||||
// window; needs a streaming-EOU transcription model such as
|
||||
// parakeet_realtime_eou_120m-v1, degrades to silence-only otherwise).
|
||||
Type string `yaml:"type,omitempty" json:"type,omitempty"`
|
||||
// Eagerness is the semantic_vad fallback when no end-of-utterance token
|
||||
// was seen: low waits 8s of silence, medium/auto 4s, high 2s.
|
||||
Eagerness string `yaml:"eagerness,omitempty" json:"eagerness,omitempty"`
|
||||
// Retranscribe (semantic_vad only) cross-checks every EOU-triggered
|
||||
// commit with an offline decode of the buffered turn: the commit only
|
||||
// proceeds when the batch decode also ends in the end-of-utterance token,
|
||||
// and its transcript is the one used. The streamed and batch transcripts
|
||||
// are compared in the logs — a diagnostic for streaming/batch alignment
|
||||
// at the cost of one extra decode per turn.
|
||||
Retranscribe *bool `yaml:"retranscribe,omitempty" json:"retranscribe,omitempty"`
|
||||
}
|
||||
|
||||
// TurnDetectionSemantic reports whether this pipeline defaults sessions to
|
||||
// semantic (EOU-driven) turn detection.
|
||||
func (p Pipeline) TurnDetectionSemantic() bool {
|
||||
return strings.EqualFold(strings.TrimSpace(p.TurnDetection.Type), "semantic_vad")
|
||||
}
|
||||
|
||||
// TurnDetectionRetranscribe reports whether semantic_vad commits should be
|
||||
// cross-checked (and transcribed) by an offline decode of the buffered turn.
|
||||
func (p Pipeline) TurnDetectionRetranscribe() bool {
|
||||
return p.TurnDetection.Retranscribe != nil && *p.TurnDetection.Retranscribe
|
||||
}
|
||||
|
||||
// @Description File configuration for model downloads
|
||||
type File struct {
|
||||
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
|
||||
|
||||
61
core/config/pipeline_turn_detection_test.go
Normal file
61
core/config/pipeline_turn_detection_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// pipeline.turn_detection sets the server-side default turn-detection mode
|
||||
// for realtime sessions. Unset keeps server_vad, so existing configs are
|
||||
// unaffected; retranscribe is opt-in.
|
||||
var _ = Describe("Pipeline turn_detection config", func() {
|
||||
It("defaults to non-semantic with retranscribe off when unset", func() {
|
||||
var p Pipeline
|
||||
Expect(p.TurnDetectionSemantic()).To(BeFalse())
|
||||
Expect(p.TurnDetectionRetranscribe()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("parses the nested turn_detection block from YAML", func() {
|
||||
var c ModelConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
name: gpt-realtime
|
||||
pipeline:
|
||||
transcription: parakeet-cpp-realtime_eou_120m-v1
|
||||
turn_detection:
|
||||
type: semantic_vad
|
||||
eagerness: high
|
||||
retranscribe: true
|
||||
`), &c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Pipeline.TurnDetectionSemantic()).To(BeTrue())
|
||||
Expect(c.Pipeline.TurnDetection.Eagerness).To(Equal("high"))
|
||||
Expect(c.Pipeline.TurnDetectionRetranscribe()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("treats server_vad and unknown types as non-semantic", func() {
|
||||
var p Pipeline
|
||||
p.TurnDetection.Type = "server_vad"
|
||||
Expect(p.TurnDetectionSemantic()).To(BeFalse())
|
||||
p.TurnDetection.Type = "something_else"
|
||||
Expect(p.TurnDetectionSemantic()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("matches semantic_vad case-insensitively with surrounding space", func() {
|
||||
var p Pipeline
|
||||
p.TurnDetection.Type = " Semantic_VAD "
|
||||
Expect(p.TurnDetectionSemantic()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("treats an explicit retranscribe false as off", func() {
|
||||
var c ModelConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
pipeline:
|
||||
turn_detection:
|
||||
type: semantic_vad
|
||||
retranscribe: false
|
||||
`), &c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Pipeline.TurnDetectionRetranscribe()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
56
core/config/swa.go
Normal file
56
core/config/swa.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// swaCacheOptionNames lists the backend option keys that control the
|
||||
// sliding-window-attention KV cache. If the user pinned any of these we leave
|
||||
// the SWA cache alone instead of forcing swa_full.
|
||||
var swaCacheOptionNames = []string{"swa_full", "n_swa"}
|
||||
|
||||
// HasSlidingWindowAttention reports whether the parsed GGUF describes a
|
||||
// sliding-window-attention (SWA) model — Gemma 2/3, Cohere2, Llama 4 and the
|
||||
// like. The gguf-parser library normalizes the per-architecture
|
||||
// `<arch>.attention.sliding_window` metadata key into
|
||||
// GGUFArchitecture.AttentionSlidingWindow, applying the same family-specific
|
||||
// rules llama.cpp uses (e.g. Phi-3 carries the key but does not actually run
|
||||
// SWA, and is normalized to 0). A non-zero window means the model interleaves
|
||||
// SWA layers, so the returned size is also the diagnostic value we log.
|
||||
func HasSlidingWindowAttention(f *gguf.GGUFFile) (uint64, bool) {
|
||||
if f == nil {
|
||||
return 0, false
|
||||
}
|
||||
w := f.Architecture().AttentionSlidingWindow
|
||||
return w, w > 0
|
||||
}
|
||||
|
||||
// ApplySWAFullDefault enables the full-size SWA KV cache (swa_full:true) for a
|
||||
// sliding-window model, unless the user already pinned an SWA cache option.
|
||||
//
|
||||
// Why: llama.cpp defaults to a reduced SWA KV cache sized to the sliding window
|
||||
// (memory-light), but that reduced cache cannot preserve a prompt prefix across
|
||||
// requests. So for SWA models the cross-request prefix cache we enable in
|
||||
// serving_defaults.go (cache_reuse) is silently defeated — every turn
|
||||
// reprocesses the entire prompt. Setting swa_full:true makes llama.cpp keep the
|
||||
// full KV cache so the shared prefix is actually reused.
|
||||
//
|
||||
// The tradeoff is memory: the full SWA cache scales with context_size, so this
|
||||
// is gated to models that are genuinely SWA (never applied to dense models,
|
||||
// where it would only waste memory) and never overrides an explicit user
|
||||
// choice. `slidingWindow` is the value read from the GGUF and is used only for
|
||||
// the diagnostic log line.
|
||||
func ApplySWAFullDefault(cfg *ModelConfig, slidingWindow uint64) {
|
||||
if cfg == nil || slidingWindow == 0 {
|
||||
return
|
||||
}
|
||||
if backendOptionSet(cfg.Options, swaCacheOptionNames...) {
|
||||
xlog.Debug("[swa] sliding-window model but an SWA cache option is already set; leaving user choice intact",
|
||||
"name", cfg.Name, "sliding_window", slidingWindow)
|
||||
return
|
||||
}
|
||||
cfg.Options = append(cfg.Options, "swa_full:true")
|
||||
xlog.Debug("[swa] enabling swa_full for sliding-window model so the cross-request prompt-prefix cache survives (reduced SWA cache cannot reuse a prefix across requests)",
|
||||
"name", cfg.Name, "sliding_window", slidingWindow)
|
||||
}
|
||||
120
core/config/swa_test.go
Normal file
120
core/config/swa_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/core/config"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// ggufWithSlidingWindow fabricates a minimal in-memory GGUF carrying the given
|
||||
// `general.architecture` and `<arch>.attention.sliding_window` so the SWA
|
||||
// detection can be exercised without a real model file. A window of 0 omits the
|
||||
// key, modelling a dense (non-SWA) model.
|
||||
func ggufWithSlidingWindow(arch string, window uint32) *gguf.GGUFFile {
|
||||
kvs := gguf.GGUFMetadataKVs{
|
||||
{
|
||||
Key: "general.architecture",
|
||||
ValueType: gguf.GGUFMetadataValueTypeString,
|
||||
Value: arch,
|
||||
},
|
||||
}
|
||||
if window > 0 {
|
||||
kvs = append(kvs, gguf.GGUFMetadataKV{
|
||||
Key: arch + ".attention.sliding_window",
|
||||
ValueType: gguf.GGUFMetadataValueTypeUint32,
|
||||
Value: window,
|
||||
})
|
||||
}
|
||||
return &gguf.GGUFFile{
|
||||
Header: gguf.GGUFHeader{MetadataKV: kvs},
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("SWA full-cache auto-default", func() {
|
||||
Context("HasSlidingWindowAttention", func() {
|
||||
It("returns false on a nil GGUF file", func() {
|
||||
w, ok := HasSlidingWindowAttention(nil)
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(w).To(BeZero())
|
||||
})
|
||||
|
||||
It("detects a sliding-window model (Gemma 3 style)", func() {
|
||||
w, ok := HasSlidingWindowAttention(ggufWithSlidingWindow("gemma3", 1024))
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(w).To(Equal(uint64(1024)))
|
||||
})
|
||||
|
||||
It("detects Gemma 2 even without an explicit key (family default window)", func() {
|
||||
// gguf-parser applies llama.cpp's family rules: gemma2 defaults the
|
||||
// sliding window to 4096 when the metadata key is absent.
|
||||
w, ok := HasSlidingWindowAttention(ggufWithSlidingWindow("gemma2", 0))
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(w).To(Equal(uint64(4096)))
|
||||
})
|
||||
|
||||
It("reports a dense model as non-SWA", func() {
|
||||
w, ok := HasSlidingWindowAttention(ggufWithSlidingWindow("llama", 0))
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(w).To(BeZero())
|
||||
})
|
||||
|
||||
It("treats Phi-3 as non-SWA even when the key is present", func() {
|
||||
// Phi-3 carries attention.sliding_window but does not actually run
|
||||
// SWA; gguf-parser normalizes it to 0 to match llama.cpp.
|
||||
w, ok := HasSlidingWindowAttention(ggufWithSlidingWindow("phi3", 2048))
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(w).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ApplySWAFullDefault", func() {
|
||||
It("enables swa_full for a sliding-window model when unset", func() {
|
||||
cfg := &ModelConfig{Name: "gemma3"}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(ContainElement("swa_full:true"))
|
||||
})
|
||||
|
||||
It("is a no-op for a dense model (window 0)", func() {
|
||||
cfg := &ModelConfig{Name: "llama"}
|
||||
ApplySWAFullDefault(cfg, 0)
|
||||
Expect(cfg.Options).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("preserves an explicit swa_full:false", func() {
|
||||
cfg := &ModelConfig{Name: "gemma3", Options: []string{"swa_full:false"}}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(Equal([]string{"swa_full:false"}))
|
||||
})
|
||||
|
||||
It("preserves an explicit swa_full:true without duplicating it", func() {
|
||||
cfg := &ModelConfig{Name: "gemma3", Options: []string{"swa_full:true"}}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(Equal([]string{"swa_full:true"}))
|
||||
})
|
||||
|
||||
It("respects the n_swa alias", func() {
|
||||
cfg := &ModelConfig{Name: "gemma3", Options: []string{"n_swa:512"}}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(Equal([]string{"n_swa:512"}))
|
||||
})
|
||||
|
||||
It("preserves unrelated options already on the config", func() {
|
||||
cfg := &ModelConfig{
|
||||
Name: "gemma3",
|
||||
Options: []string{"use_jinja:true", "cache_reuse:256"},
|
||||
}
|
||||
ApplySWAFullDefault(cfg, 1024)
|
||||
Expect(cfg.Options).To(Equal([]string{
|
||||
"use_jinja:true",
|
||||
"cache_reuse:256",
|
||||
"swa_full:true",
|
||||
}))
|
||||
})
|
||||
|
||||
It("tolerates a nil config", func() {
|
||||
Expect(func() { ApplySWAFullDefault(nil, 1024) }).ToNot(Panic())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -101,7 +101,7 @@ func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
Model: LocalModelPath(details.URI),
|
||||
},
|
||||
},
|
||||
Diffusers: config.Diffusers{
|
||||
|
||||
@@ -4,9 +4,24 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
// LocalModelPath normalizes a model URI for backends that treat the model
|
||||
// field as a HuggingFace repo id or local filesystem path (mlx, mlx-vlm,
|
||||
// vllm, transformers, diffusers). A "file://" import URI is reduced to the
|
||||
// bare path it points at: mlx-lm and vLLM otherwise mis-read the "file://"
|
||||
// scheme as a repo id and fail with "Repo id must be in the form
|
||||
// 'repo_name' or 'namespace/repo_name'" (issue #7461). HuggingFace and HTTP
|
||||
// URIs are returned unchanged so the existing remote-load path is untouched.
|
||||
func LocalModelPath(uri string) string {
|
||||
if path, ok := strings.CutPrefix(uri, downloader.LocalPrefix); ok {
|
||||
return path
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// HasFile returns true when any file in files has exactly the given basename.
|
||||
// Directory components in file.Path are ignored — a nested
|
||||
// "sub/dir/config.json" is considered a match for name = "config.json".
|
||||
|
||||
@@ -86,4 +86,21 @@ var _ = Describe("importer helpers", func() {
|
||||
Expect(importers.HasGGMLFile(files, "ggml-")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LocalModelPath", func() {
|
||||
It("strips the file:// scheme from an absolute local path", func() {
|
||||
Expect(importers.LocalModelPath("file:///Users/u/.lmstudio/models/mlx-community/Qwen3-4bit")).
|
||||
To(Equal("/Users/u/.lmstudio/models/mlx-community/Qwen3-4bit"))
|
||||
})
|
||||
It("strips the file:// scheme from a relative local path", func() {
|
||||
Expect(importers.LocalModelPath("file://my-models/nvidia/Qwen3-30B-A3B-FP4")).
|
||||
To(Equal("my-models/nvidia/Qwen3-30B-A3B-FP4"))
|
||||
})
|
||||
It("leaves HuggingFace and HTTP URIs unchanged", func() {
|
||||
Expect(importers.LocalModelPath("https://huggingface.co/mlx-community/test-model")).
|
||||
To(Equal("https://huggingface.co/mlx-community/test-model"))
|
||||
Expect(importers.LocalModelPath("mlx-community/test-model")).
|
||||
To(Equal("mlx-community/test-model"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -22,11 +22,13 @@ var _ = Describe("DiscoverModelConfig", func() {
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// No name preference + repo-root URI: the name follows the selected
|
||||
// GGUF file, not the repo (issue #10587).
|
||||
Expect(modelConfig.Name).To(Equal("localai-functioncall-qwen2.5-7b-v0.5-q4_k_m"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/resolve/main/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
@@ -38,16 +40,17 @@ var _ = Describe("DiscoverModelConfig", func() {
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// No name preference: name follows the selected model GGUF (issue #10587).
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3VL-2B-Instruct-Q4_K_M"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3VL-2B-Instruct-Q4_K_M/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3VL-2B-Instruct-Q4_K_M/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3VL-2B-Instruct-Q4_K_M/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3VL-2B-Instruct-Q4_K_M/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
@@ -59,16 +62,17 @@ var _ = Describe("DiscoverModelConfig", func() {
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// No name preference: name follows the selected Q8_0 model GGUF (issue #10587).
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3VL-2B-Instruct-Q8_0"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3VL-2B-Instruct-Q8_0/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3VL-2B-Instruct-Q8_0/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3VL-2B-Instruct-Q8_0/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3VL-2B-Instruct-Q8_0/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
@@ -98,8 +98,13 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
// nameProvided tracks whether the user supplied an explicit model name.
|
||||
// When they didn't, the URI base is only a fallback: for a HuggingFace
|
||||
// repo-root URI (no file component) it would be the repo name, so the HF
|
||||
// branch below re-derives the name from the selected GGUF file instead
|
||||
// (issue #10587).
|
||||
name, nameProvided := preferencesMap["name"].(string)
|
||||
if !nameProvided {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
@@ -227,10 +232,23 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
mmprojGroups := hfapi.GroupShards(mmprojFiles)
|
||||
ggufGroups := hfapi.GroupShards(ggufFiles)
|
||||
|
||||
modelGroup := pickPreferredGroup(ggufGroups, quants)
|
||||
|
||||
// A repo-root URI has no file component, so the URI-base fallback
|
||||
// above produced the repo name. When the user left the name blank,
|
||||
// derive it from the GGUF file actually selected from the listing so
|
||||
// the gallery entry and `model:` directory reflect the model, not the
|
||||
// repository (issue #10587). An explicit name preference always wins.
|
||||
if !nameProvided && modelGroup != nil {
|
||||
name = modelNameFromShardGroup(*modelGroup)
|
||||
modelConfig.Name = name
|
||||
cfg.Name = name
|
||||
}
|
||||
|
||||
// Emit the model group first so cfg.Files[0] is the model — callers
|
||||
// and tests rely on the model file preceding any mmproj companion.
|
||||
if group := pickPreferredGroup(ggufGroups, quants); group != nil {
|
||||
appendShardGroup(&cfg, *group, filepath.Join("llama-cpp", "models", name))
|
||||
if modelGroup != nil {
|
||||
appendShardGroup(&cfg, *modelGroup, filepath.Join("llama-cpp", "models", name))
|
||||
}
|
||||
if group := pickPreferredGroup(mmprojGroups, mmprojQuantsList); group != nil {
|
||||
appendShardGroup(&cfg, *group, filepath.Join("llama-cpp", "mmproj", name))
|
||||
@@ -281,6 +299,20 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// modelNameFromShardGroup derives a human-facing model name from the picked
|
||||
// GGUF group: the logical base filename with its .gguf extension stripped.
|
||||
// ShardGroup.Base is the common prefix for sharded sets (without the
|
||||
// -NNNNN-of-MMMMM suffix) and the sole basename for single-file models, so
|
||||
// this yields a clean name like "model-Q4_K_M" rather than an individual
|
||||
// shard filename or the repo-root URI base.
|
||||
func modelNameFromShardGroup(group hfapi.ShardGroup) string {
|
||||
base := group.Base
|
||||
if ext := filepath.Ext(base); strings.EqualFold(ext, ".gguf") {
|
||||
base = strings.TrimSuffix(base, ext)
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// pickPreferredGroup walks the preference list in priority order and returns
|
||||
// the first group whose base filename contains any preference. When nothing
|
||||
// matches, the last group wins — this preserves the historical "if the user
|
||||
|
||||
@@ -372,6 +372,62 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("derives the model name from the selected GGUF when no name is given", func() {
|
||||
// Regression for #10587: a repo-root URI has no file component, so
|
||||
// the URI base ("example-GGUF") is just the repo name. With the
|
||||
// name field left blank, the emitted name and model directory must
|
||||
// follow the GGUF file actually selected, not the repository.
|
||||
details := withHF(`{"quantizations":"Q4_K_M"}`,
|
||||
hfFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf", "aaa"),
|
||||
hfFile("Meta-Llama-3-8B-Instruct.Q3_K_M.gguf", "bbb"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("Meta-Llama-3-8B-Instruct.Q4_K_M"))
|
||||
Expect(modelConfig.Files).To(HaveLen(1), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal(
|
||||
"llama-cpp/models/Meta-Llama-3-8B-Instruct.Q4_K_M/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("name: Meta-Llama-3-8B-Instruct.Q4_K_M"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: llama-cpp/models/Meta-Llama-3-8B-Instruct.Q4_K_M/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("derives a clean name from the shard base for split GGUFs when no name is given", func() {
|
||||
// The selected primary file is shard 1; using its raw basename
|
||||
// would leak the -00001-of-00002 suffix into the name. The shard
|
||||
// base must be used so the name is the logical model.
|
||||
details := withHF(``,
|
||||
hfFile("Qwen3-30B-A3B-Q4_K_M-00001-of-00002.gguf", "p1"),
|
||||
hfFile("Qwen3-30B-A3B-Q4_K_M-00002-of-00002.gguf", "p2"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-30B-A3B-Q4_K_M"))
|
||||
Expect(modelConfig.Files).To(HaveLen(2), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal(
|
||||
"llama-cpp/models/Qwen3-30B-A3B-Q4_K_M/Qwen3-30B-A3B-Q4_K_M-00001-of-00002.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: llama-cpp/models/Qwen3-30B-A3B-Q4_K_M/Qwen3-30B-A3B-Q4_K_M-00001-of-00002.gguf"))
|
||||
})
|
||||
|
||||
It("keeps an explicit name over the selected GGUF filename", func() {
|
||||
// Precedence guard: when the user supplies a name it always wins,
|
||||
// even though a GGUF file was selected from the listing.
|
||||
details := withHF(`{"name":"my-custom-name","quantizations":"Q4_K_M"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "aaa"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-custom-name"))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/my-custom-name/model-Q4_K_M.gguf"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("quant token boundary matching", func() {
|
||||
|
||||
@@ -87,7 +87,7 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
Model: LocalModelPath(details.URI),
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
|
||||
@@ -198,5 +198,24 @@ var _ = Describe("MLXImporter", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
|
||||
It("should emit a bare filesystem path for a file:// local import", func() {
|
||||
// Regression for #7461: a model imported from a local directory
|
||||
// (e.g. LM Studio's store) must not carry the file:// scheme into
|
||||
// the model field — mlx-lm rejects it as an invalid repo id.
|
||||
preferences := json.RawMessage(`{"backend": "mlx"}`)
|
||||
details := importers.Details{
|
||||
URI: "file:///Users/u/.lmstudio/models/mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-Coder-30B-A3B-Instruct-4bit"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: /Users/u/.lmstudio/models/mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit"))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("model: file://"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -91,7 +91,7 @@ func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, err
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
Model: LocalModelPath(details.URI),
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
|
||||
@@ -81,7 +81,7 @@ func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
Model: LocalModelPath(details.URI),
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
|
||||
@@ -177,5 +177,22 @@ var _ = Describe("VLLMImporter", func() {
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
|
||||
})
|
||||
|
||||
It("should emit a bare filesystem path for a file:// local import", func() {
|
||||
// Regression for #7461: vLLM rejects a file:// model field as an
|
||||
// invalid repo id, so a locally-imported model must carry the bare
|
||||
// path instead.
|
||||
preferences := json.RawMessage(`{"backend": "vllm"}`)
|
||||
details := Details{
|
||||
URI: "file://my-models/nvidia/Qwen3-30B-A3B-FP4",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: my-models/nvidia/Qwen3-30B-A3B-FP4"))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("model: file://"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -65,6 +65,10 @@ type BackendEndpointService struct {
|
||||
|
||||
type GalleryBackend struct {
|
||||
ID string `json:"id"`
|
||||
// Force reinstalls the backend even when it is already installed and
|
||||
// runnable. Off by default so apply stays idempotent for supervising
|
||||
// apps that ensure their backend on every boot.
|
||||
Force bool `json:"force"`
|
||||
}
|
||||
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *galleryop.GalleryService, upgradeChecker UpgradeInfoProvider) BackendEndpointService {
|
||||
@@ -103,7 +107,9 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance
|
||||
// ApplyBackendEndpoint installs a new backend to a LocalAI instance. The op is
|
||||
// idempotent: an already-installed, runnable backend is left alone unless the
|
||||
// request sets "force": true (explicit reinstall).
|
||||
// @Summary Install backends to LocalAI.
|
||||
// @Tags backends
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
@@ -137,6 +143,7 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint(systemState *system.Syst
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
Force: input.Force,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
|
||||
87
core/http/endpoints/localai/backend_apply_test.go
Normal file
87
core/http/endpoints/localai/backend_apply_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// POST /backends/apply must be idempotent by default: supervising apps call it
|
||||
// on every boot to ensure a backend exists, and forcing a reinstall there
|
||||
// re-downloads the whole artifact each time. Reinstall stays available behind
|
||||
// the explicit force flag.
|
||||
var _ = Describe("POST /backends/apply force plumbing", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
gs *galleryop.GalleryService
|
||||
tmpDir string
|
||||
received chan galleryop.ManagementOp[gallery.GalleryBackend, any]
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
app = echo.New()
|
||||
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "backends-apply-test-*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
systemState, err := system.GetSystemState(system.WithBackendPath(tmpDir))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
appConfig := &config.ApplicationConfig{SystemState: systemState}
|
||||
|
||||
// The service is deliberately not started: the test reads the op off
|
||||
// the (unbuffered) channel itself.
|
||||
gs = galleryop.NewGalleryService(appConfig, model.NewModelLoader(systemState))
|
||||
svc := CreateBackendEndpointService(nil, systemState, gs, nil)
|
||||
app.POST("/backends/apply", svc.ApplyBackendEndpoint(systemState))
|
||||
|
||||
received = make(chan galleryop.ManagementOp[gallery.GalleryBackend, any], 1)
|
||||
go func() {
|
||||
op := <-gs.BackendGalleryChannel
|
||||
received <- op
|
||||
}()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(os.RemoveAll(tmpDir)).To(Succeed())
|
||||
})
|
||||
|
||||
apply := func(body string) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodPost, "/backends/apply", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
It("enqueues a non-forced op by default", func() {
|
||||
rec := apply(`{"id":"llama-cpp"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var op galleryop.ManagementOp[gallery.GalleryBackend, any]
|
||||
Eventually(received).Should(Receive(&op))
|
||||
Expect(op.GalleryElementName).To(Equal("llama-cpp"))
|
||||
Expect(op.Force).To(BeFalse())
|
||||
})
|
||||
|
||||
It("enqueues a forced op when the request sets force", func() {
|
||||
rec := apply(`{"id":"llama-cpp","force":true}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var op galleryop.ManagementOp[gallery.GalleryBackend, any]
|
||||
Eventually(received).Should(Receive(&op))
|
||||
Expect(op.GalleryElementName).To(Equal("llama-cpp"))
|
||||
Expect(op.Force).To(BeTrue())
|
||||
})
|
||||
})
|
||||
@@ -618,6 +618,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
finishReason = FinishReasonToolCalls
|
||||
} else if toolsCalled {
|
||||
finishReason = FinishReasonFunctionCall
|
||||
} else if reachedTokenBudget(finalUsage.Completion, config.Maxtokens) {
|
||||
// Generation stopped because it hit the max_tokens ceiling
|
||||
// rather than a natural stop — report "length" (issue #9716).
|
||||
finishReason = FinishReasonLength
|
||||
}
|
||||
|
||||
// Final delta chunk: empty delta with finish_reason set. Per
|
||||
@@ -984,6 +988,18 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}
|
||||
|
||||
// If generation hit the max_tokens ceiling, report "length"
|
||||
// instead of a natural "stop" (issue #9716). Mirrors the
|
||||
// streaming path; tool/function finish reasons are untouched.
|
||||
if reachedTokenBudget(tokenUsage.Completion, config.Maxtokens) {
|
||||
for i := range result {
|
||||
if result[i].FinishReason != nil && *result[i].FinishReason == FinishReasonStop {
|
||||
lengthReason := FinishReasonLength
|
||||
result[i].FinishReason = &lengthReason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute (or no MCP tools configured), return response
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
|
||||
149
core/http/endpoints/openai/compactcoord/compactcoord.go
Normal file
149
core/http/endpoints/openai/compactcoord/compactcoord.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// Package compactcoord is the explicit state machine for the realtime API's
|
||||
// conversation-compaction concern (machine "M4" in
|
||||
// docs/design/realtime-state-machines.md).
|
||||
//
|
||||
// In the legacy code this machine is an implicit single-flight guard: a
|
||||
// per-conversation `compacting atomic.Bool` that maybeCompact CAS-flips to start
|
||||
// a background summarize+evict and a deferred Store(false) clears. The intent —
|
||||
// at most one compaction running per conversation at a time, so two goroutines
|
||||
// never summarize and evict the same overflow concurrently (Part 4, invariant
|
||||
// #9) — is correct but implicit in a bare atomic.
|
||||
//
|
||||
// This package makes it explicit:
|
||||
// - a sealed sum type for State (Idle | Running) — "two compactions running" is
|
||||
// unrepresentable,
|
||||
// - a total, pure transition function Next(state, event) -> (state, effects),
|
||||
// - a single-writer Coordinator that serializes every transition.
|
||||
//
|
||||
// Unlike respcoord (M3), a Trigger while Running is NOT a supersede: compaction
|
||||
// is idempotent work on the same overflow, so a concurrent trigger is simply
|
||||
// dropped (matching the legacy CAS-fails-so-skip), not queued or restarted.
|
||||
package compactcoord
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/coordinator"
|
||||
)
|
||||
|
||||
// State is the sealed sum type of compaction states. Exhaustively:
|
||||
// Idle | Running | Terminated.
|
||||
type State interface {
|
||||
isState()
|
||||
String() string
|
||||
}
|
||||
|
||||
// Idle: no compaction is running.
|
||||
type Idle struct{}
|
||||
|
||||
// Running: exactly one compaction is in flight.
|
||||
type Running struct{}
|
||||
|
||||
// Terminated: the conversation/session is torn down. Absorbing — no compaction
|
||||
// can start from here, so the M1 (connection) parent's teardown can cancel +
|
||||
// join the in-flight compaction and guarantee none outlives the session (see
|
||||
// formal-verification/session_lifecycle.fizz). This closes the legacy gap where
|
||||
// the fire-and-forget compaction goroutine could outlive the session.
|
||||
type Terminated struct{}
|
||||
|
||||
func (Idle) isState() {}
|
||||
func (Running) isState() {}
|
||||
func (Terminated) isState() {}
|
||||
|
||||
func (Idle) String() string { return "Idle" }
|
||||
func (Running) String() string { return "Running" }
|
||||
func (Terminated) String() string { return "Terminated" }
|
||||
|
||||
// Event is the sealed sum type of inputs. Exhaustively:
|
||||
// Trigger | Finished | Shutdown.
|
||||
type Event interface {
|
||||
isEvent()
|
||||
String() string
|
||||
}
|
||||
|
||||
// Trigger requests a compaction (the live buffer grew past the trigger). It
|
||||
// starts one only when Idle; while Running it is a no-op (single-flight).
|
||||
type Trigger struct{}
|
||||
|
||||
// Finished reports that the running compaction goroutine finished (success, error, or
|
||||
// timeout — it always reports Finished so the flag can never stick).
|
||||
type Finished struct{}
|
||||
|
||||
// Shutdown terminates the coordinator at teardown: the in-flight compaction is
|
||||
// cancelled + joined by the sink, and no compaction can start afterwards.
|
||||
type Shutdown struct{}
|
||||
|
||||
func (Trigger) isEvent() {}
|
||||
func (Finished) isEvent() {}
|
||||
func (Shutdown) isEvent() {}
|
||||
|
||||
func (Trigger) String() string { return "Trigger" }
|
||||
func (Finished) String() string { return "Finished" }
|
||||
func (Shutdown) String() string { return "Shutdown" }
|
||||
|
||||
// Effect is a side effect returned by Next as data. Exhaustively: StartCompaction.
|
||||
type Effect interface {
|
||||
isEffect()
|
||||
String() string
|
||||
}
|
||||
|
||||
// StartCompaction: spawn the background summarize+evict goroutine.
|
||||
type StartCompaction struct{}
|
||||
|
||||
func (StartCompaction) isEffect() {}
|
||||
|
||||
func (StartCompaction) String() string { return "StartCompaction" }
|
||||
|
||||
// Next is the total, pure transition function. For every (state, event) it
|
||||
// returns the next state and the ordered effects. It returns a non-nil error
|
||||
// only for an unknown State/Event implementation. Every in-domain pair is
|
||||
// defined; there are no forbidden transitions, only no-ops.
|
||||
//
|
||||
// Single-flight crux: StartCompaction is emitted only on Idle+Trigger, and a
|
||||
// Trigger while Running is a no-op — so at most one compaction ever runs.
|
||||
func Next(s State, e Event) (State, []Effect, error) {
|
||||
switch s.(type) {
|
||||
case Idle:
|
||||
switch e.(type) {
|
||||
case Trigger:
|
||||
return Running{}, []Effect{StartCompaction{}}, nil
|
||||
case Finished:
|
||||
// No compaction to finish: stale/idempotent no-op.
|
||||
return Idle{}, nil, nil
|
||||
case Shutdown:
|
||||
return Terminated{}, nil, nil
|
||||
}
|
||||
case Running:
|
||||
switch e.(type) {
|
||||
case Trigger:
|
||||
// Already compacting: drop (single-flight).
|
||||
return Running{}, nil, nil
|
||||
case Finished:
|
||||
return Idle{}, nil, nil
|
||||
case Shutdown:
|
||||
// Teardown while compacting: the sink cancels + joins the goroutine,
|
||||
// so its later Finished is absorbed here in Terminated.
|
||||
return Terminated{}, nil, nil
|
||||
}
|
||||
case Terminated:
|
||||
// Absorbing: a Trigger after teardown is rejected (no StartCompaction), so
|
||||
// no compaction outlives the session.
|
||||
switch e.(type) {
|
||||
case Trigger, Finished, Shutdown:
|
||||
return Terminated{}, nil, nil
|
||||
}
|
||||
}
|
||||
return s, nil, fmt.Errorf("compactcoord: unhandled transition %s <- %s", s, e)
|
||||
}
|
||||
|
||||
// EffectSink performs the effects produced by a transition. See coordinator.Sink:
|
||||
// StartCompaction spawns a goroutine, so Perform does not block under the lock.
|
||||
type EffectSink = coordinator.Sink[Effect]
|
||||
|
||||
// Coordinator serializes the compaction transitions. See coordinator.Coordinator.
|
||||
type Coordinator = coordinator.Coordinator[State, Event, Effect]
|
||||
|
||||
// New returns an idle Coordinator that performs effects via sink.
|
||||
func New(sink EffectSink) *Coordinator {
|
||||
return coordinator.New[State, Event, Effect](Idle{}, Next, sink)
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package compactcoord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestCompactcoord(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "compactcoord (realtime M4) Suite")
|
||||
}
|
||||
202
core/http/endpoints/openai/compactcoord/compactcoord_test.go
Normal file
202
core/http/endpoints/openai/compactcoord/compactcoord_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package compactcoord
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// recordingSink captures the ordered stream of effects. Perform is called under
|
||||
// the coordinator lock; the mutex here guards reads from the spec goroutine.
|
||||
type recordingSink struct {
|
||||
mu sync.Mutex
|
||||
log []Effect
|
||||
}
|
||||
|
||||
func (s *recordingSink) Perform(e Effect) {
|
||||
s.mu.Lock()
|
||||
s.log = append(s.log, e)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *recordingSink) count() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return len(s.log)
|
||||
}
|
||||
|
||||
type unknownEvent struct{}
|
||||
|
||||
func (unknownEvent) isEvent() {}
|
||||
func (unknownEvent) String() string { return "unknownEvent" }
|
||||
|
||||
type unknownState struct{}
|
||||
|
||||
func (unknownState) isState() {}
|
||||
func (unknownState) String() string { return "unknownState" }
|
||||
|
||||
var _ = Describe("compactcoord.Next", func() {
|
||||
DescribeTable("transitions",
|
||||
func(state State, event Event, wantState State, wantEff []Effect) {
|
||||
gotState, gotEff, err := Next(state, event)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(gotState).To(Equal(wantState))
|
||||
Expect(gotEff).To(Equal(wantEff))
|
||||
},
|
||||
Entry("idle+trigger -> running: start",
|
||||
Idle{}, Trigger{}, Running{}, []Effect{StartCompaction{}}),
|
||||
Entry("idle+finished -> idle, no-op (stale)",
|
||||
Idle{}, Finished{}, Idle{}, []Effect(nil)),
|
||||
Entry("running+trigger -> running, no-op (single-flight)",
|
||||
Running{}, Trigger{}, Running{}, []Effect(nil)),
|
||||
Entry("running+finished -> idle",
|
||||
Running{}, Finished{}, Idle{}, []Effect(nil)),
|
||||
Entry("idle+shutdown -> terminated",
|
||||
Idle{}, Shutdown{}, Terminated{}, []Effect(nil)),
|
||||
Entry("running+shutdown -> terminated",
|
||||
Running{}, Shutdown{}, Terminated{}, []Effect(nil)),
|
||||
Entry("terminated+trigger -> terminated, REJECTED",
|
||||
Terminated{}, Trigger{}, Terminated{}, []Effect(nil)),
|
||||
Entry("terminated+finished -> terminated, no-op (stale)",
|
||||
Terminated{}, Finished{}, Terminated{}, []Effect(nil)),
|
||||
Entry("terminated+shutdown -> terminated, idempotent",
|
||||
Terminated{}, Shutdown{}, Terminated{}, []Effect(nil)),
|
||||
)
|
||||
|
||||
It("is total over the defined (state, event) pairs", func() {
|
||||
for _, s := range []State{Idle{}, Running{}, Terminated{}} {
|
||||
for _, e := range []Event{Trigger{}, Finished{}, Shutdown{}} {
|
||||
_, _, err := Next(s, e)
|
||||
Expect(err).NotTo(HaveOccurred(), "Next(%s, %s)", s, e)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("errors on an unknown event type", func() {
|
||||
_, _, err := Next(Idle{}, unknownEvent{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors on an unknown state type", func() {
|
||||
_, _, err := Next(unknownState{}, Trigger{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("compactcoord.Coordinator", func() {
|
||||
// A StartCompaction is only ever produced while Idle (verified by checking the
|
||||
// effect count grows exactly when the model transitions Idle->Running), so at
|
||||
// most one compaction is ever in flight.
|
||||
It("starts at most one compaction at a time over random sequences", func() {
|
||||
seeds := []uint64{1, 2, 3, 42, 1337, 0xC0FFEE}
|
||||
for _, seed := range seeds {
|
||||
r := rand.New(rand.NewPCG(seed, 0xA5A5A5A5))
|
||||
sink := &recordingSink{}
|
||||
c := New(sink)
|
||||
running := false
|
||||
starts := 0
|
||||
|
||||
for range 5000 {
|
||||
if r.IntN(2) == 0 {
|
||||
before := sink.count()
|
||||
Expect(c.Apply(Trigger{})).To(Succeed())
|
||||
if sink.count() > before {
|
||||
// A StartCompaction was produced: must have been Idle.
|
||||
Expect(running).To(BeFalse(), "seed=%d: started while already running", seed)
|
||||
running = true
|
||||
starts++
|
||||
}
|
||||
} else {
|
||||
Expect(c.Apply(Finished{})).To(Succeed())
|
||||
running = false
|
||||
}
|
||||
if running {
|
||||
Expect(c.State()).To(Equal(State(Running{})), "seed=%d", seed)
|
||||
} else {
|
||||
Expect(c.State()).To(Equal(State(Idle{})), "seed=%d", seed)
|
||||
}
|
||||
}
|
||||
Expect(starts).To(BeNumerically(">", 0), "seed=%d: walk should have started at least one", seed)
|
||||
}
|
||||
})
|
||||
|
||||
// Faithful concurrent test: StartCompaction spawns "work" that bumps an active
|
||||
// counter, runs, and reports Finished back to the coordinator (exactly how the
|
||||
// real sink behaves). Single-flight must hold even under many concurrent
|
||||
// Triggers: the active counter never exceeds 1. Run under -race.
|
||||
It("never runs two compactions concurrently", func() {
|
||||
var active, maxActive int32
|
||||
var c *Coordinator
|
||||
var work sync.WaitGroup
|
||||
sink := &spawnSink{onStart: func() {
|
||||
work.Add(1)
|
||||
go func() {
|
||||
defer work.Done()
|
||||
n := atomic.AddInt32(&active, 1)
|
||||
for {
|
||||
m := atomic.LoadInt32(&maxActive)
|
||||
if n <= m || atomic.CompareAndSwapInt32(&maxActive, m, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
atomic.AddInt32(&active, -1)
|
||||
_ = c.Apply(Finished{})
|
||||
}()
|
||||
}}
|
||||
c = New(sink)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for g := 0; g < 8; g++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 1000 {
|
||||
_ = c.Apply(Trigger{})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
work.Wait() // let any in-flight compaction report Finished
|
||||
|
||||
Expect(atomic.LoadInt32(&maxActive)).To(BeNumerically("<=", 1))
|
||||
Expect(c.State()).To(Equal(State(Idle{})))
|
||||
})
|
||||
|
||||
It("terminates on shutdown and rejects later triggers", func() {
|
||||
sink := &recordingSink{}
|
||||
c := New(sink)
|
||||
Expect(c.Apply(Trigger{})).To(Succeed()) // Idle -> Running (StartCompaction)
|
||||
Expect(c.Apply(Shutdown{})).To(Succeed())
|
||||
Expect(c.State()).To(Equal(State(Terminated{})))
|
||||
|
||||
before := sink.count()
|
||||
Expect(c.Apply(Trigger{})).To(Succeed()) // rejected
|
||||
Expect(sink.count()).To(Equal(before), "no StartCompaction after shutdown")
|
||||
Expect(c.Apply(Finished{})).To(Succeed()) // stale, absorbed
|
||||
Expect(c.State()).To(Equal(State(Terminated{})))
|
||||
})
|
||||
})
|
||||
|
||||
// spawnSink invokes onStart for each StartCompaction (called under the coord lock;
|
||||
// onStart must be non-blocking — it spawns the work goroutine).
|
||||
type spawnSink struct{ onStart func() }
|
||||
|
||||
func (s *spawnSink) Perform(e Effect) {
|
||||
if _, ok := e.(StartCompaction); ok {
|
||||
s.onStart()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = DescribeTable("compactcoord stringers",
|
||||
func(got, want string) { Expect(got).To(Equal(want)) },
|
||||
Entry(nil, Idle{}.String(), "Idle"),
|
||||
Entry(nil, Running{}.String(), "Running"),
|
||||
Entry(nil, Terminated{}.String(), "Terminated"),
|
||||
Entry(nil, Trigger{}.String(), "Trigger"),
|
||||
Entry(nil, Finished{}.String(), "Finished"),
|
||||
Entry(nil, Shutdown{}.String(), "Shutdown"),
|
||||
Entry(nil, StartCompaction{}.String(), "StartCompaction"),
|
||||
)
|
||||
164
core/http/endpoints/openai/conncoord/conncoord.go
Normal file
164
core/http/endpoints/openai/conncoord/conncoord.go
Normal file
@@ -0,0 +1,164 @@
|
||||
// Package conncoord is the explicit state machine for the realtime API's
|
||||
// connection lifecycle (machine "M1" in docs/design/realtime-state-machines.md).
|
||||
//
|
||||
// In the legacy code this machine is implicit and fragile. The session handler
|
||||
// keeps a `vadServerStarted` bool plus a `done` channel that is REASSIGNED to a
|
||||
// fresh channel every time turn detection is toggled on (session.update) and
|
||||
// closed both at toggle-off and at teardown (Part 2, failure mode 6). It is
|
||||
// correct today only because one goroutine owns it; "one variable name meaning
|
||||
// different channels over time, closed from two sites guarded by a bool" is a
|
||||
// structural hazard, not an explicit lifecycle. Teardown likewise depends on the
|
||||
// bool to avoid closing an already-closed channel.
|
||||
//
|
||||
// This package makes the lifecycle explicit:
|
||||
// - a sealed sum type for State (Live{VADRunning} | Torn) — illegal states
|
||||
// such as "running after teardown" are unrepresentable,
|
||||
// - a total, pure transition function Next(state, event) -> (state, effects),
|
||||
// - a single-writer Coordinator that serializes every transition.
|
||||
//
|
||||
// The guarantees the spec checks:
|
||||
// - the VAD goroutine's done channel is closed exactly once per start (StopVAD
|
||||
// is emitted only while running, so never a double close / close of nil),
|
||||
// - teardown runs exactly once (Close from Live; any later Close is a no-op),
|
||||
// - nothing is started after teardown (no resurrection / no send-after-close).
|
||||
//
|
||||
// Like turncoord (M2), the connection machine is driven by the single session
|
||||
// goroutine; the Coordinator's lock keeps State() race-free and guards against a
|
||||
// future second writer. The effects are performed by a sink that owns the actual
|
||||
// channels/goroutines (see realtime_conncoord.go).
|
||||
package conncoord
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/coordinator"
|
||||
)
|
||||
|
||||
// State is the sealed sum type of connection states. The only implementations
|
||||
// are the marker-method structs in this file. Exhaustively: Live | Torn.
|
||||
type State interface {
|
||||
isState()
|
||||
String() string
|
||||
}
|
||||
|
||||
// Live: the session is active. VADRunning records whether the turn-detection
|
||||
// (handleVAD) goroutine is currently running — the single source of truth that
|
||||
// replaces the legacy vadServerStarted bool, so the per-run done channel is
|
||||
// closed exactly once.
|
||||
type Live struct{ VADRunning bool }
|
||||
|
||||
// Torn: the session has been torn down. Terminal — no effect is ever produced
|
||||
// from here again.
|
||||
type Torn struct{}
|
||||
|
||||
func (Live) isState() {}
|
||||
func (Torn) isState() {}
|
||||
|
||||
func (s Live) String() string { return fmt.Sprintf("Live(vad=%t)", s.VADRunning) }
|
||||
func (Torn) String() string { return "Torn" }
|
||||
|
||||
// Event is the sealed sum type of inputs. Exhaustively: SetVAD | Close.
|
||||
type Event interface {
|
||||
isEvent()
|
||||
String() string
|
||||
}
|
||||
|
||||
// SetVAD requests the turn-detection goroutine be running (Active) or not. It is
|
||||
// raised whenever session.update changes whether turn detection is active. It is
|
||||
// idempotent: setting the state it is already in is a no-op.
|
||||
type SetVAD struct{ Active bool }
|
||||
|
||||
// Close requests teardown (the transport read loop ended, or the session is
|
||||
// closing). It is idempotent — only the first Close from Live tears down.
|
||||
type Close struct{}
|
||||
|
||||
func (SetVAD) isEvent() {}
|
||||
func (Close) isEvent() {}
|
||||
|
||||
func (e SetVAD) String() string { return fmt.Sprintf("SetVAD(%t)", e.Active) }
|
||||
func (Close) String() string { return "Close" }
|
||||
|
||||
// Effect is a side effect returned by Next as data for the caller to perform.
|
||||
// Exhaustively: StartVAD | StopVAD | Teardown.
|
||||
type Effect interface {
|
||||
isEffect()
|
||||
String() string
|
||||
}
|
||||
|
||||
// StartVAD: create a fresh done channel and spawn the handleVAD goroutine on it.
|
||||
type StartVAD struct{}
|
||||
|
||||
// StopVAD: close the running VAD goroutine's done channel (signal it to exit).
|
||||
type StopVAD struct{}
|
||||
|
||||
// Teardown: the once-only teardown — stop the remaining input goroutines (opus
|
||||
// decode, sound window), join them, cancel in-flight responses, and remove the
|
||||
// session from the registry. Emitted exactly once.
|
||||
type Teardown struct{}
|
||||
|
||||
func (StartVAD) isEffect() {}
|
||||
func (StopVAD) isEffect() {}
|
||||
func (Teardown) isEffect() {}
|
||||
|
||||
func (StartVAD) String() string { return "StartVAD" }
|
||||
func (StopVAD) String() string { return "StopVAD" }
|
||||
func (Teardown) String() string { return "Teardown" }
|
||||
|
||||
// Next is the total, pure transition function. For every (state, event) it
|
||||
// returns the next state and the ordered effects to perform. It returns a
|
||||
// non-nil error only for an unknown State/Event implementation. Every in-domain
|
||||
// pair is defined; there are no forbidden transitions, only no-ops.
|
||||
//
|
||||
// The crux: Close moves to Torn, which absorbs every later event with no
|
||||
// effects. So teardown's channel closes happen exactly once even if Close is
|
||||
// raised again (e.g. an error path and the normal return both reaching it), and
|
||||
// no StartVAD can resurrect a torn session.
|
||||
func Next(s State, e Event) (State, []Effect, error) {
|
||||
switch st := s.(type) {
|
||||
case Live:
|
||||
switch ev := e.(type) {
|
||||
case SetVAD:
|
||||
switch {
|
||||
case ev.Active && !st.VADRunning:
|
||||
return Live{VADRunning: true}, []Effect{StartVAD{}}, nil
|
||||
case !ev.Active && st.VADRunning:
|
||||
return Live{VADRunning: false}, []Effect{StopVAD{}}, nil
|
||||
default:
|
||||
// Already in the requested state: idempotent no-op.
|
||||
return Live{VADRunning: st.VADRunning}, nil, nil
|
||||
}
|
||||
case Close:
|
||||
if st.VADRunning {
|
||||
return Torn{}, []Effect{StopVAD{}, Teardown{}}, nil
|
||||
}
|
||||
return Torn{}, []Effect{Teardown{}}, nil
|
||||
}
|
||||
case Torn:
|
||||
switch e.(type) {
|
||||
case SetVAD:
|
||||
// No resurrection: a toggle after teardown is ignored.
|
||||
return Torn{}, nil, nil
|
||||
case Close:
|
||||
// Idempotent: teardown already ran.
|
||||
return Torn{}, nil, nil
|
||||
}
|
||||
}
|
||||
return s, nil, fmt.Errorf("conncoord: unhandled transition %s <- %s", s, e)
|
||||
}
|
||||
|
||||
// EffectSink performs the effects produced by a transition. See coordinator.Sink:
|
||||
// Perform runs under the coordinator lock. The Teardown effect does join
|
||||
// goroutines (which can block) — acceptable here because the connection
|
||||
// coordinator is single-writer and torn down exactly once at the end of the
|
||||
// session goroutine, so no other Apply is contending the lock.
|
||||
type EffectSink = coordinator.Sink[Effect]
|
||||
|
||||
// Coordinator serializes the connection-lifecycle transitions.
|
||||
// See coordinator.Coordinator.
|
||||
type Coordinator = coordinator.Coordinator[State, Event, Effect]
|
||||
|
||||
// New returns a Coordinator in Live{VADRunning:false} that performs effects via
|
||||
// sink.
|
||||
func New(sink EffectSink) *Coordinator {
|
||||
return coordinator.New[State, Event, Effect](Live{VADRunning: false}, Next, sink)
|
||||
}
|
||||
13
core/http/endpoints/openai/conncoord/conncoord_suite_test.go
Normal file
13
core/http/endpoints/openai/conncoord/conncoord_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package conncoord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestConncoord(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "conncoord (realtime M1) Suite")
|
||||
}
|
||||
212
core/http/endpoints/openai/conncoord/conncoord_test.go
Normal file
212
core/http/endpoints/openai/conncoord/conncoord_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package conncoord
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// recordingSink captures the ordered stream of effects so the invariants can be
|
||||
// checked independently of the transition function. Perform is called by
|
||||
// Coordinator.Apply under the coordinator lock; the mutex here only guards reads
|
||||
// from the spec goroutine.
|
||||
type recordingSink struct {
|
||||
mu sync.Mutex
|
||||
log []Effect
|
||||
}
|
||||
|
||||
func (s *recordingSink) Perform(e Effect) {
|
||||
s.mu.Lock()
|
||||
s.log = append(s.log, e)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *recordingSink) snapshot() []Effect {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]Effect, len(s.log))
|
||||
copy(out, s.log)
|
||||
return out
|
||||
}
|
||||
|
||||
// checkLog replays the effect log and asserts the lifecycle safety properties
|
||||
// from docs/design/realtime-state-machines.md, Part 4 (invariants #8, #10 and
|
||||
// failure mode 6):
|
||||
//
|
||||
// (1) the VAD done channel is closed exactly once per start -- StartVAD only
|
||||
// while stopped, StopVAD only while running (no double close / close-of-nil);
|
||||
// (2) teardown runs at most once;
|
||||
// (3) no resurrection -- no StartVAD after Teardown.
|
||||
func checkLog(log []Effect) {
|
||||
running := false
|
||||
torn := false
|
||||
teardowns := 0
|
||||
for i, eff := range log {
|
||||
switch eff.(type) {
|
||||
case StartVAD:
|
||||
Expect(torn).To(BeFalse(), "invariant (3): StartVAD after teardown (effect #%d)\nlog=%v", i, log)
|
||||
Expect(running).To(BeFalse(), "invariant (1): StartVAD while already running (effect #%d)\nlog=%v", i, log)
|
||||
running = true
|
||||
case StopVAD:
|
||||
Expect(running).To(BeTrue(), "invariant (1): StopVAD while not running (effect #%d)\nlog=%v", i, log)
|
||||
running = false
|
||||
case Teardown:
|
||||
Expect(torn).To(BeFalse(), "invariant (2): Teardown twice (effect #%d)\nlog=%v", i, log)
|
||||
torn = true
|
||||
teardowns++
|
||||
}
|
||||
}
|
||||
Expect(teardowns).To(BeNumerically("<=", 1), "invariant (2): teardown ran %d times\nlog=%v", teardowns, log)
|
||||
}
|
||||
|
||||
type unknownEvent struct{}
|
||||
|
||||
func (unknownEvent) isEvent() {}
|
||||
func (unknownEvent) String() string { return "unknownEvent" }
|
||||
|
||||
type unknownState struct{}
|
||||
|
||||
func (unknownState) isState() {}
|
||||
func (unknownState) String() string { return "unknownState" }
|
||||
|
||||
var _ = Describe("conncoord.Next", func() {
|
||||
DescribeTable("transitions",
|
||||
func(state State, event Event, wantState State, wantEff []Effect) {
|
||||
gotState, gotEff, err := Next(state, event)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(gotState).To(Equal(wantState))
|
||||
Expect(gotEff).To(Equal(wantEff))
|
||||
},
|
||||
Entry("stopped+setvad(on) -> running: start",
|
||||
Live{VADRunning: false}, SetVAD{Active: true},
|
||||
Live{VADRunning: true}, []Effect{StartVAD{}}),
|
||||
Entry("running+setvad(on) -> running, no-op",
|
||||
Live{VADRunning: true}, SetVAD{Active: true},
|
||||
Live{VADRunning: true}, []Effect(nil)),
|
||||
Entry("stopped+setvad(off) -> stopped, no-op",
|
||||
Live{VADRunning: false}, SetVAD{Active: false},
|
||||
Live{VADRunning: false}, []Effect(nil)),
|
||||
Entry("running+setvad(off) -> stopped: stop",
|
||||
Live{VADRunning: true}, SetVAD{Active: false},
|
||||
Live{VADRunning: false}, []Effect{StopVAD{}}),
|
||||
Entry("stopped+close -> torn: teardown",
|
||||
Live{VADRunning: false}, Close{},
|
||||
Torn{}, []Effect{Teardown{}}),
|
||||
Entry("running+close -> torn: stop + teardown",
|
||||
Live{VADRunning: true}, Close{},
|
||||
Torn{}, []Effect{StopVAD{}, Teardown{}}),
|
||||
Entry("torn+setvad(on) -> torn, no-op (no resurrection)",
|
||||
Torn{}, SetVAD{Active: true},
|
||||
Torn{}, []Effect(nil)),
|
||||
Entry("torn+close -> torn, no-op (idempotent)",
|
||||
Torn{}, Close{},
|
||||
Torn{}, []Effect(nil)),
|
||||
)
|
||||
|
||||
It("is total over the defined (state, event) pairs", func() {
|
||||
states := []State{Live{VADRunning: false}, Live{VADRunning: true}, Torn{}}
|
||||
events := []Event{SetVAD{Active: true}, SetVAD{Active: false}, Close{}}
|
||||
for _, s := range states {
|
||||
for _, e := range events {
|
||||
_, _, err := Next(s, e)
|
||||
Expect(err).NotTo(HaveOccurred(), "Next(%s, %s)", s, e)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("errors on an unknown event type", func() {
|
||||
_, _, err := Next(Live{}, unknownEvent{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors on an unknown state type", func() {
|
||||
_, _, err := Next(unknownState{}, Close{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("conncoord.Coordinator", func() {
|
||||
It("upholds the lifecycle invariants over random event sequences", func() {
|
||||
seeds := []uint64{1, 2, 3, 42, 1337, 0xC0FFEE}
|
||||
for _, seed := range seeds {
|
||||
r := rand.New(rand.NewPCG(seed, 0xA5A5A5A5))
|
||||
sink := &recordingSink{}
|
||||
c := New(sink)
|
||||
running := false
|
||||
torn := false
|
||||
|
||||
for range 5000 {
|
||||
switch r.IntN(3) {
|
||||
case 0:
|
||||
Expect(c.Apply(SetVAD{Active: true})).To(Succeed())
|
||||
if !torn {
|
||||
running = true
|
||||
}
|
||||
case 1:
|
||||
Expect(c.Apply(SetVAD{Active: false})).To(Succeed())
|
||||
if !torn {
|
||||
running = false
|
||||
}
|
||||
case 2:
|
||||
Expect(c.Apply(Close{})).To(Succeed())
|
||||
torn = true
|
||||
running = false
|
||||
}
|
||||
if torn {
|
||||
Expect(c.State()).To(Equal(State(Torn{})), "seed=%d", seed)
|
||||
} else {
|
||||
Expect(c.State()).To(Equal(State(Live{VADRunning: running})), "seed=%d", seed)
|
||||
}
|
||||
}
|
||||
checkLog(sink.snapshot())
|
||||
}
|
||||
})
|
||||
|
||||
It("tears down at most once under concurrent SetVAD/Close from two goroutines", func() {
|
||||
const perGoroutine = 2000
|
||||
sink := &recordingSink{}
|
||||
c := New(sink)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
drive := func(active bool) {
|
||||
defer wg.Done()
|
||||
for i := range perGoroutine {
|
||||
switch i % 3 {
|
||||
case 0:
|
||||
_ = c.Apply(SetVAD{Active: active})
|
||||
case 1:
|
||||
_ = c.Apply(SetVAD{Active: !active})
|
||||
case 2:
|
||||
if i > perGoroutine/2 {
|
||||
_ = c.Apply(Close{})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go drive(true)
|
||||
go drive(false)
|
||||
wg.Wait()
|
||||
_ = c.Apply(Close{})
|
||||
|
||||
checkLog(sink.snapshot())
|
||||
Expect(c.State()).To(Equal(State(Torn{})))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = DescribeTable("conncoord stringers",
|
||||
func(got, want string) { Expect(got).To(Equal(want)) },
|
||||
Entry(nil, Live{VADRunning: true}.String(), "Live(vad=true)"),
|
||||
Entry(nil, Live{VADRunning: false}.String(), "Live(vad=false)"),
|
||||
Entry(nil, Torn{}.String(), "Torn"),
|
||||
|
||||
Entry(nil, SetVAD{Active: true}.String(), "SetVAD(true)"),
|
||||
Entry(nil, Close{}.String(), "Close"),
|
||||
|
||||
Entry(nil, StartVAD{}.String(), "StartVAD"),
|
||||
Entry(nil, StopVAD{}.String(), "StopVAD"),
|
||||
Entry(nil, Teardown{}.String(), "Teardown"),
|
||||
)
|
||||
@@ -5,4 +5,7 @@ const (
|
||||
FinishReasonStop = "stop"
|
||||
FinishReasonToolCalls = "tool_calls"
|
||||
FinishReasonFunctionCall = "function_call"
|
||||
// FinishReasonLength is reported when generation stopped because it
|
||||
// reached the max_tokens budget rather than a natural stop (issue #9716).
|
||||
FinishReasonLength = "length"
|
||||
)
|
||||
|
||||
82
core/http/endpoints/openai/coordinator/coordinator.go
Normal file
82
core/http/endpoints/openai/coordinator/coordinator.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Package coordinator is the shared single-writer state-machine runtime for the
|
||||
// realtime API's explicit coordinators (machines M1–M5 in
|
||||
// docs/design/realtime-state-machines.md).
|
||||
//
|
||||
// Each machine package (respcoord, turncoord, conncoord, compactcoord, ttscoord)
|
||||
// defines its OWN sealed sum types for State/Event/Effect and a total, pure
|
||||
// transition function Next(state, event) -> (state, []effect, error). The
|
||||
// plumbing around that — a single-writer Coordinator that serializes every
|
||||
// transition behind one lock and performs the returned effects in order — is
|
||||
// identical across all five, so it lives here once instead of being copied.
|
||||
//
|
||||
// A machine package wires itself up with three lines:
|
||||
//
|
||||
// type EffectSink = coordinator.Sink[Effect]
|
||||
// type Coordinator = coordinator.Coordinator[State, Event, Effect]
|
||||
// func New(sink EffectSink) *Coordinator { return coordinator.New[State, Event, Effect](Idle{}, Next, sink) }
|
||||
//
|
||||
// The aliases keep each package's public API (Coordinator, New, EffectSink,
|
||||
// Apply, State) unchanged. The single-writer serialization — the load-bearing
|
||||
// concurrency guarantee the FizzBee specs check — is therefore implemented and
|
||||
// reasoned about in exactly one place.
|
||||
package coordinator
|
||||
|
||||
import "sync"
|
||||
|
||||
// TransitionFunc is a machine's total, pure transition: given the current state
|
||||
// and an event it returns the next state, the ordered effects to perform, and a
|
||||
// non-nil error ONLY for an unhandled (programmer-error) state/event pair. It
|
||||
// must not perform I/O or block; side effects are returned as data (F) for the
|
||||
// Coordinator to hand to the Sink.
|
||||
type TransitionFunc[S, E, F any] func(state S, event E) (S, []F, error)
|
||||
|
||||
// Sink performs the effects a transition produces. Implementations MUST be
|
||||
// non-blocking: Perform is called while the Coordinator holds its lock, so it
|
||||
// must not block (it should spawn a goroutine, call a cancel func, or do a
|
||||
// non-blocking channel send) and MUST NOT call back into the same Coordinator's
|
||||
// Apply.
|
||||
type Sink[F any] interface {
|
||||
Perform(F)
|
||||
}
|
||||
|
||||
// Coordinator is the single-writer wrapper around a pure transition function.
|
||||
// Every Apply is serialized by mu, so multiple goroutines can drive the machine
|
||||
// without racing, and a transition's effects are performed in order under the
|
||||
// lock (before any subsequent Apply can observe the new state).
|
||||
type Coordinator[S, E, F any] struct {
|
||||
mu sync.Mutex
|
||||
state S
|
||||
next TransitionFunc[S, E, F]
|
||||
sink Sink[F]
|
||||
}
|
||||
|
||||
// New returns a Coordinator in the given initial state that transitions via next
|
||||
// and performs effects via sink.
|
||||
func New[S, E, F any](initial S, next TransitionFunc[S, E, F], sink Sink[F]) *Coordinator[S, E, F] {
|
||||
return &Coordinator[S, E, F]{state: initial, next: next, sink: sink}
|
||||
}
|
||||
|
||||
// Apply runs one transition under the lock and performs its effects in order. If
|
||||
// the transition function returns an error (an unhandled state/event), the state
|
||||
// is left unchanged and the error is returned to the caller — never silently
|
||||
// swallowed.
|
||||
func (c *Coordinator[S, E, F]) Apply(e E) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
ns, effects, err := c.next(c.state, e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.state = ns
|
||||
for _, eff := range effects {
|
||||
c.sink.Perform(eff)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// State returns the current state (a value; safe to call concurrently).
|
||||
func (c *Coordinator[S, E, F]) State() S {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.state
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package coordinator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestCoordinator(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "coordinator (shared runtime) Suite")
|
||||
}
|
||||
124
core/http/endpoints/openai/coordinator/coordinator_test.go
Normal file
124
core/http/endpoints/openai/coordinator/coordinator_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package coordinator
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// A tiny toy machine exercises the generic runtime directly (the five real
|
||||
// machines exercise it via their aliases, but the gate measures this package's
|
||||
// own coverage). off <-toggle-> on; burst emits three ordered effects; boom is
|
||||
// the unhandled/error path.
|
||||
type tstate int
|
||||
|
||||
const (
|
||||
off tstate = iota
|
||||
on
|
||||
)
|
||||
|
||||
type tevent int
|
||||
|
||||
const (
|
||||
toggle tevent = iota
|
||||
burst
|
||||
boom
|
||||
)
|
||||
|
||||
type teffect string
|
||||
|
||||
func tnext(s tstate, e tevent) (tstate, []teffect, error) {
|
||||
switch e {
|
||||
case toggle:
|
||||
if s == off {
|
||||
return on, []teffect{"on"}, nil
|
||||
}
|
||||
return off, []teffect{"off"}, nil
|
||||
case burst:
|
||||
return s, []teffect{"a", "b", "c"}, nil
|
||||
case boom:
|
||||
return s, nil, errors.New("boom: unhandled")
|
||||
}
|
||||
return s, nil, fmt.Errorf("unknown event %d", int(e))
|
||||
}
|
||||
|
||||
type recordingSink struct {
|
||||
mu sync.Mutex
|
||||
log []teffect
|
||||
}
|
||||
|
||||
func (s *recordingSink) Perform(e teffect) {
|
||||
s.mu.Lock()
|
||||
s.log = append(s.log, e)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *recordingSink) snapshot() []teffect {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]teffect, len(s.log))
|
||||
copy(out, s.log)
|
||||
return out
|
||||
}
|
||||
|
||||
var _ = Describe("coordinator.Coordinator", func() {
|
||||
It("starts in the initial state", func() {
|
||||
c := New[tstate, tevent, teffect](off, tnext, &recordingSink{})
|
||||
Expect(c.State()).To(Equal(off))
|
||||
})
|
||||
|
||||
It("advances state and performs the transition's effects", func() {
|
||||
sink := &recordingSink{}
|
||||
c := New[tstate, tevent, teffect](off, tnext, sink)
|
||||
|
||||
Expect(c.Apply(toggle)).To(Succeed())
|
||||
Expect(c.State()).To(Equal(on))
|
||||
Expect(c.Apply(toggle)).To(Succeed())
|
||||
Expect(c.State()).To(Equal(off))
|
||||
|
||||
Expect(sink.snapshot()).To(Equal([]teffect{"on", "off"}))
|
||||
})
|
||||
|
||||
It("performs multiple effects in order", func() {
|
||||
sink := &recordingSink{}
|
||||
c := New[tstate, tevent, teffect](off, tnext, sink)
|
||||
Expect(c.Apply(burst)).To(Succeed())
|
||||
Expect(sink.snapshot()).To(Equal([]teffect{"a", "b", "c"}))
|
||||
})
|
||||
|
||||
It("returns the transition error and leaves state unchanged", func() {
|
||||
sink := &recordingSink{}
|
||||
c := New[tstate, tevent, teffect](on, tnext, sink)
|
||||
err := c.Apply(boom)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(c.State()).To(Equal(on), "state unchanged on error")
|
||||
Expect(sink.snapshot()).To(BeEmpty(), "no effects performed on error")
|
||||
})
|
||||
|
||||
It("serializes concurrent Apply from many goroutines (run with -race)", func() {
|
||||
const goroutines = 8
|
||||
const each = 1000
|
||||
sink := &recordingSink{}
|
||||
c := New[tstate, tevent, teffect](off, tnext, sink)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range each {
|
||||
_ = c.Apply(toggle)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// goroutines*each toggles from off; an even total returns to off. The
|
||||
// point is race-freedom + a consistent final state, not the value itself.
|
||||
Expect(c.State()).To(Equal(off))
|
||||
Expect(sink.snapshot()).To(HaveLen(goroutines * each))
|
||||
})
|
||||
})
|
||||
@@ -13,6 +13,14 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// reachedTokenBudget reports whether generation stopped because it reached the
|
||||
// configured max_tokens ceiling. A maxTokens of nil or <= 0 means "no limit".
|
||||
// Used to suppress regeneration retries (which would just hit the same ceiling
|
||||
// again) and to report finish_reason "length" instead of "stop" (issue #9716).
|
||||
func reachedTokenBudget(completion int, maxTokens *int) bool {
|
||||
return maxTokens != nil && *maxTokens > 0 && completion >= *maxTokens
|
||||
}
|
||||
|
||||
func ComputeChoices(
|
||||
req *schema.OpenAIRequest,
|
||||
predInput string,
|
||||
@@ -113,11 +121,21 @@ func ComputeChoices(
|
||||
}
|
||||
prediction = p
|
||||
|
||||
// budgetExhausted is true when the model stopped because it reached
|
||||
// the configured max_tokens ceiling. None of the retry paths below
|
||||
// should fire in that case: regenerating would just hit the same
|
||||
// ceiling again and multiply token consumption (issue #9716). A
|
||||
// thinking model that spends its whole budget on the reasoning block
|
||||
// produces an empty content / reasoning-only response, which would
|
||||
// otherwise look like a failed generation worth retrying. This is a
|
||||
// "length" finish, not an empty one.
|
||||
budgetExhausted := reachedTokenBudget(prediction.Usage.Completion, config.Maxtokens)
|
||||
|
||||
// Built-in: retry on truly empty response (no tokens at all).
|
||||
// However, when the C++ autoparser is active, it clears the raw
|
||||
// message and delivers content via ChatDeltas instead. Do NOT
|
||||
// retry if ChatDeltas contain tool calls or content.
|
||||
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries {
|
||||
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries && !budgetExhausted {
|
||||
hasChatDeltaData := false
|
||||
for _, d := range prediction.ChatDeltas {
|
||||
if d.Content != "" || len(d.ToolCalls) > 0 {
|
||||
@@ -159,7 +177,7 @@ func ComputeChoices(
|
||||
}
|
||||
}
|
||||
}
|
||||
if shouldRetryFn != nil && !skipCallerRetry && shouldRetryFn(attempt) && attempt < maxRetries {
|
||||
if shouldRetryFn != nil && !skipCallerRetry && !budgetExhausted && shouldRetryFn(attempt) && attempt < maxRetries {
|
||||
// Caller has already reset its state inside shouldRetry
|
||||
result = result[:0]
|
||||
allChatDeltas = nil
|
||||
|
||||
@@ -393,6 +393,73 @@ var _ = Describe("ComputeChoices", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("reachedTokenBudget", func() {
|
||||
ptr := func(i int) *int { return &i }
|
||||
It("is false when no limit is configured", func() {
|
||||
Expect(reachedTokenBudget(1000, nil)).To(BeFalse())
|
||||
Expect(reachedTokenBudget(1000, ptr(0))).To(BeFalse())
|
||||
Expect(reachedTokenBudget(1000, ptr(-1))).To(BeFalse())
|
||||
})
|
||||
It("is false when generation stopped below the limit", func() {
|
||||
Expect(reachedTokenBudget(99, ptr(100))).To(BeFalse())
|
||||
})
|
||||
It("is true when generation reached or exceeded the limit", func() {
|
||||
Expect(reachedTokenBudget(100, ptr(100))).To(BeTrue())
|
||||
Expect(reachedTokenBudget(101, ptr(100))).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("max_tokens budget exhausted on reasoning (issue #9716)", func() {
|
||||
// Reproduces the streaming retry loop: when a thinking model spends its
|
||||
// entire max_tokens budget on the reasoning block, the C++ autoparser
|
||||
// clears the raw Response and delivers reasoning-only ChatDeltas (no
|
||||
// content, no tool calls). The built-in empty-response retry then fires
|
||||
// and regenerates from scratch up to maxRetries times, each re-consuming
|
||||
// the whole budget — instead of terminating with finish_reason "length".
|
||||
It("should NOT retry when the token budget was exhausted", func() {
|
||||
maxTokens := 100
|
||||
cfg.Maxtokens = &maxTokens
|
||||
|
||||
calls := 0
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
calls++
|
||||
// Autoparser cleared Response; only reasoning was produced,
|
||||
// and the completion count reached the max_tokens budget.
|
||||
return backend.LLMResponse{
|
||||
Response: "",
|
||||
ChatDeltas: []*pb.ChatDelta{{ReasoningContent: "thinking..."}},
|
||||
Usage: backend.TokenUsage{Prompt: 5, Completion: maxTokens},
|
||||
}, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
|
||||
_, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// The model hit its token ceiling; regenerating would just hit it
|
||||
// again and multiply token consumption. Exactly one call expected.
|
||||
Expect(calls).To(Equal(1), "budget-exhausted generation must not be retried")
|
||||
Expect(usage.Completion).To(Equal(maxTokens))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with streaming token callback", func() {
|
||||
It("should call tokenCallback for streaming responses", func() {
|
||||
var streamedTokens []string
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user