mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-30 11:26:32 -04:00
Compare commits
42 Commits
dependabot
...
fix/watchd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
347cdcf545 | ||
|
|
0e381897b5 | ||
|
|
b1af37257d | ||
|
|
ebefa6dcca | ||
|
|
605348925d | ||
|
|
686ce10b54 | ||
|
|
2cee318fad | ||
|
|
1a4f68ed4a | ||
|
|
28d7397743 | ||
|
|
5d0c43ec6e | ||
|
|
6ab29ec8b9 | ||
|
|
036f950b1b | ||
|
|
5b7b914b4f | ||
|
|
d1cee4c52a | ||
|
|
baaa0fe94f | ||
|
|
c3b5c7c3fa | ||
|
|
bd1ec8f2c2 | ||
|
|
135debf9af | ||
|
|
e8c18ae28e | ||
|
|
c4d302e1ab | ||
|
|
323b57a4bc | ||
|
|
3d2f639213 | ||
|
|
be1ae9338b | ||
|
|
923c47020d | ||
|
|
b7a1dec773 | ||
|
|
de2ec2f136 | ||
|
|
d3a26f961d | ||
|
|
13b1ae53bc | ||
|
|
e68ca109c5 | ||
|
|
6740e988d2 | ||
|
|
ade9cc9e37 | ||
|
|
471e38e4e7 | ||
|
|
f3d829e2ef | ||
|
|
91885c2c7e | ||
|
|
f1fcafb888 | ||
|
|
fdff114701 | ||
|
|
1154be5eea | ||
|
|
8aba4fdba3 | ||
|
|
d7d7721eae | ||
|
|
c548150f99 | ||
|
|
ec26b86dd4 | ||
|
|
d11b202dd2 |
@@ -7,8 +7,11 @@
|
||||
# Runs only the checks relevant to what's staged:
|
||||
# - Go files -> make lint + make test-coverage-check
|
||||
# - core/http/react-ui -> make test-ui-coverage-check (Playwright e2e + gate)
|
||||
# A commit touching neither is skipped entirely (docs/YAML/etc. can't change
|
||||
# lint findings, Go coverage, or the UI).
|
||||
# - realtime state machines / specs -> make test-realtime-conformance
|
||||
# (respcoord/**, turncoord/**, or formal-verification/** -- a pure .fizz
|
||||
# spec edit must still re-verify the design, detected separately from Go)
|
||||
# A commit touching none of these is skipped entirely (other docs/YAML can't
|
||||
# change lint findings, Go coverage, the UI, or the realtime conformance gate).
|
||||
#
|
||||
# To bypass for a single commit (e.g. a WIP checkpoint): git commit --no-verify
|
||||
set -eu
|
||||
@@ -20,11 +23,13 @@ staged="$(git diff --cached --name-only --diff-filter=ACMRD)"
|
||||
|
||||
go_changed=0
|
||||
ui_changed=0
|
||||
rt_changed=0
|
||||
if echo "$staged" | grep -qE '\.go$'; then go_changed=1; fi
|
||||
if echo "$staged" | grep -qE '^core/http/react-ui/'; then ui_changed=1; fi
|
||||
if echo "$staged" | grep -qE '^(core/http/endpoints/openai/(coordinator|respcoord|turncoord|conncoord|compactcoord|ttscoord)/|formal-verification/)'; then rt_changed=1; fi
|
||||
|
||||
if [ "$go_changed" -eq 0 ] && [ "$ui_changed" -eq 0 ]; then
|
||||
echo "pre-commit: no Go or React UI changes staged — skipping."
|
||||
if [ "$go_changed" -eq 0 ] && [ "$ui_changed" -eq 0 ] && [ "$rt_changed" -eq 0 ]; then
|
||||
echo "pre-commit: no Go, React UI, or realtime-spec changes staged — skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
@@ -57,4 +62,11 @@ if [ "$ui_changed" -eq 1 ]; then
|
||||
make test-ui-coverage-check
|
||||
fi
|
||||
|
||||
if [ "$rt_changed" -eq 1 ]; then
|
||||
echo "pre-commit ▶ realtime state-machine conformance (make test-realtime-conformance) —"
|
||||
echo " Go transition/rapid tests under -race + FizzBee model check of the"
|
||||
echo " authoritative specs. Fail-closed: needs FizzBee (make install-fizzbee)."
|
||||
make test-realtime-conformance
|
||||
fi
|
||||
|
||||
echo "pre-commit ✓ all relevant checks passed"
|
||||
|
||||
307
.github/backend-matrix.yml
vendored
307
.github/backend-matrix.yml
vendored
@@ -3745,6 +3745,302 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# voice-detect
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-voice-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-voice-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-voice-detect'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-voice-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-voice-detect'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-voice-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-voice-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-voice-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-voice-detect'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-voice-detect'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-voice-detect'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "voice-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# face-detect
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-face-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-face-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-face-detect'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-face-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-face-detect'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-face-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-face-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-face-detect'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-face-detect'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-face-detect'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-face-detect'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "face-detect"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# acestep-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -4928,6 +5224,14 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-ced"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "voice-detect"
|
||||
tag-suffix: "-metal-darwin-arm64-voice-detect"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "face-detect"
|
||||
tag-suffix: "-metal-darwin-arm64-face-detect"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "acestep-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-acestep-cpp"
|
||||
build-type: "metal"
|
||||
@@ -4991,9 +5295,6 @@ includeDarwin:
|
||||
- backend: "qwen-tts"
|
||||
tag-suffix: "-metal-darwin-arm64-qwen-tts"
|
||||
build-type: "mps"
|
||||
- backend: "fish-speech"
|
||||
tag-suffix: "-metal-darwin-arm64-fish-speech"
|
||||
build-type: "mps"
|
||||
- backend: "voxcpm"
|
||||
tag-suffix: "-metal-darwin-arm64-voxcpm"
|
||||
build-type: "mps"
|
||||
|
||||
12
.github/workflows/backend_build_darwin.yml
vendored
12
.github/workflows/backend_build_darwin.yml
vendored
@@ -82,7 +82,7 @@ jobs:
|
||||
# as the Linux registry cache.
|
||||
- name: Restore Homebrew cache
|
||||
id: brew-cache
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v6
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/Homebrew/downloads
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
|
||||
- name: Save Homebrew cache
|
||||
if: github.event_name != 'pull_request' && steps.brew-cache.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
uses: actions/cache/save@v6
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/Homebrew/downloads
|
||||
@@ -178,7 +178,7 @@ jobs:
|
||||
- name: Restore ccache
|
||||
if: inputs.backend == 'llama-cpp'
|
||||
id: ccache-cache
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v6
|
||||
with:
|
||||
path: ~/Library/Caches/ccache
|
||||
key: ccache-llama-${{ runner.arch }}-${{ steps.llama-version.outputs.version }}-${{ github.run_id }}
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
- name: Restore Python wheel cache
|
||||
if: inputs.lang == 'python'
|
||||
id: pyenv-cache
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v6
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/pip
|
||||
@@ -256,14 +256,14 @@ jobs:
|
||||
|
||||
- name: Save ccache
|
||||
if: inputs.backend == 'llama-cpp' && github.event_name != 'pull_request'
|
||||
uses: actions/cache/save@v4
|
||||
uses: actions/cache/save@v6
|
||||
with:
|
||||
path: ~/Library/Caches/ccache
|
||||
key: ccache-llama-${{ runner.arch }}-${{ steps.llama-version.outputs.version }}-${{ github.run_id }}
|
||||
|
||||
- name: Save Python wheel cache
|
||||
if: inputs.lang == 'python' && github.event_name != 'pull_request' && steps.pyenv-cache.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
uses: actions/cache/save@v6
|
||||
with:
|
||||
path: |
|
||||
~/Library/Caches/pip
|
||||
|
||||
8
.github/workflows/bump_deps.yaml
vendored
8
.github/workflows/bump_deps.yaml
vendored
@@ -46,6 +46,14 @@ jobs:
|
||||
variable: "CED_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/ced/Makefile"
|
||||
- repository: "mudler/voice-detect.cpp"
|
||||
variable: "VOICEDETECT_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/voice-detect/Makefile"
|
||||
- repository: "mudler/face-detect.cpp"
|
||||
variable: "FACEDETECT_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/face-detect/Makefile"
|
||||
- repository: "mudler/depth-anything.cpp"
|
||||
variable: "DEPTHANYTHING_VERSION"
|
||||
branch: "master"
|
||||
|
||||
69
.github/workflows/realtime-conformance.yml
vendored
Normal file
69
.github/workflows/realtime-conformance.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
name: 'realtime-conformance'
|
||||
|
||||
# Verifies the realtime state-machine implementations conform to their formal
|
||||
# designs (docs/design/realtime-state-machines.md, formal-verification/). BOTH
|
||||
# layers are enforced and the gate is fail-closed: the Go conformance layer
|
||||
# (respcoord + turncoord transition/rapid tests under -race) AND the FizzBee model check of
|
||||
# the authoritative specs. FizzBee is pinned + checksum-verified
|
||||
# (formal-verification/fizzbee.sha256), so a failed install fails the job rather
|
||||
# than silently skipping verification.
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'core/http/endpoints/openai/coordinator/**'
|
||||
- 'core/http/endpoints/openai/respcoord/**'
|
||||
- 'core/http/endpoints/openai/turncoord/**'
|
||||
- 'core/http/endpoints/openai/conncoord/**'
|
||||
- 'core/http/endpoints/openai/compactcoord/**'
|
||||
- 'core/http/endpoints/openai/ttscoord/**'
|
||||
- 'formal-verification/**'
|
||||
- 'scripts/realtime-conformance.sh'
|
||||
- 'scripts/install-fizzbee.sh'
|
||||
- '.github/workflows/realtime-conformance.yml'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- 'core/http/endpoints/openai/coordinator/**'
|
||||
- 'core/http/endpoints/openai/respcoord/**'
|
||||
- 'core/http/endpoints/openai/turncoord/**'
|
||||
- 'core/http/endpoints/openai/conncoord/**'
|
||||
- 'core/http/endpoints/openai/compactcoord/**'
|
||||
- 'core/http/endpoints/openai/ttscoord/**'
|
||||
- 'formal-verification/**'
|
||||
- 'scripts/realtime-conformance.sh'
|
||||
|
||||
concurrency:
|
||||
group: realtime-conformance-${{ github.event.pull_request.number || github.sha }}-${{ github.repository }}
|
||||
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
|
||||
|
||||
jobs:
|
||||
conformance:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.26.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v7
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: false
|
||||
- name: Cache FizzBee
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: .tools/fizzbee
|
||||
key: fizzbee-v0.5.2-${{ runner.os }}-${{ hashFiles('formal-verification/fizzbee.sha256') }}
|
||||
- name: Install FizzBee (pinned, checksum-verified)
|
||||
# No `|| true`: a failed/forged download must fail the job, not silently
|
||||
# drop the design verification. install-fizzbee.sh is a no-op if the
|
||||
# cached binary is already present and valid.
|
||||
run: ./scripts/install-fizzbee.sh
|
||||
- name: Run conformance gate (fail-closed)
|
||||
# No skip env: both the Go conformance and the FizzBee model check are
|
||||
# required. The gate auto-detects .tools/fizzbee/fizz.
|
||||
run: make test-realtime-conformance
|
||||
6
.github/workflows/test-extra.yml
vendored
6
.github/workflows/test-extra.yml
vendored
@@ -1008,7 +1008,11 @@ jobs:
|
||||
# image + working dir.
|
||||
tests-vibevoice-cpp-grpc-transcription:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.vibevoice-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
# Skip on release tag pushes: the ASR Q4_K model is ~10 GB and cannot be
|
||||
# pulled from HF within the inner `go test -timeout 30m` budget on a CI
|
||||
# runner, so every tag build hung and timed out. Still runs on PRs/branch
|
||||
# pushes that touch vibevoice-cpp so regressions are caught off the release path.
|
||||
if: (needs.detect-changes.outputs.vibevoice-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true') && !startsWith(github.ref, 'refs/tags/')
|
||||
runs-on: bigger-runner
|
||||
timeout-minutes: 150
|
||||
steps:
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -97,3 +97,12 @@ core/http/react-ui/test-results/
|
||||
|
||||
# Local Apple signing material (never commit)
|
||||
.certs/
|
||||
|
||||
# Pinned dev tools (e.g. FizzBee for the realtime-conformance gate)
|
||||
.tools/
|
||||
|
||||
# FizzBee model-check artifacts: the parser emits <spec>.json next to each
|
||||
# .fizz and the checker writes run dirs under out/. Both are regenerated by
|
||||
# the realtime-conformance gate; only the .fizz sources are authoritative.
|
||||
formal-verification/*.json
|
||||
formal-verification/out/
|
||||
|
||||
14
Makefile
14
Makefile
@@ -405,6 +405,18 @@ test-realtime: build-mock-backend
|
||||
@echo 'Running realtime e2e tests (mock backend)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime && !real-models" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
|
||||
# Verify the realtime state-machine implementations conform to their formal
|
||||
# designs (Go transition/rapid tests under -race + FizzBee model check of the
|
||||
# authoritative specs). See docs/design/realtime-state-machines.md (Part 6) and
|
||||
# docs/design/specs/README.md.
|
||||
test-realtime-conformance:
|
||||
GOCMD=$(GOCMD) ./scripts/realtime-conformance.sh
|
||||
|
||||
# Install the pinned, checksum-verified FizzBee model checker (into .tools/,
|
||||
# gitignored) used by test-realtime-conformance. Idempotent; no-op if present.
|
||||
install-fizzbee:
|
||||
./scripts/install-fizzbee.sh
|
||||
|
||||
# Container-based real-model realtime testing. Build env vars / pipeline
|
||||
# definition kept here so test-realtime-models-docker can drive a fully wired
|
||||
# pipeline (VAD + STT + LLM + TTS) from inside a containerised runner.
|
||||
@@ -1027,7 +1039,7 @@ test-extra-backend-whisper-transcription: docker-build-whisper
|
||||
## is reachable.
|
||||
test-extra-backend-parakeet-cpp-transcription: docker-build-parakeet-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:parakeet-cpp \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/tdt_ctc-110m-f16.gguf \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/realtime_eou_120m-v1-f16.gguf \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
@@ -177,6 +177,7 @@ For more details, see the [Getting Started guide](https://localai.io/basics/gett
|
||||
|
||||
## Latest News
|
||||
|
||||
- **June 2026**: New native biometric backends from the LocalAI team: [voice-detect.cpp](https://github.com/mudler/voice-detect.cpp) for speaker recognition and voice analysis (ECAPA-TDNN, WeSpeaker, ERes2Net, CAM++, wav2vec2 age/gender/emotion) and [face-detect.cpp](https://github.com/mudler/face-detect.cpp) for face detection, recognition, demographics and anti-spoofing (SCRFD/ArcFace, YuNet/SFace). Both are from-scratch C++/ggml engines with no Python or onnxruntime at inference, self-contained GGUF weights, bit-exact parity with the reference, and GPU cuDNN parity, replacing the heavier Python `insightface` and `speaker-recognition` backends ([PR #10441](https://github.com/mudler/LocalAI/pull/10441)).
|
||||
- **June 2026**: New [realtime voice assistant demo](https://github.com/localai-org/localai-realtime-demo) (a tiny Go client for the Realtime API with a full talk-back voice loop and tool calling), plus [streaming of the realtime LLM / TTS / transcription pipeline stages](https://github.com/mudler/LocalAI/pull/10176) and [configurable WebRTC ICE candidates](https://github.com/mudler/LocalAI/pull/10231).
|
||||
- **June 2026**: Big speech push: the [parakeet.cpp](https://github.com/mudler/parakeet.cpp) ASR engine gains [NeMo-faithful segment timestamps](https://github.com/mudler/LocalAI/pull/10207), a [multilingual streaming Nemotron-3.5 model](https://github.com/mudler/LocalAI/pull/10199), [dynamic batching for concurrent transcription](https://github.com/mudler/LocalAI/pull/10112) and [CUDA graphs](https://github.com/mudler/LocalAI/pull/10273); the new [CrispASR backend](https://github.com/mudler/LocalAI/pull/10099) adds multi-architecture ASR + TTS, and [60 Piper TTS voices across 42 languages](https://github.com/mudler/LocalAI/pull/10296) land in the gallery (plus [per-request TTS instructions and params](https://github.com/mudler/LocalAI/pull/10172)).
|
||||
- **June 2026**: New backends and models: [locate-anything.cpp](https://github.com/mudler/LocalAI/pull/10264) for open-vocabulary object detection via ggml, [Ideogram4 image generation](https://github.com/mudler/LocalAI/pull/10201) in stablediffusion-ggml, [llama.cpp video input](https://github.com/mudler/LocalAI/pull/10216), and the [Gemma 4 QAT family with MTP speculative-decoding pairs](https://github.com/mudler/LocalAI/pull/10215). Plus an [interactive CLI chat mode](https://github.com/mudler/LocalAI/pull/10226) and [RAG source citations in agent responses](https://github.com/mudler/LocalAI/pull/10228).
|
||||
|
||||
@@ -137,7 +137,7 @@ RUN <<EOT bash
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} libcudnn9-dev-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -18,6 +18,18 @@ service Backend {
|
||||
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
|
||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||
rpc AudioTranscriptionStream(TranscriptRequest) returns (stream TranscriptStreamResponse) {}
|
||||
// AudioTranscriptionLive is the bidirectional live-microphone ASR RPC. The
|
||||
// first message MUST carry a Config; subsequent messages carry Audio frames
|
||||
// (mono float PCM at config.sample_rate, 16 kHz default). After a
|
||||
// successful open the backend replies with a single ready ack
|
||||
// (TranscriptLiveResponse{ready:true}); backends or models without
|
||||
// cache-aware streaming support return UNIMPLEMENTED instead. Newly
|
||||
// finalized text streams back as deltas; eou=true marks the model's
|
||||
// end-of-utterance token. One stream spans many utterances (the decoder
|
||||
// resets itself after each EOU). Closing the send side finalizes: the
|
||||
// backend flushes the decoder tail and emits a terminal message carrying
|
||||
// final_result. A second Config mid-stream resets the decode session.
|
||||
rpc AudioTranscriptionLive(stream TranscriptLiveRequest) returns (stream TranscriptLiveResponse) {}
|
||||
rpc TTS(TTSRequest) returns (Result) {}
|
||||
rpc TTSStream(TTSRequest) returns (stream Reply) {}
|
||||
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||
@@ -479,6 +491,10 @@ message TranscriptResult {
|
||||
string text = 2;
|
||||
string language = 3;
|
||||
float duration = 4;
|
||||
// True when the decode ended on the model's end-of-utterance special token
|
||||
// (<EOU>/<EOB>, emitted by cache-aware streaming models such as
|
||||
// parakeet_realtime_eou_120m-v1). The marker itself is stripped from text.
|
||||
bool eou = 5;
|
||||
}
|
||||
|
||||
message TranscriptStreamResponse {
|
||||
@@ -486,6 +502,34 @@ message TranscriptStreamResponse {
|
||||
TranscriptResult final_result = 2;
|
||||
}
|
||||
|
||||
// === AudioTranscriptionLive messages =====================================
|
||||
|
||||
message TranscriptLiveRequest {
|
||||
oneof payload {
|
||||
TranscriptLiveConfig config = 1;
|
||||
TranscriptLiveAudio audio = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message TranscriptLiveConfig {
|
||||
string language = 1; // "" => model default
|
||||
int32 sample_rate = 2; // 0 => 16000; backends may reject others
|
||||
map<string, string> params = 3; // backend-specific tuning
|
||||
}
|
||||
|
||||
message TranscriptLiveAudio {
|
||||
repeated float pcm = 1; // mono PCM in [-1,1] at config.sample_rate
|
||||
}
|
||||
|
||||
message TranscriptLiveResponse {
|
||||
bool ready = 1; // open ack: sent once, before any delta
|
||||
string delta = 2; // newly-finalized text since previous response
|
||||
bool eou = 3; // <EOU> fired during this feed (the user yielded the turn)
|
||||
repeated TranscriptWord words = 4; // words finalized by this feed (stream-relative ns)
|
||||
TranscriptResult final_result = 5; // terminal message only, after the send side closes
|
||||
bool eob = 6; // <EOB> fired: a backchannel ("uh-huh") ended — NOT a turn boundary
|
||||
}
|
||||
|
||||
message TranscriptWord {
|
||||
int64 start = 1;
|
||||
int64 end = 2;
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
## Clip/LLaVA library for multimodal support — built locally from copied sources
|
||||
set(TARGET myclip)
|
||||
add_library(${TARGET} clip.cpp clip.h llava.cpp llava.h)
|
||||
install(TARGETS ${TARGET} LIBRARY)
|
||||
target_include_directories(myclip PUBLIC .)
|
||||
target_include_directories(myclip PUBLIC ../..)
|
||||
target_include_directories(myclip PUBLIC ../../common)
|
||||
target_link_libraries(${TARGET} PRIVATE common ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if (NOT MSVC)
|
||||
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual)
|
||||
endif()
|
||||
## Multimodal support is provided by the in-tree `mtmd` library target
|
||||
## (examples/mtmd/), which the grpc-server links and includes below. clip/llava
|
||||
## were pruned upstream; the high-level mtmd_* / mtmd_helper_* API is used instead.
|
||||
|
||||
set(TARGET grpc-server)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -67,12 +58,16 @@ add_library(hw_grpc_proto
|
||||
${hw_proto_hdrs} )
|
||||
|
||||
add_executable(${TARGET} grpc-server.cpp json.hpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama myclip ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
|
||||
# mtmd public headers (mtmd.h / mtmd-helper.h) live in examples/mtmd/.
|
||||
# Linking the mtmd target also propagates this include dir, but we add it
|
||||
# explicitly for clarity.
|
||||
target_include_directories(${TARGET} PRIVATE ../mtmd)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama mtmd ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
|
||||
absl::flags_parse
|
||||
gRPC::${_REFLECTION}
|
||||
gRPC::${_GRPC_GRPCPP}
|
||||
protobuf::${_PROTOBUF_LIBPROTOBUF})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=b84902d2ad27c34f989f23947200c4b91b1568fd
|
||||
IK_LLAMA_VERSION?=f74a6fb87b315b2c3154166e075360e15021a61d
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <getopt.h>
|
||||
#include "clip.h"
|
||||
#include "llava.h"
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "json.hpp"
|
||||
@@ -45,7 +45,9 @@ using backend::HealthMessage;
|
||||
|
||||
///// LLAMA.CPP server code below
|
||||
|
||||
using json = nlohmann::json;
|
||||
// Match mtmd.h and ik_llama's server/common headers, which all use
|
||||
// nlohmann::ordered_json; a plain nlohmann::json alias collides at global scope.
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct server_params
|
||||
{
|
||||
@@ -219,6 +221,11 @@ struct llama_client_slot
|
||||
|
||||
// multimodal
|
||||
std::vector<slot_image> images;
|
||||
// Full prompt with mtmd media markers (mtmd_default_marker()) substituted in
|
||||
// place of the legacy [img-N] tags, covering the text up to and including the
|
||||
// last image. The text after the last image is kept in params.input_suffix and
|
||||
// decoded through the normal token path so the sampling loop is unchanged.
|
||||
std::string mtmd_prompt;
|
||||
|
||||
// stats
|
||||
size_t sent_count = 0;
|
||||
@@ -252,14 +259,14 @@ struct llama_client_slot
|
||||
|
||||
for (slot_image & img : images)
|
||||
{
|
||||
free(img.image_embedding);
|
||||
if (img.img_data) {
|
||||
clip_image_u8_free(img.img_data);
|
||||
if (img.bitmap) {
|
||||
mtmd_bitmap_free(img.bitmap);
|
||||
img.bitmap = nullptr;
|
||||
}
|
||||
img.prefix_prompt = "";
|
||||
}
|
||||
|
||||
images.clear();
|
||||
mtmd_prompt = "";
|
||||
}
|
||||
|
||||
bool has_budget(gpt_params &global_params) {
|
||||
@@ -396,46 +403,13 @@ struct llama_metrics {
|
||||
}
|
||||
};
|
||||
|
||||
struct llava_embd_batch {
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_server_context
|
||||
{
|
||||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
|
||||
clip_ctx *clp_ctx = nullptr;
|
||||
mtmd_context *mctx = nullptr;
|
||||
|
||||
gpt_params params;
|
||||
|
||||
@@ -491,11 +465,6 @@ struct llama_server_context
|
||||
if (!params.mmproj.path.empty()) {
|
||||
multimodal = true;
|
||||
LOG_INFO("Multi Modal Mode Enabled", {});
|
||||
clp_ctx = clip_model_load(params.mmproj.path.c_str(), /*verbosity=*/ 1);
|
||||
if(clp_ctx == nullptr) {
|
||||
LOG_ERR("unable to load clip model: %s", params.mmproj.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (params.n_ctx < 2048) { // request larger context for the image embedding
|
||||
params.n_ctx = 2048;
|
||||
@@ -512,10 +481,24 @@ struct llama_server_context
|
||||
}
|
||||
|
||||
if (multimodal) {
|
||||
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
|
||||
const int n_embd_llm = llama_model_n_embd(model);
|
||||
if (n_embd_clip != n_embd_llm) {
|
||||
LOG("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm);
|
||||
// mtmd_init_from_file requires the already-loaded text model, so it must
|
||||
// run AFTER llama_init_from_gpt_params. It validates the projector
|
||||
// against the model internally and returns nullptr on dim mismatch, so
|
||||
// the explicit clip_n_mmproj_embd check is no longer needed.
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params.n_threads_mtmd != -1 ? params.n_threads_mtmd
|
||||
: params.n_threads_batch != -1 ? params.n_threads_batch
|
||||
: params.n_threads;
|
||||
mparams.verbosity = GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED
|
||||
: LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
mctx = mtmd_init_from_file(params.mmproj.path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
LOG_ERR("unable to load multimodal projector: %s", params.mmproj.path.c_str());
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
return false;
|
||||
@@ -865,8 +848,8 @@ struct llama_server_context
|
||||
|
||||
slot_image img_sl;
|
||||
img_sl.id = img.count("id") != 0 ? img["id"].get<int>() : slot->images.size();
|
||||
img_sl.img_data = clip_image_u8_init();
|
||||
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
|
||||
img_sl.bitmap = mtmd_helper_bitmap_init_from_buf(mctx, image_buffer.data(), image_buffer.size());
|
||||
if (img_sl.bitmap == nullptr)
|
||||
{
|
||||
LOG_ERR("%s: failed to load image, slot_id: %d, img_sl_id: %d",
|
||||
__func__,
|
||||
@@ -879,50 +862,74 @@ struct llama_server_context
|
||||
{"slot_id", slot->id},
|
||||
{"img_sl_id", img_sl.id}
|
||||
});
|
||||
img_sl.request_encode_image = true;
|
||||
slot->images.push_back(img_sl);
|
||||
}
|
||||
// process prompt
|
||||
// example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]}
|
||||
// Translate the legacy [img-N] tags into mtmd media markers, in
|
||||
// order, and collect the matching bitmaps in marker order so they
|
||||
// line up with the markers passed to mtmd_tokenize(). The text after
|
||||
// the last image stays in input_suffix and is decoded through the
|
||||
// normal token path, so the sampling loop is unchanged.
|
||||
// example: system prompt [img-102] user [img-103] describe [img-134]
|
||||
if (slot->images.size() > 0 && !slot->prompt.is_array())
|
||||
{
|
||||
const std::string marker = mtmd_default_marker();
|
||||
std::string prompt = slot->prompt.get<std::string>();
|
||||
size_t pos = 0, begin_prefix = 0;
|
||||
std::string built_prompt;
|
||||
std::vector<slot_image> ordered;
|
||||
size_t pos = 0, copy_from = 0;
|
||||
std::string pattern = "[img-";
|
||||
while ((pos = prompt.find(pattern, pos)) != std::string::npos) {
|
||||
size_t end_prefix = pos;
|
||||
pos += pattern.length();
|
||||
size_t end_pos = prompt.find(']', pos);
|
||||
if (end_pos != std::string::npos)
|
||||
{
|
||||
std::string image_id = prompt.substr(pos, end_pos - pos);
|
||||
try
|
||||
{
|
||||
int img_id = std::stoi(image_id);
|
||||
bool found = false;
|
||||
for (slot_image &img : slot->images)
|
||||
{
|
||||
if (img.id == img_id) {
|
||||
found = true;
|
||||
img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix);
|
||||
begin_prefix = end_pos + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
LOG("ERROR: Image with id: %i, not found.\n", img_id);
|
||||
slot->images.clear();
|
||||
return false;
|
||||
}
|
||||
} catch (const std::invalid_argument& e) {
|
||||
LOG("Invalid image number id in prompt\n");
|
||||
slot->images.clear();
|
||||
return false;
|
||||
|
||||
auto free_images = [&]() {
|
||||
for (slot_image &img : slot->images) {
|
||||
if (img.bitmap) {
|
||||
mtmd_bitmap_free(img.bitmap);
|
||||
img.bitmap = nullptr;
|
||||
}
|
||||
}
|
||||
slot->images.clear();
|
||||
};
|
||||
|
||||
while ((pos = prompt.find(pattern, pos)) != std::string::npos) {
|
||||
size_t tag_begin = pos;
|
||||
pos += pattern.length();
|
||||
size_t end_pos = prompt.find(']', pos);
|
||||
if (end_pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
std::string image_id = prompt.substr(pos, end_pos - pos);
|
||||
try
|
||||
{
|
||||
int img_id = std::stoi(image_id);
|
||||
bool found = false;
|
||||
for (slot_image &img : slot->images)
|
||||
{
|
||||
if (img.id == img_id) {
|
||||
found = true;
|
||||
// text before this tag, then the media marker
|
||||
built_prompt += prompt.substr(copy_from, tag_begin - copy_from);
|
||||
built_prompt += marker;
|
||||
copy_from = end_pos + 1;
|
||||
ordered.push_back(img);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
LOG("ERROR: Image with id: %i, not found.\n", img_id);
|
||||
free_images();
|
||||
return false;
|
||||
}
|
||||
} catch (const std::invalid_argument& e) {
|
||||
LOG("Invalid image number id in prompt\n");
|
||||
free_images();
|
||||
return false;
|
||||
}
|
||||
pos = end_pos + 1;
|
||||
}
|
||||
// bitmaps are consumed in marker order by mtmd_tokenize()
|
||||
slot->images = ordered;
|
||||
slot->mtmd_prompt = built_prompt;
|
||||
slot->prompt = "";
|
||||
slot->params.input_suffix = prompt.substr(begin_prefix);
|
||||
slot->params.input_suffix = prompt.substr(copy_from);
|
||||
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
|
||||
}
|
||||
}
|
||||
@@ -1176,21 +1183,10 @@ struct llama_server_context
|
||||
|
||||
bool process_images(llama_client_slot &slot) const
|
||||
{
|
||||
for (slot_image &img : slot.images)
|
||||
{
|
||||
if (!img.request_encode_image)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!llava_image_embed_make_with_clip_img(clp_ctx, params.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) {
|
||||
LOG("Error processing the given image");
|
||||
return false;
|
||||
}
|
||||
|
||||
img.request_encode_image = false;
|
||||
}
|
||||
|
||||
// With the mtmd pipeline, image encoding is no longer eager: the bitmaps
|
||||
// are tokenized and encoded together with the surrounding text inside
|
||||
// ingest_images() via mtmd_tokenize() + mtmd_helper_eval_chunks(). This
|
||||
// just reports whether the slot carries any images to process.
|
||||
return slot.images.size() > 0;
|
||||
}
|
||||
|
||||
@@ -1435,69 +1431,70 @@ struct llama_server_context
|
||||
}
|
||||
}
|
||||
|
||||
// for multiple images processing
|
||||
// Tokenize the multimodal prompt (text interleaved with media markers) together
|
||||
// with the slot's bitmaps, then decode the resulting chunks into the llama
|
||||
// context via the high-level mtmd helper. The helper runs llama_decode() on the
|
||||
// text chunks and mtmd_encode() + llama_decode() on the image chunks, handling
|
||||
// batching and any pre/post decode setup (e.g. non-causal attention for gemma3).
|
||||
// Advances slot.n_past by the number of positions consumed, then leaves the
|
||||
// post-image suffix tokens in `batch` so the normal decode + sampling loop
|
||||
// produces the first generated token.
|
||||
bool ingest_images(llama_client_slot &slot, int n_batch)
|
||||
{
|
||||
int image_idx = 0;
|
||||
|
||||
while (image_idx < (int) slot.images.size())
|
||||
if (mctx == nullptr)
|
||||
{
|
||||
slot_image &img = slot.images[image_idx];
|
||||
LOG("%s : multimodal context is not initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// process prefix prompt
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
||||
{
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
if (llama_decode(ctx, batch_view))
|
||||
{
|
||||
LOG("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// bitmaps stay owned by slot.images (freed on reset()); pass non-owning ptrs
|
||||
std::vector<const mtmd_bitmap *> bitmaps;
|
||||
bitmaps.reserve(slot.images.size());
|
||||
for (const slot_image &img : slot.images)
|
||||
{
|
||||
bitmaps.push_back(img.bitmap);
|
||||
}
|
||||
|
||||
// process image with llm
|
||||
for (int i = 0; i < img.image_tokens; i += n_batch)
|
||||
{
|
||||
int n_eval = img.image_tokens - i;
|
||||
if (n_eval > n_batch)
|
||||
{
|
||||
n_eval = n_batch;
|
||||
}
|
||||
mtmd_input_text inp_txt;
|
||||
inp_txt.text = slot.mtmd_prompt.c_str();
|
||||
inp_txt.add_special = add_bos_token;
|
||||
inp_txt.parse_special = true;
|
||||
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
float * embd = img.image_embedding + i * n_embd;
|
||||
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, slot.n_past, 0);
|
||||
if (llama_decode(ctx, llava_batch.batch))
|
||||
{
|
||||
LOG("%s : failed to eval image\n", __func__);
|
||||
return false;
|
||||
}
|
||||
slot.n_past += n_eval;
|
||||
}
|
||||
image_idx++;
|
||||
mtmd::input_chunks chunks(mtmd_input_chunks_init());
|
||||
int32_t res = mtmd_tokenize(mctx,
|
||||
chunks.ptr.get(),
|
||||
&inp_txt,
|
||||
bitmaps.data(),
|
||||
bitmaps.size());
|
||||
if (res != 0)
|
||||
{
|
||||
LOG("%s : failed to tokenize multimodal prompt, res = %d\n", __func__, res);
|
||||
return false;
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
const llama_pos start_pos = (llama_pos) system_tokens.size() + slot.n_past;
|
||||
llama_pos new_n_past = start_pos;
|
||||
if (mtmd_helper_eval_chunks(mctx,
|
||||
ctx,
|
||||
chunks.ptr.get(),
|
||||
start_pos,
|
||||
slot.id,
|
||||
n_batch,
|
||||
/*logits_last=*/ false,
|
||||
&new_n_past) != 0)
|
||||
{
|
||||
LOG("%s : failed to eval multimodal chunks\n", __func__);
|
||||
return false;
|
||||
}
|
||||
slot.n_past += (int32_t) (new_n_past - start_pos);
|
||||
|
||||
// append prefix of next image
|
||||
const auto json_prompt = (image_idx >= (int) slot.images.size()) ?
|
||||
slot.params.input_suffix : // no more images, then process suffix prompt
|
||||
(json)(slot.images[image_idx].prefix_prompt);
|
||||
|
||||
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
|
||||
for (int i = 0; i < (int) append_tokens.size(); ++i)
|
||||
{
|
||||
common_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
|
||||
slot.n_past += 1;
|
||||
}
|
||||
// queue the post-image suffix text for the normal decode + sampling path
|
||||
common_batch_clear(batch);
|
||||
std::vector<llama_token> suffix_tokens = tokenize(slot.params.input_suffix, false);
|
||||
for (llama_token tok : suffix_tokens)
|
||||
{
|
||||
common_batch_add(batch, tok, system_tokens.size() + slot.n_past, { slot.id }, false);
|
||||
slot.n_past += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -1884,8 +1881,11 @@ struct llama_server_context
|
||||
|
||||
const bool has_images = process_images(slot);
|
||||
|
||||
// process the prefix of first image
|
||||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
||||
// For the multimodal path the whole pre-image / inter-image text is
|
||||
// tokenized and decoded inside ingest_images() via mtmd, so no prefix
|
||||
// tokens are queued here; the post-image suffix is appended by
|
||||
// ingest_images() for the normal decode + sampling loop.
|
||||
std::vector<llama_token> prefix_tokens = has_images ? std::vector<llama_token>() : prompt_tokens;
|
||||
|
||||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
--- a/examples/llava/clip.cpp
|
||||
+++ b/examples/llava/clip.cpp
|
||||
@@ -2494,7 +2494,7 @@
|
||||
}
|
||||
new_data = work.data();
|
||||
|
||||
- new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr);
|
||||
+ new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr, nullptr);
|
||||
} else {
|
||||
new_type = cur->type;
|
||||
new_data = cur->data;
|
||||
@@ -17,28 +17,9 @@ cp -r grpc-server.cpp llama.cpp/examples/grpc-server/
|
||||
cp -r utils.hpp llama.cpp/examples/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/examples/grpc-server/
|
||||
|
||||
## Copy clip/llava files for multimodal support (built as myclip library)
|
||||
cp -rfv llama.cpp/examples/llava/clip.h llama.cpp/examples/grpc-server/clip.h
|
||||
cp -rfv llama.cpp/examples/llava/clip.cpp llama.cpp/examples/grpc-server/clip.cpp
|
||||
cp -rfv llama.cpp/examples/llava/llava.cpp llama.cpp/examples/grpc-server/llava.cpp
|
||||
# Prepend llama.h include to llava.h
|
||||
echo '#include "llama.h"' > llama.cpp/examples/grpc-server/llava.h
|
||||
cat llama.cpp/examples/llava/llava.h >> llama.cpp/examples/grpc-server/llava.h
|
||||
# Copy clip-impl.h if it exists
|
||||
if [ -f llama.cpp/examples/llava/clip-impl.h ]; then
|
||||
cp -rfv llama.cpp/examples/llava/clip-impl.h llama.cpp/examples/grpc-server/clip-impl.h
|
||||
fi
|
||||
# Copy stb_image.h
|
||||
if [ -f llama.cpp/vendor/stb/stb_image.h ]; then
|
||||
cp -rfv llama.cpp/vendor/stb/stb_image.h llama.cpp/examples/grpc-server/stb_image.h
|
||||
elif [ -f llama.cpp/common/stb_image.h ]; then
|
||||
cp -rfv llama.cpp/common/stb_image.h llama.cpp/examples/grpc-server/stb_image.h
|
||||
fi
|
||||
|
||||
## Fix API compatibility in llava.cpp (llama_n_embd -> llama_model_n_embd)
|
||||
if [ -f llama.cpp/examples/grpc-server/llava.cpp ]; then
|
||||
sed -i 's/llama_n_embd(/llama_model_n_embd(/g' llama.cpp/examples/grpc-server/llava.cpp
|
||||
fi
|
||||
## Multimodal support is provided by the `mtmd` library target (examples/mtmd/),
|
||||
## which the grpc-server links and includes directly. No source copy is needed:
|
||||
## clip/llava were pruned upstream and the high-level mtmd_* API is used instead.
|
||||
|
||||
set +e
|
||||
if grep -q "grpc-server" llama.cpp/examples/CMakeLists.txt; then
|
||||
|
||||
@@ -11,9 +11,12 @@
|
||||
|
||||
#include "json.hpp"
|
||||
|
||||
#include "clip.h"
|
||||
#include "mtmd.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
// mtmd.h and ik_llama's entire server/common stack (chat.h, server-common.h,
|
||||
// server-task.h, ...) declare `using json = nlohmann::ordered_json`, so match it
|
||||
// here: a plain `nlohmann::json` alias collides with mtmd.h's at global scope.
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
extern bool server_verbose;
|
||||
|
||||
@@ -111,13 +114,12 @@ struct slot_image
|
||||
{
|
||||
int32_t id;
|
||||
|
||||
bool request_encode_image = false;
|
||||
float * image_embedding = nullptr;
|
||||
int32_t image_tokens = 0;
|
||||
|
||||
clip_image_u8 * img_data;
|
||||
|
||||
std::string prefix_prompt; // before of this image
|
||||
// mtmd bitmap (image/audio) decoded from the request buffer. Owned by the
|
||||
// slot; freed via mtmd_bitmap_free() on reset. The high-level mtmd pipeline
|
||||
// (mtmd_tokenize + mtmd_helper_eval_chunks) consumes these directly, so the
|
||||
// legacy eager-encode fields (embedding/tokens) and per-image prefix prompt
|
||||
// are no longer needed.
|
||||
mtmd_bitmap * bitmap = nullptr;
|
||||
};
|
||||
|
||||
// completion token output with probabilities
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=9d5d882d8cd0f0a9283d87ed5e6fe3ee0d925fb1
|
||||
LLAMA_VERSION?=6f4f53f2b7da54fcdbbecaaa734337c337ad6176
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
@@ -156,11 +156,11 @@ llama-cpp-grpc: llama.cpp
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build purge
|
||||
$(info ${GREEN}I llama-cpp build info:grpc${RESET})
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_RPC=ON -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" TARGET="--target grpc-server --target rpc-server" $(MAKE) VARIANT="llama-cpp-grpc-build" build-llama-cpp-grpc-server
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_RPC=ON -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" TARGET="--target grpc-server --target ggml-rpc-server" $(MAKE) VARIANT="llama-cpp-grpc-build" build-llama-cpp-grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/grpc-server llama-cpp-grpc
|
||||
|
||||
llama-cpp-rpc-server: llama-cpp-grpc
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/llama.cpp/build/bin/rpc-server llama-cpp-rpc-server
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/llama.cpp/build/bin/ggml-rpc-server llama-cpp-rpc-server
|
||||
|
||||
llama.cpp:
|
||||
mkdir -p llama.cpp
|
||||
|
||||
@@ -30,6 +30,19 @@
|
||||
#define LOCALAI_HAS_SERVER_SCHEMA 1
|
||||
#include "server-schema.cpp"
|
||||
#endif
|
||||
// server-stream.cpp exists only in llama.cpp after the upstream refactor that
|
||||
// added the SSE stream-resumption layer (stream_session/stream_pipe_producer).
|
||||
// server-context.cpp calls into it (spipe->cleanup(), stream_aware_should_stop,
|
||||
// stream_session_attach_pipe), so its definitions must be part of this
|
||||
// translation unit or the link fails with "undefined reference to
|
||||
// stream_pipe_producer::cleanup()". The file is self-contained (its only
|
||||
// external symbols come from server-common, already pulled in above) and the
|
||||
// http route-handler factories it also defines are unused here but harmless.
|
||||
// __has_include keeps the source compatible with older pins/forks that predate
|
||||
// the split.
|
||||
#if __has_include("server-stream.cpp")
|
||||
#include "server-stream.cpp"
|
||||
#endif
|
||||
#include "server-context.cpp"
|
||||
|
||||
// LocalAI
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
# Local development: point at a working checkout instead of cloning, e.g.
|
||||
# make PRIVACY_FILTER_SRC=$HOME/c/privacy-filter.cpp grpc-server
|
||||
|
||||
PRIVACY_FILTER_VERSION?=98f52c5ef2250f207cc6b9a6aef05393a120cb7c
|
||||
PRIVACY_FILTER_VERSION?=595f59630c69d361b5196f2aba2c71c873d0c13c
|
||||
PRIVACY_FILTER_REPO?=https://github.com/localai-org/privacy-filter.cpp
|
||||
PRIVACY_FILTER_SRC?=
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=8f1218141b792b8868861c1af17ba1e361b05dc0
|
||||
CRISPASR_VERSION?=3b93758f9725d400eca82976f895e4cec3f31260
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
18
backend/go/face-detect/.gitignore
vendored
Normal file
18
backend/go/face-detect/.gitignore
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
# Fetched upstream sources
|
||||
sources/
|
||||
|
||||
# CMake build directories
|
||||
build*/
|
||||
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in face-detect.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
facedetect_capi.h
|
||||
compile_commands.json
|
||||
|
||||
# Compiled backend binary
|
||||
face-detect-grpc
|
||||
|
||||
# Packaging output
|
||||
package/
|
||||
110
backend/go/face-detect/Makefile
Normal file
110
backend/go/face-detect/Makefile
Normal file
@@ -0,0 +1,110 @@
|
||||
# face-detect backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as FACEDETECT_VERSION?=e22260d5d5490b37b021b7f795079f386d553afd
|
||||
# can find and update it - matches the voice-detect / parakeet.cpp / whisper.cpp
|
||||
# convention).
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree face-detect.cpp build,
|
||||
# symlink the .so + header into this directory and skip the clone/cmake steps:
|
||||
#
|
||||
# ln -sf /path/to/face-detect.cpp/build-shared/libfacedetect.so .
|
||||
# ln -sf /path/to/face-detect.cpp/include/facedetect_capi.h .
|
||||
# go build -o face-detect-grpc .
|
||||
#
|
||||
# The default target below does the proper clone-at-pin + cmake build so CI does
|
||||
# not need a side-checkout.
|
||||
|
||||
FACEDETECT_VERSION?=e22260d5d5490b37b021b7f795079f386d553afd
|
||||
FACEDETECT_REPO?=https://github.com/mudler/face-detect.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# Resolve the target arch. The backend matrix / Docker build pass TARGETARCH
|
||||
# (amd64|arm64); fall back to uname -m (aarch64|x86_64) for a local build.
|
||||
RECON_ARCH?=$(or $(TARGETARCH),$(shell uname -m))
|
||||
|
||||
# Build ggml + the vendored libjpeg-turbo statically into libfacedetect.so (PIC)
|
||||
# so the shared lib is self-contained: dlopen needs no libggml*.so alongside it,
|
||||
# only system libs (libstdc++/libgomp/libc) the runtime image already provides.
|
||||
# The vendored jpeg symbols are hidden via -Wl,--exclude-libs,ALL on the C++
|
||||
# side, so only the facedetect_capi_* surface is exported.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DFACEDETECT_SHARED=ON -DFACEDETECT_BUILD_CLI=OFF -DFACEDETECT_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# face-detect.cpp gates its GGML backends behind FACEDETECT_GGML_* options and
|
||||
# does set(GGML_CUDA ${FACEDETECT_GGML_CUDA} CACHE BOOL "" FORCE), so a bare
|
||||
# -DGGML_CUDA=ON is overwritten back to OFF. Forward the FACEDETECT_GGML_*
|
||||
# options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DFACEDETECT_GGML_CUDA=ON
|
||||
# Opt-in cuDNN implicit-GEMM conv path (kills im2col on GPU, SCRFD 2.3x
|
||||
# vs torch-cuDNN parity). Only the arm64 + CUDA 13 image (GB10/Jetson/L4T)
|
||||
# ships libcudnn9 + the -dev headers, so gate cuDNN to that variant.
|
||||
# x86 CUDA images carry no cuDNN -> enabling it there is a link failure.
|
||||
ifeq ($(CUDA_MAJOR_VERSION),13)
|
||||
ifneq (,$(filter arm64 aarch64,$(RECON_ARCH)))
|
||||
CMAKE_ARGS+=-DFACEDETECT_GGML_CUDNN=ON
|
||||
endif
|
||||
endif
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DFACEDETECT_GGML_HIP=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DFACEDETECT_GGML_VULKAN=ON
|
||||
else ifeq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DFACEDETECT_GGML_METAL=ON
|
||||
endif
|
||||
|
||||
.PHONY: face-detect-grpc package build clean purge test all
|
||||
|
||||
all: face-detect-grpc
|
||||
|
||||
# Clone the upstream face-detect.cpp source at the pinned commit. Directory acts
|
||||
# as the target so make only re-clones when missing. After a FACEDETECT_VERSION
|
||||
# bump, run 'make purge && make' to refetch.
|
||||
sources/face-detect.cpp:
|
||||
mkdir -p sources/face-detect.cpp
|
||||
cd sources/face-detect.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(FACEDETECT_REPO) && \
|
||||
git fetch --depth 1 origin $(FACEDETECT_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib + header out-of-tree, then stage them next to the Go
|
||||
# sources so purego.Dlopen("libfacedetect.so") and the cgo-less build both pick
|
||||
# them up.
|
||||
libfacedetect.so: sources/face-detect.cpp
|
||||
cmake -B sources/face-detect.cpp/build-shared -S sources/face-detect.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/face-detect.cpp/build-shared --config Release -j$(JOBS) --target facedetect
|
||||
cp -fv sources/face-detect.cpp/build-shared/libfacedetect.so* ./ 2>/dev/null || true
|
||||
cp -fv sources/face-detect.cpp/include/facedetect_capi.h ./
|
||||
|
||||
face-detect-grpc: libfacedetect.so main.go gofacedetect.go options.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o face-detect-grpc .
|
||||
|
||||
package: face-detect-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. The embed/detect/verify/analyze smoke specs are gated on
|
||||
# FACEDETECT_BACKEND_TEST_MODEL + FACEDETECT_BACKEND_TEST_IMAGE; without them the
|
||||
# heavy specs auto-skip and only the pure-Go parsing specs run.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libfacedetect.so* facedetect_capi.h package face-detect-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/face-detect.cpp
|
||||
431
backend/go/face-detect/gofacedetect.go
Normal file
431
backend/go/face-detect/gofacedetect.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// purego-bound entry points from libfacedetect.so. Names match
|
||||
// facedetect_capi.h exactly so a `nm libfacedetect.so | grep facedetect_capi`
|
||||
// is enough to spot drift.
|
||||
//
|
||||
// The opaque ctx and the malloc'd char*/float* return values are declared as
|
||||
// uintptr so we get the raw pointer back and can release it via the matching
|
||||
// capi free function. purego's native string/[]float32 returns would copy and
|
||||
// forget the original pointer, leaking the C-owned buffer on every call.
|
||||
var (
|
||||
CppAbiVersion func() int32
|
||||
CppLoad func(ggufPath string) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppFreeString func(s uintptr)
|
||||
CppFreeVec func(v uintptr)
|
||||
CppEmbedPath func(ctx uintptr, imagePath string, outVec, outDim unsafe.Pointer) int32
|
||||
CppEmbedRGB func(ctx uintptr, rgb []byte, width, height int32, outVec, outDim unsafe.Pointer) int32
|
||||
CppDetectJSON func(ctx uintptr, imagePath string) uintptr
|
||||
CppVerifyPaths func(ctx uintptr, a, b string, threshold float32, antiSpoof int32, outDistance, outVerified unsafe.Pointer) int32
|
||||
CppAnalyzeJSON func(ctx uintptr, imagePath string) uintptr
|
||||
)
|
||||
|
||||
// FaceDetect implements the face-recognition (biometric) subset of the Backend
|
||||
// gRPC service over libfacedetect.so. The C side keeps a single loaded model
|
||||
// pack plus a per-ctx last-error buffer and is not reentrant, so
|
||||
// base.SingleThread serializes every call.
|
||||
type FaceDetect struct {
|
||||
base.SingleThread
|
||||
opts loadOptions
|
||||
ctxPtr uintptr
|
||||
}
|
||||
|
||||
func (f *FaceDetect) Load(opts *pb.ModelOptions) error {
|
||||
model := opts.ModelFile
|
||||
if model == "" {
|
||||
model = opts.ModelPath
|
||||
}
|
||||
if !filepath.IsAbs(model) && opts.ModelPath != "" {
|
||||
model = filepath.Join(opts.ModelPath, model)
|
||||
}
|
||||
if model == "" {
|
||||
return errors.New("face-detect: ModelFile is required")
|
||||
}
|
||||
|
||||
f.opts = parseOptions(opts.Options)
|
||||
if f.opts.modelName == "" {
|
||||
f.opts.modelName = filepath.Base(model)
|
||||
}
|
||||
|
||||
// Propagate LocalAI's per-model thread budget to the engine. LocalAI spawns
|
||||
// one backend process per model and serves requests concurrently, so the
|
||||
// engine's own min(hardware_concurrency, 8) default can oversubscribe cores.
|
||||
// FACEDETECT_THREADS is read by the engine at backend construction, so it
|
||||
// must be set before the capi load. A non-positive Threads means "unset":
|
||||
// leave the env alone so the engine keeps its sane default.
|
||||
threads := opts.Threads
|
||||
if threads > 0 {
|
||||
if err := os.Setenv("FACEDETECT_THREADS", strconv.Itoa(int(threads))); err != nil {
|
||||
return fmt.Errorf("face-detect: set FACEDETECT_THREADS: %w", err)
|
||||
}
|
||||
xlog.Info("face-detect: applying LocalAI thread budget", "threads", threads)
|
||||
}
|
||||
|
||||
xlog.Info("face-detect: loading model", "model", model,
|
||||
"verify_threshold", f.opts.verifyThreshold, "abi", CppAbiVersion())
|
||||
|
||||
ctx := CppLoad(model)
|
||||
if ctx == 0 {
|
||||
// The last-error buffer lives on the ctx that was never returned, so
|
||||
// surface the path the operator tried to load instead.
|
||||
return fmt.Errorf("face-detect: facedetect_capi_load failed for %q", model)
|
||||
}
|
||||
f.ctxPtr = ctx
|
||||
return nil
|
||||
}
|
||||
|
||||
// Embeddings returns the L2-normalized ArcFace embedding of the primary face in
|
||||
// the supplied image. Mirroring the Python face backend, the image is read from
|
||||
// Images[0] as a base64 payload; materializeImage decodes it to a temp file so
|
||||
// the path-based C-API can run its own decode (cv2.imread parity). The gRPC
|
||||
// server wraps the returned slice in an EmbeddingResult.
|
||||
func (f *FaceDetect) Embeddings(req *pb.PredictOptions) ([]float32, error) {
|
||||
if f.ctxPtr == 0 {
|
||||
return nil, errors.New("face-detect: model not loaded")
|
||||
}
|
||||
if len(req.Images) == 0 || req.Images[0] == "" {
|
||||
return nil, errors.New("face-detect: Embedding requires Images[0] to be a base64 image")
|
||||
}
|
||||
|
||||
path, cleanup, err := materializeImage(req.Images[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
return f.embedPath(path)
|
||||
}
|
||||
|
||||
func (f *FaceDetect) embedPath(path string) ([]float32, error) {
|
||||
var vec uintptr
|
||||
var dim int32
|
||||
rc := CppEmbedPath(f.ctxPtr, path, unsafe.Pointer(&vec), unsafe.Pointer(&dim))
|
||||
if rc != 0 || vec == 0 || dim <= 0 {
|
||||
return nil, f.lastErr("embed", path)
|
||||
}
|
||||
defer CppFreeVec(vec)
|
||||
// Copy out of the C-owned malloc'd buffer before freeing it. The
|
||||
// uintptr->Pointer conversion trips vet's unsafeptr check, which can't tell
|
||||
// a C heap pointer from Go-managed memory; safe here, the GC neither tracks
|
||||
// nor moves this buffer and we copy immediately.
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(vec)), int(dim)) //nolint:govet // C-owned malloc'd vector, copied out before free
|
||||
out := make([]float32, int(dim))
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Detect runs SCRFD over the image and returns one Detection per face. The
|
||||
// C-API emits a box as [x1,y1,x2,y2] in pixels; the proto carries x/y plus
|
||||
// width/height, so the corners are converted. The 5 facial landmarks the engine
|
||||
// also returns are dropped: the Detection message has no field for them.
|
||||
func (f *FaceDetect) Detect(req *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||
if f.ctxPtr == 0 {
|
||||
return pb.DetectResponse{}, errors.New("face-detect: model not loaded")
|
||||
}
|
||||
if req.Src == "" {
|
||||
return pb.DetectResponse{}, errors.New("face-detect: src image is required")
|
||||
}
|
||||
|
||||
path, cleanup, err := materializeImage(req.Src)
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
faces, err := f.detectFaces(path)
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, err
|
||||
}
|
||||
|
||||
dets := make([]*pb.Detection, 0, len(faces))
|
||||
for _, fc := range faces {
|
||||
if req.Threshold > 0 && fc.Score < req.Threshold {
|
||||
continue
|
||||
}
|
||||
x, y, w, h := fc.xywh()
|
||||
dets = append(dets, &pb.Detection{
|
||||
X: x,
|
||||
Y: y,
|
||||
Width: w,
|
||||
Height: h,
|
||||
Confidence: fc.Score,
|
||||
ClassName: "face",
|
||||
})
|
||||
}
|
||||
return pb.DetectResponse{Detections: dets}, nil
|
||||
}
|
||||
|
||||
// FaceVerify embeds the primary face in each image and reports whether they are
|
||||
// the same identity by cosine distance against a threshold. A request threshold
|
||||
// <= 0 falls back to the model-configured default (verify_threshold option,
|
||||
// 0.35 if unset). When anti_spoofing is set, the C-API applies a MiniFASNet
|
||||
// veto internally (verified forced false on a spoof); the per-image liveness
|
||||
// scores are not exposed by the verify entry point, so img*_is_real /
|
||||
// img*_antispoof_score stay at their zero values.
|
||||
func (f *FaceDetect) FaceVerify(req *pb.FaceVerifyRequest) (pb.FaceVerifyResponse, error) {
|
||||
if f.ctxPtr == 0 {
|
||||
return pb.FaceVerifyResponse{}, errors.New("face-detect: model not loaded")
|
||||
}
|
||||
if req.Img1 == "" || req.Img2 == "" {
|
||||
return pb.FaceVerifyResponse{}, errors.New("face-detect: img1 and img2 are required")
|
||||
}
|
||||
|
||||
path1, cleanup1, err := materializeImage(req.Img1)
|
||||
if err != nil {
|
||||
return pb.FaceVerifyResponse{}, err
|
||||
}
|
||||
defer cleanup1()
|
||||
path2, cleanup2, err := materializeImage(req.Img2)
|
||||
if err != nil {
|
||||
return pb.FaceVerifyResponse{}, err
|
||||
}
|
||||
defer cleanup2()
|
||||
|
||||
threshold := req.Threshold
|
||||
if threshold <= 0 {
|
||||
threshold = f.opts.verifyThreshold
|
||||
}
|
||||
|
||||
antiSpoof := int32(0)
|
||||
if req.AntiSpoofing {
|
||||
antiSpoof = 1
|
||||
}
|
||||
|
||||
started := time.Now()
|
||||
var distance float32
|
||||
var verified int32
|
||||
rc := CppVerifyPaths(f.ctxPtr, path1, path2, threshold, antiSpoof,
|
||||
unsafe.Pointer(&distance), unsafe.Pointer(&verified))
|
||||
if rc != 0 {
|
||||
return pb.FaceVerifyResponse{}, f.lastErr("verify", req.Img1[:min(8, len(req.Img1))]+"...")
|
||||
}
|
||||
elapsedMs := float32(time.Since(started).Seconds() * 1000.0)
|
||||
|
||||
// Confidence decays linearly from 100 at distance 0 to 0 at the threshold,
|
||||
// matching the Python face backend's reporting.
|
||||
confidence := float32(0)
|
||||
if threshold > 0 {
|
||||
confidence = float32(math.Max(0, math.Min(100, (1.0-float64(distance)/float64(threshold))*100.0)))
|
||||
}
|
||||
|
||||
return pb.FaceVerifyResponse{
|
||||
Verified: verified != 0,
|
||||
Distance: distance,
|
||||
Threshold: threshold,
|
||||
Confidence: confidence,
|
||||
Model: f.opts.modelName,
|
||||
Img1Area: f.bestArea(path1),
|
||||
Img2Area: f.bestArea(path2),
|
||||
ProcessingTimeMs: elapsedMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FaceAnalyze runs the genderage head on every detected face. The C-API returns
|
||||
// "M"/"F" gender labels and a rounded age; the labels are normalized to the
|
||||
// "Man"/"Woman" values the proto documents.
|
||||
func (f *FaceDetect) FaceAnalyze(req *pb.FaceAnalyzeRequest) (pb.FaceAnalyzeResponse, error) {
|
||||
if f.ctxPtr == 0 {
|
||||
return pb.FaceAnalyzeResponse{}, errors.New("face-detect: model not loaded")
|
||||
}
|
||||
if req.Img == "" {
|
||||
return pb.FaceAnalyzeResponse{}, errors.New("face-detect: img is required")
|
||||
}
|
||||
|
||||
path, cleanup, err := materializeImage(req.Img)
|
||||
if err != nil {
|
||||
return pb.FaceAnalyzeResponse{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
ptr := CppAnalyzeJSON(f.ctxPtr, path)
|
||||
if ptr == 0 {
|
||||
return pb.FaceAnalyzeResponse{}, f.lastErr("analyze", path)
|
||||
}
|
||||
defer CppFreeString(ptr)
|
||||
|
||||
faces, err := parseAnalyzeJSON(goStringFromCPtr(ptr))
|
||||
if err != nil {
|
||||
return pb.FaceAnalyzeResponse{}, fmt.Errorf("face-detect: analyze JSON: %w", err)
|
||||
}
|
||||
return pb.FaceAnalyzeResponse{Faces: faces}, nil
|
||||
}
|
||||
|
||||
// faceBox is one entry of the detect/analyze JSON documents the engine emits.
|
||||
type faceBox struct {
|
||||
Score float32 `json:"score"`
|
||||
Box []float32 `json:"box"`
|
||||
Age float32 `json:"age"`
|
||||
Gender string `json:"gender"`
|
||||
}
|
||||
|
||||
// xywh converts the engine's [x1,y1,x2,y2] box into the x/y/width/height the
|
||||
// proto carries. A short or missing box yields zeros.
|
||||
func (b faceBox) xywh() (x, y, w, h float32) {
|
||||
if len(b.Box) < 4 {
|
||||
return 0, 0, 0, 0
|
||||
}
|
||||
return b.Box[0], b.Box[1], b.Box[2] - b.Box[0], b.Box[3] - b.Box[1]
|
||||
}
|
||||
|
||||
type facesJSON struct {
|
||||
Faces []faceBox `json:"faces"`
|
||||
}
|
||||
|
||||
func (f *FaceDetect) detectFaces(path string) ([]faceBox, error) {
|
||||
ptr := CppDetectJSON(f.ctxPtr, path)
|
||||
if ptr == 0 {
|
||||
return nil, f.lastErr("detect", path)
|
||||
}
|
||||
defer CppFreeString(ptr)
|
||||
|
||||
var doc facesJSON
|
||||
if err := json.Unmarshal([]byte(goStringFromCPtr(ptr)), &doc); err != nil {
|
||||
return nil, fmt.Errorf("face-detect: detect JSON: %w", err)
|
||||
}
|
||||
return doc.Faces, nil
|
||||
}
|
||||
|
||||
// bestArea returns the FacialArea of the highest-scoring face in an image, or an
|
||||
// empty area when detection fails or finds nothing. Best-effort: verify already
|
||||
// succeeded, so a missing region must not turn a valid match into an error.
|
||||
func (f *FaceDetect) bestArea(path string) *pb.FacialArea {
|
||||
faces, err := f.detectFaces(path)
|
||||
if err != nil || len(faces) == 0 {
|
||||
return &pb.FacialArea{}
|
||||
}
|
||||
best := faces[0]
|
||||
for _, fc := range faces[1:] {
|
||||
if fc.Score > best.Score {
|
||||
best = fc
|
||||
}
|
||||
}
|
||||
x, y, w, h := best.xywh()
|
||||
return &pb.FacialArea{X: x, Y: y, W: w, H: h}
|
||||
}
|
||||
|
||||
// parseAnalyzeJSON maps the engine's analyze document onto FaceAnalysis entries.
|
||||
// The engine reports gender as "M"/"F"; both the dominant label and the score
|
||||
// map are filled with the "Man"/"Woman" form the proto documents.
|
||||
func parseAnalyzeJSON(doc string) ([]*pb.FaceAnalysis, error) {
|
||||
var parsed facesJSON
|
||||
if err := json.Unmarshal([]byte(doc), &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make([]*pb.FaceAnalysis, 0, len(parsed.Faces))
|
||||
for _, fc := range parsed.Faces {
|
||||
x, y, w, h := fc.xywh()
|
||||
fa := &pb.FaceAnalysis{
|
||||
Region: &pb.FacialArea{X: x, Y: y, W: w, H: h},
|
||||
FaceConfidence: fc.Score,
|
||||
Age: fc.Age,
|
||||
}
|
||||
if label := normalizeGender(fc.Gender); label != "" {
|
||||
fa.DominantGender = label
|
||||
fa.Gender = map[string]float32{label: 1.0}
|
||||
}
|
||||
out = append(out, fa)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// normalizeGender maps the engine's "M"/"F" code to the "Man"/"Woman" labels the
|
||||
// proto documents. Unknown codes pass through unchanged.
|
||||
func normalizeGender(g string) string {
|
||||
switch strings.ToUpper(strings.TrimSpace(g)) {
|
||||
case "M":
|
||||
return "Man"
|
||||
case "F":
|
||||
return "Woman"
|
||||
case "":
|
||||
return ""
|
||||
default:
|
||||
return g
|
||||
}
|
||||
}
|
||||
|
||||
// materializeImage decodes a base64 image payload into a temp file and returns
|
||||
// its path plus a cleanup func. As a convenience for callers that already pass a
|
||||
// filesystem path (e.g. a test fixture), an existing path is used as-is with a
|
||||
// no-op cleanup. data: URI prefixes are stripped before decoding.
|
||||
func materializeImage(src string) (path string, cleanup func(), err error) {
|
||||
noop := func() {}
|
||||
if src == "" {
|
||||
return "", noop, errors.New("face-detect: empty image input")
|
||||
}
|
||||
if _, statErr := os.Stat(src); statErr == nil {
|
||||
return src, noop, nil
|
||||
}
|
||||
|
||||
payload := src
|
||||
if i := strings.Index(payload, ","); strings.HasPrefix(payload, "data:") && i >= 0 {
|
||||
payload = payload[i+1:]
|
||||
}
|
||||
data, decErr := base64.StdEncoding.DecodeString(strings.TrimSpace(payload))
|
||||
if decErr != nil || len(data) == 0 {
|
||||
return "", noop, errors.New("face-detect: image is neither an existing path nor valid base64")
|
||||
}
|
||||
|
||||
tmp, createErr := os.CreateTemp("", "face-detect-*.img")
|
||||
if createErr != nil {
|
||||
return "", noop, fmt.Errorf("face-detect: create temp image: %w", createErr)
|
||||
}
|
||||
cleanup = func() { _ = os.Remove(tmp.Name()) }
|
||||
if _, wErr := tmp.Write(data); wErr != nil {
|
||||
_ = tmp.Close()
|
||||
cleanup()
|
||||
return "", noop, fmt.Errorf("face-detect: write temp image: %w", wErr)
|
||||
}
|
||||
if cErr := tmp.Close(); cErr != nil {
|
||||
cleanup()
|
||||
return "", noop, fmt.Errorf("face-detect: close temp image: %w", cErr)
|
||||
}
|
||||
return tmp.Name(), cleanup, nil
|
||||
}
|
||||
|
||||
// lastErr wraps the C-API's per-ctx last-error buffer into a Go error.
|
||||
func (f *FaceDetect) lastErr(op, subject string) error {
|
||||
msg := strings.TrimSpace(CppLastError(f.ctxPtr))
|
||||
if msg == "" {
|
||||
msg = "no error detail"
|
||||
}
|
||||
return fmt.Errorf("face-detect: %s failed for %q: %s", op, subject, msg)
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is a
|
||||
// malloc'd buffer the caller owns; release it via CppFreeString after the copy.
|
||||
//
|
||||
// The uintptr->Pointer conversion trips vet's unsafeptr check, which can't tell
|
||||
// a C heap pointer from Go-managed memory. Safe here: the GC neither tracks nor
|
||||
// moves the buffer and we dereference it immediately to copy the bytes out.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := unsafe.Pointer(cptr) //nolint:govet // C-owned malloc'd buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
230
backend/go/face-detect/gofacedetect_test.go
Normal file
230
backend/go/face-detect/gofacedetect_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestFaceDetect(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "face-detect Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the C-API
|
||||
// bridge without spinning up the gRPC server. Records the error (the smoke
|
||||
// specs skip themselves) when libfacedetect.so is not loadable from cwd
|
||||
// (LD_LIBRARY_PATH or a symlink in ./).
|
||||
func ensureLibLoaded() error {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("FACEDETECT_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libfacedetect.so"
|
||||
}
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppAbiVersion, lib, "facedetect_capi_abi_version")
|
||||
purego.RegisterLibFunc(&CppLoad, lib, "facedetect_capi_load")
|
||||
purego.RegisterLibFunc(&CppFree, lib, "facedetect_capi_free")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "facedetect_capi_last_error")
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "facedetect_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppFreeVec, lib, "facedetect_capi_free_vec")
|
||||
purego.RegisterLibFunc(&CppEmbedPath, lib, "facedetect_capi_embed_path")
|
||||
purego.RegisterLibFunc(&CppEmbedRGB, lib, "facedetect_capi_embed_rgb")
|
||||
purego.RegisterLibFunc(&CppDetectJSON, lib, "facedetect_capi_detect_path_json")
|
||||
purego.RegisterLibFunc(&CppVerifyPaths, lib, "facedetect_capi_verify_paths")
|
||||
purego.RegisterLibFunc(&CppAnalyzeJSON, lib, "facedetect_capi_analyze_path_json")
|
||||
})
|
||||
return libLoadErr
|
||||
}
|
||||
|
||||
var _ = Describe("parseOptions", func() {
|
||||
It("defaults verify_threshold to 0.35", func() {
|
||||
o := parseOptions(nil)
|
||||
Expect(o.verifyThreshold).To(Equal(float32(0.35)))
|
||||
Expect(o.modelName).To(Equal(""))
|
||||
})
|
||||
|
||||
It("parses verify_threshold, threshold alias and model_name", func() {
|
||||
o := parseOptions([]string{"verify_threshold:0.4", "model_name:buffalo_l", "unknown:x"})
|
||||
Expect(o.verifyThreshold).To(Equal(float32(0.4)))
|
||||
Expect(o.modelName).To(Equal("buffalo_l"))
|
||||
|
||||
o2 := parseOptions([]string{"threshold:0.3"})
|
||||
Expect(o2.verifyThreshold).To(Equal(float32(0.3)))
|
||||
})
|
||||
|
||||
It("ignores non-positive thresholds and keeps the default", func() {
|
||||
o := parseOptions([]string{"verify_threshold:0", "threshold:-1"})
|
||||
Expect(o.verifyThreshold).To(Equal(float32(0.35)))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("normalizeGender", func() {
|
||||
It("maps M/F codes to Man/Woman", func() {
|
||||
Expect(normalizeGender("M")).To(Equal("Man"))
|
||||
Expect(normalizeGender("f")).To(Equal("Woman"))
|
||||
Expect(normalizeGender(" m ")).To(Equal("Man"))
|
||||
})
|
||||
|
||||
It("passes empty and unknown codes through", func() {
|
||||
Expect(normalizeGender("")).To(Equal(""))
|
||||
Expect(normalizeGender("nonbinary")).To(Equal("nonbinary"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("faceBox.xywh", func() {
|
||||
It("converts an [x1,y1,x2,y2] box to x/y/width/height", func() {
|
||||
b := faceBox{Box: []float32{10, 20, 50, 80}}
|
||||
x, y, w, h := b.xywh()
|
||||
Expect(x).To(Equal(float32(10)))
|
||||
Expect(y).To(Equal(float32(20)))
|
||||
Expect(w).To(Equal(float32(40)))
|
||||
Expect(h).To(Equal(float32(60)))
|
||||
})
|
||||
|
||||
It("returns zeros for a short box", func() {
|
||||
x, y, w, h := faceBox{Box: []float32{1, 2}}.xywh()
|
||||
Expect([]float32{x, y, w, h}).To(Equal([]float32{0, 0, 0, 0}))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("parseAnalyzeJSON", func() {
|
||||
It("maps region, age and gender for each face", func() {
|
||||
doc := `{"faces":[
|
||||
{"score":0.997,"box":[10,20,50,80],"age":31,"gender":"M"},
|
||||
{"score":0.81,"box":[0,0,40,40],"age":24,"gender":"F"}]}`
|
||||
faces, err := parseAnalyzeJSON(doc)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(faces).To(HaveLen(2))
|
||||
|
||||
Expect(faces[0].FaceConfidence).To(BeNumerically("~", 0.997, 1e-4))
|
||||
Expect(faces[0].Age).To(BeNumerically("~", 31, 1e-4))
|
||||
Expect(faces[0].DominantGender).To(Equal("Man"))
|
||||
Expect(faces[0].Gender).To(HaveKeyWithValue("Man", float32(1.0)))
|
||||
Expect(faces[0].Region.W).To(Equal(float32(40)))
|
||||
Expect(faces[0].Region.H).To(Equal(float32(60)))
|
||||
|
||||
Expect(faces[1].DominantGender).To(Equal("Woman"))
|
||||
})
|
||||
|
||||
It("tolerates a missing gender field", func() {
|
||||
faces, err := parseAnalyzeJSON(`{"faces":[{"score":0.5,"box":[0,0,10,10],"age":40}]}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(faces).To(HaveLen(1))
|
||||
Expect(faces[0].DominantGender).To(Equal(""))
|
||||
Expect(faces[0].Gender).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns no faces for an empty document", func() {
|
||||
faces, err := parseAnalyzeJSON(`{"faces":[]}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(faces).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns an error on malformed JSON", func() {
|
||||
_, err := parseAnalyzeJSON(`{not-json`)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("materializeImage", func() {
|
||||
It("decodes a base64 payload to a temp file", func() {
|
||||
payload := base64.StdEncoding.EncodeToString([]byte("\xff\xd8\xff\xe0fake-jpeg"))
|
||||
path, cleanup, err := materializeImage(payload)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cleanup()
|
||||
data, rerr := os.ReadFile(path)
|
||||
Expect(rerr).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("\xff\xd8\xff\xe0fake-jpeg")))
|
||||
})
|
||||
|
||||
It("strips a data: URI prefix before decoding", func() {
|
||||
payload := "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte("hello"))
|
||||
path, cleanup, err := materializeImage(payload)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cleanup()
|
||||
data, rerr := os.ReadFile(path)
|
||||
Expect(rerr).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("hello")))
|
||||
})
|
||||
|
||||
It("uses an existing path as-is", func() {
|
||||
tmp, err := os.CreateTemp("", "face-detect-fixture-*.bin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer func() { _ = os.Remove(tmp.Name()) }()
|
||||
Expect(tmp.Close()).To(Succeed())
|
||||
|
||||
path, cleanup, err := materializeImage(tmp.Name())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cleanup()
|
||||
Expect(path).To(Equal(tmp.Name()))
|
||||
})
|
||||
|
||||
It("errors on input that is neither a path nor base64", func() {
|
||||
_, _, err := materializeImage("not base64!!!")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
// The specs below exercise the real C-API end to end. They run only when both a
|
||||
// model GGUF and a test image are provided, and skip cleanly otherwise so the
|
||||
// suite stays green without large assets.
|
||||
var _ = Describe("FaceDetect end-to-end", Ordered, func() {
|
||||
var (
|
||||
f *FaceDetect
|
||||
modelPath = os.Getenv("FACEDETECT_BACKEND_TEST_MODEL")
|
||||
imagePath = os.Getenv("FACEDETECT_BACKEND_TEST_IMAGE")
|
||||
)
|
||||
|
||||
BeforeAll(func() {
|
||||
if modelPath == "" || imagePath == "" {
|
||||
Skip("set FACEDETECT_BACKEND_TEST_MODEL and FACEDETECT_BACKEND_TEST_IMAGE to run the e2e specs")
|
||||
}
|
||||
if err := ensureLibLoaded(); err != nil {
|
||||
Skip("libfacedetect.so not loadable: " + err.Error())
|
||||
}
|
||||
f = &FaceDetect{}
|
||||
Expect(f.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
})
|
||||
|
||||
It("embeds the primary face in an image", func() {
|
||||
emb, err := f.Embeddings(&pb.PredictOptions{Images: []string{imagePath}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(emb).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("detects at least one face", func() {
|
||||
resp, err := f.Detect(&pb.DetectOptions{Src: imagePath})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Detections).ToNot(BeEmpty())
|
||||
Expect(resp.Detections[0].ClassName).To(Equal("face"))
|
||||
})
|
||||
|
||||
It("verifies an image against itself as the same identity", func() {
|
||||
resp, err := f.FaceVerify(&pb.FaceVerifyRequest{Img1: imagePath, Img2: imagePath})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Verified).To(BeTrue())
|
||||
Expect(resp.Distance).To(BeNumerically("<=", resp.Threshold))
|
||||
})
|
||||
|
||||
It("analyzes age/gender for each face", func() {
|
||||
resp, err := f.FaceAnalyze(&pb.FaceAnalyzeRequest{Img: imagePath})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Faces).ToNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
65
backend/go/face-detect/main.go
Normal file
65
backend/go/face-detect/main.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libfacedetect.so via purego and registers the flat C-API entry points
|
||||
// declared in facedetect_capi.h. The library name can be overridden with
|
||||
// FACEDETECT_LIBRARY (mirrors the VOICEDETECT_LIBRARY / PARAKEET_LIBRARY
|
||||
// convention in the sibling backends); the default looks for the .so next to
|
||||
// this binary (resolved via LD_LIBRARY_PATH by run.sh).
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("FACEDETECT_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libfacedetect.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("face-detect: dlopen %q: %w", libName, err))
|
||||
}
|
||||
|
||||
// Bound 1:1 to facedetect_capi.h. char*/float* returns are registered as
|
||||
// uintptr so the raw pointer can be freed via the matching capi free fn.
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppAbiVersion, "facedetect_capi_abi_version"},
|
||||
{&CppLoad, "facedetect_capi_load"},
|
||||
{&CppFree, "facedetect_capi_free"},
|
||||
{&CppLastError, "facedetect_capi_last_error"},
|
||||
{&CppFreeString, "facedetect_capi_free_string"},
|
||||
{&CppFreeVec, "facedetect_capi_free_vec"},
|
||||
{&CppEmbedPath, "facedetect_capi_embed_path"},
|
||||
{&CppEmbedRGB, "facedetect_capi_embed_rgb"},
|
||||
{&CppDetectJSON, "facedetect_capi_detect_path_json"},
|
||||
{&CppVerifyPaths, "facedetect_capi_verify_paths"},
|
||||
{&CppAnalyzeJSON, "facedetect_capi_analyze_path_json"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[face-detect] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &FaceDetect{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
47
backend/go/face-detect/options.go
Normal file
47
backend/go/face-detect/options.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// defaultVerifyThreshold is the cosine-distance cutoff used when a request does
|
||||
// not set one. Matches the insightface buffalo_l ArcFace R50 default the Python
|
||||
// face backend ships with so the two implementations agree on verdicts out of
|
||||
// the box.
|
||||
const defaultVerifyThreshold float32 = 0.35
|
||||
|
||||
// loadOptions holds the parsed model-level options for face-detect.
|
||||
type loadOptions struct {
|
||||
verifyThreshold float32
|
||||
modelName string
|
||||
}
|
||||
|
||||
func splitOption(o string) (key, value string, ok bool) {
|
||||
i := strings.Index(o, ":")
|
||||
if i < 0 {
|
||||
return "", "", false
|
||||
}
|
||||
return strings.TrimSpace(o[:i]), strings.TrimSpace(o[i+1:]), true
|
||||
}
|
||||
|
||||
// parseOptions reads the backend "key:value" option slice. Unknown keys are
|
||||
// ignored. Defaults: verify_threshold 0.35, model_name derived from the file.
|
||||
func parseOptions(opts []string) loadOptions {
|
||||
o := loadOptions{verifyThreshold: defaultVerifyThreshold}
|
||||
for _, oo := range opts {
|
||||
key, value, ok := splitOption(oo)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "verify_threshold", "threshold":
|
||||
if f, err := strconv.ParseFloat(value, 32); err == nil && f > 0 {
|
||||
o.verifyThreshold = float32(f)
|
||||
}
|
||||
case "model_name":
|
||||
o.modelName = value
|
||||
}
|
||||
}
|
||||
return o
|
||||
}
|
||||
68
backend/go/face-detect/package.sh
Normal file
68
backend/go/face-detect/package.sh
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Bundle the face-detect-grpc binary, libfacedetect.so, the core runtime libs
|
||||
# (libc/libstdc++/libgomp + ld.so) and the GPU runtime for the active BUILD_TYPE
|
||||
# so the package is self-contained. Mirrors backend/go/voice-detect/package.sh;
|
||||
# run.sh routes the (CGO_ENABLED=0) binary through lib/ld.so so the packaged libc
|
||||
# is used instead of the host's.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/face-detect-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libfacedetect.so + any soname symlinks. purego.Dlopen resolves it via
|
||||
# LD_LIBRARY_PATH, which run.sh points at lib/.
|
||||
cp -avf "$CURDIR"/libfacedetect.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libfacedetect.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Detect architecture and copy the core runtime libs libfacedetect.so links
|
||||
# against, plus the matching dynamic loader as lib/ld.so.
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 "$CURDIR/package/lib/ld.so"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 "$CURDIR/package/lib/ld.so"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||
elif [ "$(uname -s)" = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries (CUDA/ROCm/Intel/Vulkan loader + ICDs + drivers) based on
|
||||
# BUILD_TYPE so the backend can reach the GPU without the runtime base image
|
||||
# shipping those drivers.
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
16
backend/go/face-detect/run.sh
Normal file
16
backend/go/face-detect/run.sh
Normal file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the packaged
|
||||
# libc / libstdc++ are used instead of the host's (matches the voice-detect /
|
||||
# whisper / parakeet backends' runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/face-detect-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/face-detect-grpc" "$@"
|
||||
15
backend/go/face-detect/test.sh
Normal file
15
backend/go/face-detect/test.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
cd "$CURDIR"
|
||||
|
||||
echo "Running face-detect backend tests..."
|
||||
|
||||
# The pure-Go parsing specs always run. The embed/detect/verify/analyze smoke
|
||||
# specs run only when a model + image are provided via
|
||||
# FACEDETECT_BACKEND_TEST_MODEL and FACEDETECT_BACKEND_TEST_IMAGE; otherwise they
|
||||
# auto-skip.
|
||||
LD_LIBRARY_PATH="$CURDIR:${LD_LIBRARY_PATH:-}" go test -v -timeout 1200s .
|
||||
|
||||
echo "face-detect tests completed."
|
||||
81
backend/go/parakeet-cpp/boundary.go
Normal file
81
backend/go/parakeet-cpp/boundary.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package main
|
||||
|
||||
// utteranceBoundary is the single definition of a small state machine that was
|
||||
// previously open-coded three times — as a bare `finalEou` bool with an ad-hoc
|
||||
// toggle — in the live feed (live.go), the file-stream text path, and the
|
||||
// file-stream JSON path (goparakeetcpp.go).
|
||||
//
|
||||
// It answers one running question: does the decode currently rest on an
|
||||
// end-of-utterance boundary? That is the value a closing FinalResult reports as
|
||||
// .Eou and the realtime turn detector treats as a commit point.
|
||||
//
|
||||
// parakeet auto-resets its decoder after every <EOU>/<EOB>, so one streaming
|
||||
// session is a sequence of utterances and this is a LATCH, not a monotonic
|
||||
// flag: it closes on an <EOU> and reopens as soon as the next utterance starts.
|
||||
// (Contrast the realtime API's per-turn `eouSeen`, which only ever goes
|
||||
// false->true because each turn gets a fresh stream. Here the stream outlives
|
||||
// the turn, so the boundary status must be able to reopen.)
|
||||
//
|
||||
// The only transitions, over the events one streamFeedResult carries — an
|
||||
// <EOU>, an <EOB> (backchannel), or plain speech output (text and/or words):
|
||||
//
|
||||
// <EOU>
|
||||
// open ───────────► closed
|
||||
// ▲ ▲ │ │ │
|
||||
// │ └─┘ <EOB>|speech │ │ <EOU>
|
||||
// │ (stay open) │ └─┘ (stay closed)
|
||||
// └──────────────────┘
|
||||
// <EOB>|speech
|
||||
//
|
||||
// open = NOT on an utterance boundary: mid-utterance, the last boundary was
|
||||
// a backchannel <EOB>, or the stream just began (the initial state).
|
||||
// closed = the last meaningful event was an <EOU> with no later speech: a real
|
||||
// turn boundary.
|
||||
//
|
||||
// A feed that carries nothing (no eou/eob/text/words — e.g. a finalize flush
|
||||
// that produced no tail) is a no-op and leaves the state unchanged, matching
|
||||
// the legacy "leave finalEou as it was" behaviour.
|
||||
//
|
||||
// The state carries no data, so it is modelled as a two-valued type (a named
|
||||
// bool) rather than an int enum: every inhabitant is legal, so illegal states
|
||||
// are unrepresentable — the payload-free analog of the sealed sum types the
|
||||
// realtime machines use (those need interfaces because their states carry data,
|
||||
// e.g. Active{ID}, where "Active with no ID" is the illegal combination a scalar
|
||||
// cannot even express).
|
||||
type utteranceBoundary bool
|
||||
|
||||
const (
|
||||
// boundaryOpen is the zero value (false), so a fresh decode starts open —
|
||||
// exactly the legacy `var finalEou bool` (false) initial condition.
|
||||
boundaryOpen utteranceBoundary = false
|
||||
boundaryClosed utteranceBoundary = true
|
||||
)
|
||||
|
||||
// observe folds one decode increment into the latch and returns the new state.
|
||||
//
|
||||
// <EOU> takes priority when a single feed carries both an <EOU> and speech
|
||||
// (e.g. {"text":"hello","eou":1}): the utterance both produced that text AND
|
||||
// ended, so the decode rests on the boundary. This matches the legacy
|
||||
// eou-checked-first ordering at every call site.
|
||||
func (b utteranceBoundary) observe(r streamFeedResult) utteranceBoundary {
|
||||
switch {
|
||||
case r.Eou:
|
||||
return boundaryClosed
|
||||
case r.Eob || r.Delta != "" || len(r.Words) > 0:
|
||||
return boundaryOpen
|
||||
default:
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
// ended reports whether the decode currently rests on an end-of-utterance
|
||||
// boundary (a real <EOU>, not a backchannel <EOB>). This is what a closing
|
||||
// FinalResult carries as .Eou.
|
||||
func (b utteranceBoundary) ended() bool { return b == boundaryClosed }
|
||||
|
||||
func (b utteranceBoundary) String() string {
|
||||
if b == boundaryClosed {
|
||||
return "closed"
|
||||
}
|
||||
return "open"
|
||||
}
|
||||
92
backend/go/parakeet-cpp/boundary_test.go
Normal file
92
backend/go/parakeet-cpp/boundary_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("utteranceBoundary (decode end-of-utterance latch)", func() {
|
||||
It("starts open: a fresh decode is not on a boundary", func() {
|
||||
var b utteranceBoundary
|
||||
Expect(b).To(Equal(boundaryOpen))
|
||||
Expect(b.ended()).To(BeFalse())
|
||||
})
|
||||
|
||||
DescribeTable("single feed transition from the open state",
|
||||
func(r streamFeedResult, wantEnded bool) {
|
||||
Expect(boundaryOpen.observe(r).ended()).To(Equal(wantEnded))
|
||||
},
|
||||
Entry("<EOU> closes it", streamFeedResult{Eou: true}, true),
|
||||
Entry("<EOU> with text closes it (eou wins)", streamFeedResult{Delta: "hi", Eou: true}, true),
|
||||
Entry("<EOB> stays open (backchannel is not a turn boundary)", streamFeedResult{Eob: true}, false),
|
||||
Entry("plain text stays open", streamFeedResult{Delta: "hello"}, false),
|
||||
Entry("words-only stays open", streamFeedResult{Words: []transcriptWord{{W: "x"}}}, false),
|
||||
Entry("empty feed is a no-op (stays open)", streamFeedResult{}, false),
|
||||
)
|
||||
|
||||
DescribeTable("single feed transition from the closed state",
|
||||
func(r streamFeedResult, wantEnded bool) {
|
||||
Expect(boundaryClosed.observe(r).ended()).To(Equal(wantEnded))
|
||||
},
|
||||
Entry("another <EOU> stays closed", streamFeedResult{Eou: true}, true),
|
||||
Entry("trailing speech reopens it", streamFeedResult{Delta: "and more"}, false),
|
||||
Entry("words reopen it", streamFeedResult{Words: []transcriptWord{{W: "x"}}}, false),
|
||||
Entry("a backchannel <EOB> reopens it", streamFeedResult{Eob: true}, false),
|
||||
Entry("empty feed is a no-op (stays closed)", streamFeedResult{}, true),
|
||||
)
|
||||
|
||||
It("is a latch: <EOU> then trailing speech reopens, then <EOU> closes again", func() {
|
||||
b := boundaryOpen
|
||||
b = b.observe(streamFeedResult{Delta: "turn one", Eou: true})
|
||||
Expect(b.ended()).To(BeTrue())
|
||||
b = b.observe(streamFeedResult{Delta: " and more"})
|
||||
Expect(b.ended()).To(BeFalse(), "trailing speech without an EOU is an open utterance")
|
||||
b = b.observe(streamFeedResult{Eou: true})
|
||||
Expect(b.ended()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("treats a backchannel before a real EOU correctly", func() {
|
||||
b := boundaryOpen
|
||||
b = b.observe(streamFeedResult{Delta: "uh huh", Eob: true})
|
||||
Expect(b.ended()).To(BeFalse(), "a backchannel must not masquerade as a turn boundary")
|
||||
b = b.observe(streamFeedResult{Delta: "done", Eou: true})
|
||||
Expect(b.ended()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches the reference fold over seeded random feed sequences", func() {
|
||||
// The invariant: after any sequence of feeds, ended() is true iff the
|
||||
// last feed that carried ANY event was an <EOU>. <EOU> takes priority
|
||||
// when a feed carries both an EOU and speech; empty feeds are ignored.
|
||||
for seed := uint64(1); seed <= 200; seed++ {
|
||||
rng := rand.New(rand.NewPCG(seed, seed*2654435761))
|
||||
b := boundaryOpen
|
||||
lastWasEou := false // reference: did the last meaningful feed end on EOU?
|
||||
steps := rng.IntN(30)
|
||||
for i := 0; i < steps; i++ {
|
||||
var r streamFeedResult
|
||||
switch rng.IntN(5) {
|
||||
case 0:
|
||||
r = streamFeedResult{Eou: true}
|
||||
case 1:
|
||||
r = streamFeedResult{Eob: true}
|
||||
case 2:
|
||||
r = streamFeedResult{Delta: "w"}
|
||||
case 3:
|
||||
r = streamFeedResult{Delta: "w", Eou: true} // eou + speech, eou wins
|
||||
case 4:
|
||||
r = streamFeedResult{} // empty: no-op
|
||||
}
|
||||
b = b.observe(r)
|
||||
if r.Eou {
|
||||
lastWasEou = true
|
||||
} else if r.Eob || r.Delta != "" || len(r.Words) > 0 {
|
||||
lastWasEou = false
|
||||
}
|
||||
}
|
||||
Expect(b.ended()).To(Equal(lastWasEou),
|
||||
"seed %d: latch disagreed with the reference fold", seed)
|
||||
}
|
||||
})
|
||||
})
|
||||
82
backend/go/parakeet-cpp/driver.go
Normal file
82
backend/go/parakeet-cpp/driver.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// streamFeedResult is one decode increment from a cache-aware streaming session:
|
||||
// the newly-finalized text plus the model's own per-feed boundary tokens
|
||||
// (<EOU>/<EOB>) and word timings. It is the single event type both the live
|
||||
// (bidi) and file (server-stream) paths fold over, hiding the ABI v4 JSON vs
|
||||
// older text-only entry-point split behind one shape.
|
||||
type streamFeedResult struct {
|
||||
Delta string
|
||||
Eou bool
|
||||
Eob bool
|
||||
Words []transcriptWord
|
||||
}
|
||||
|
||||
// feedChunk feeds one PCM chunk to the streaming session (or finalizes it, when
|
||||
// finalize is true) and returns the unified decode increment. It prefers the
|
||||
// ABI v4 JSON entry points (which also carry per-word timestamps) and falls
|
||||
// back to the older text-only entry points against an older libparakeet.so.
|
||||
//
|
||||
// This is the one place the JSON-vs-text choice is made; every consumer works
|
||||
// in terms of streamFeedResult.
|
||||
func (p *ParakeetCpp) feedChunk(stream uintptr, pcm []float32, finalize bool) (streamFeedResult, error) {
|
||||
if CppStreamFeedJSON != nil {
|
||||
doc, err := p.streamFeedDoc(stream, pcm, finalize)
|
||||
if err != nil {
|
||||
return streamFeedResult{}, err
|
||||
}
|
||||
return streamFeedResult{Delta: doc.Text, Eou: doc.Eou != 0, Eob: doc.Eob != 0, Words: doc.Words}, nil
|
||||
}
|
||||
delta, eou, eob, err := p.streamFeedText(stream, pcm, finalize)
|
||||
if err != nil {
|
||||
return streamFeedResult{}, err
|
||||
}
|
||||
return streamFeedResult{Delta: delta, Eou: eou, Eob: eob}, nil
|
||||
}
|
||||
|
||||
// feedSlices feeds pcm through the session in streamChunkSamples slices,
|
||||
// invoking onFeed for each decode increment. It does NOT finalize: callers
|
||||
// decide when the send side is done. The file path finalizes after the whole
|
||||
// file; the live path finalizes only when its request channel closes, never
|
||||
// between audio messages. Slicing keeps each per-call engineMu hold short so
|
||||
// concurrent unary transcription interleaves fairly (the C session buffers
|
||||
// internally).
|
||||
//
|
||||
// If ctx is non-nil it is checked before each slice so a cancelled file
|
||||
// transcription stops promptly; the live path passes nil (it is bounded by its
|
||||
// request channel instead of a ctx).
|
||||
func (p *ParakeetCpp) feedSlices(ctx context.Context, stream uintptr, pcm []float32, onFeed func(streamFeedResult) error) error {
|
||||
for off := 0; off < len(pcm); off += streamChunkSamples {
|
||||
if ctx != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(pcm))
|
||||
res, err := p.feedChunk(stream, pcm[off:end], false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := onFeed(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushTail finalizes the session once and folds the flushed tail (the last
|
||||
// ~2 encoder frames of text, which only appear on finalize) through onFeed.
|
||||
func (p *ParakeetCpp) flushTail(stream uintptr, onFeed func(streamFeedResult) error) error {
|
||||
res, err := p.feedChunk(stream, nil, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return onFeed(res)
|
||||
}
|
||||
@@ -103,12 +103,13 @@ type transcriptJSON struct {
|
||||
// {"text":"...","eou":0,"eob":0,"frame_sec":0.080000,
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "text" is the newly-finalized text since the last call; "eou" is 1 when an
|
||||
// <EOU> (end of utterance) fired this feed and "eob" is 1 when an <EOB>
|
||||
// (backchannel) fired. ABI v4 conflated the two into "eou"; v5 split them, so
|
||||
// we read both and treat either as an utterance boundary for segmentation.
|
||||
// "words" are the words finalized this call with absolute (stream-relative)
|
||||
// start/end seconds.
|
||||
// "text" is the newly-finalized text since the last call. Under ABI v5 "eou"
|
||||
// is 1 iff an <EOU> fired this feed (the user yielded the turn) and "eob" 1
|
||||
// iff an <EOB> fired (a backchannel like "uh-huh" ended — NOT a turn
|
||||
// boundary). A v4 library has no "eob" field and its "eou" conflates both
|
||||
// tokens: Eob stays 0 and Eou keeps the old any-event meaning. "words" are
|
||||
// the words finalized this call with absolute (stream-relative) start/end
|
||||
// seconds.
|
||||
type streamFeedJSON struct {
|
||||
Text string `json:"text"`
|
||||
Eou int `json:"eou"`
|
||||
@@ -364,7 +365,7 @@ var segmentSeparators = []rune{'.', '?', '!'}
|
||||
// the caller requested word granularity; token ids populate each segment's
|
||||
// Tokens by time-window membership. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
text, eou := stripEouMarker(strings.TrimSpace(doc.Text))
|
||||
|
||||
// Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off.
|
||||
gapSeconds := 0.0
|
||||
@@ -383,6 +384,7 @@ func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gap
|
||||
return pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}},
|
||||
Eou: eou,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,7 +411,25 @@ func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gap
|
||||
}
|
||||
segments = append(segments, seg)
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments}
|
||||
return pb.TranscriptResult{Text: text, Segments: segments, Eou: eou}
|
||||
}
|
||||
|
||||
// stripEouMarker removes a trailing literal <EOU>/<EOB> from offline-decode
|
||||
// text and reports whether the decode ended on an end-of-UTTERANCE token. The
|
||||
// realtime EOU model's offline decode keeps the special token in the
|
||||
// detokenized text (the streaming path strips it and surfaces it as flags
|
||||
// instead); user-visible transcripts must never carry either marker, but only
|
||||
// <EOU> may confirm the semantic_vad retranscribe cross-check — a decode
|
||||
// ending on <EOB> means the last thing heard was a backchannel, not the user
|
||||
// yielding the turn.
|
||||
func stripEouMarker(text string) (string, bool) {
|
||||
if strings.HasSuffix(text, "<EOU>") {
|
||||
return strings.TrimSpace(strings.TrimSuffix(text, "<EOU>")), true
|
||||
}
|
||||
if strings.HasSuffix(text, "<EOB>") {
|
||||
return strings.TrimSpace(strings.TrimSuffix(text, "<EOB>")), false
|
||||
}
|
||||
return text, false
|
||||
}
|
||||
|
||||
// splitWordsIntoSegments groups words into segments exactly as NeMo's
|
||||
@@ -476,41 +496,55 @@ func tokensInWindow(tokens []transcriptToken, start, end float64) []int32 {
|
||||
return ids
|
||||
}
|
||||
|
||||
// streamSegmenter accumulates streaming words into per-utterance segments. EOU
|
||||
// is the model's own utterance boundary; each closed segment takes its start/end
|
||||
// from its first/last accumulated word.
|
||||
// streamSegmenter accumulates streaming decode increments into per-utterance
|
||||
// segments. <EOU>/<EOB> are the model's own utterance boundaries; each closes a
|
||||
// segment. When the feed carries per-word timings (ABI v4 JSON), a closed
|
||||
// segment takes its start/end from its first/last word; against an older
|
||||
// text-only library (no words) it falls back to segmenting the delta text, so
|
||||
// the same assembler serves both paths.
|
||||
type streamSegmenter struct {
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord
|
||||
nextID int32
|
||||
segs []*pb.TranscriptSegment
|
||||
cur []transcriptWord // words for the open segment (ABI v4 JSON path)
|
||||
curText []string // delta text for the open segment (text-only path)
|
||||
nextID int32
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) add(doc streamFeedJSON) {
|
||||
s.cur = append(s.cur, doc.Words...)
|
||||
// Close the segment on either turn signal: <EOU> (end of utterance) or
|
||||
// <EOB> (backchannel). ABI v4 reported both via "eou"; v5 split them, so we
|
||||
// OR them here to keep the v4 segmentation boundaries.
|
||||
if doc.Eou != 0 || doc.Eob != 0 {
|
||||
func (s *streamSegmenter) add(r streamFeedResult) {
|
||||
s.cur = append(s.cur, r.Words...)
|
||||
if len(r.Words) == 0 && r.Delta != "" {
|
||||
// Older libparakeet.so with no per-word timing: segment from the text.
|
||||
s.curText = append(s.curText, r.Delta)
|
||||
}
|
||||
// Both <EOU> and <EOB> reset the decoder, so both close a segment.
|
||||
if r.Eou || r.Eob {
|
||||
s.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) flush() {
|
||||
if len(s.cur) == 0 {
|
||||
return
|
||||
switch {
|
||||
case len(s.cur) > 0:
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
case len(s.curText) > 0:
|
||||
// No words this segment: emit a text-only segment (no timestamps),
|
||||
// skipping a purely-whitespace one as the legacy text path did.
|
||||
if t := strings.TrimSpace(strings.Join(s.curText, "")); t != "" {
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{Id: s.nextID, Text: t})
|
||||
s.nextID++
|
||||
}
|
||||
}
|
||||
parts := make([]string, len(s.cur))
|
||||
for i, w := range s.cur {
|
||||
parts[i] = w.W
|
||||
}
|
||||
s.segs = append(s.segs, &pb.TranscriptSegment{
|
||||
Id: s.nextID,
|
||||
Start: secondsToNanos(s.cur[0].Start),
|
||||
End: secondsToNanos(s.cur[len(s.cur)-1].End),
|
||||
Text: strings.TrimSpace(strings.Join(parts, " ")),
|
||||
})
|
||||
s.nextID++
|
||||
s.cur = nil
|
||||
s.curText = nil
|
||||
}
|
||||
|
||||
func (s *streamSegmenter) segments() []*pb.TranscriptSegment { return s.segs }
|
||||
@@ -535,18 +569,119 @@ func secondsToNanos(sec float64) int64 {
|
||||
return int64(sec * 1e9)
|
||||
}
|
||||
|
||||
// Per-C-call engine serialization for the streaming paths.
|
||||
//
|
||||
// Every individual C call (begin / feed / finalize / free) takes engineMu and
|
||||
// re-checks ctxPtr under the lock; the lock is NEVER held across a stream's
|
||||
// lifetime. This is safe because each parakeet.cpp call builds its own ggml
|
||||
// graph and all streaming caches live in the session object, not the ctx —
|
||||
// the only ctx-shared mutable state is last_error, which is why it is read
|
||||
// under the same lock as the failing call. Holding the lock per call (rather
|
||||
// than per stream, as this file previously did) keeps a long-lived live
|
||||
// session from starving batched unary transcription and vice versa.
|
||||
//
|
||||
// A stream must not outlive its ctx (C-API contract). Free() takes engineMu
|
||||
// and zeroes ctxPtr, so a racing per-call helper returns ModelNotLoaded
|
||||
// instead of feeding a freed engine; streamFree of an orphaned session only
|
||||
// runs the session destructor, which does not touch the ctx.
|
||||
|
||||
// streamBegin opens a cache-aware streaming session. A 0 stream with nil
|
||||
// error means the loaded model is not a streaming model.
|
||||
func (p *ParakeetCpp) streamBegin(lang string) (uintptr, error) {
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
if p.ctxPtr == 0 {
|
||||
return 0, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if CppStreamBeginLang != nil {
|
||||
return CppStreamBeginLang(p.ctxPtr, lang), nil
|
||||
}
|
||||
return CppStreamBegin(p.ctxPtr), nil
|
||||
}
|
||||
|
||||
func (p *ParakeetCpp) streamFree(stream uintptr) {
|
||||
if stream == 0 {
|
||||
return
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
CppStreamFree(stream)
|
||||
}
|
||||
|
||||
// streamFeedText runs one text-mode feed (or the finalize flush when
|
||||
// finalize is true) under engineMu, returning the newly-finalized delta and
|
||||
// whether an <EOU>/<EOB> fired during the call.
|
||||
func (p *ParakeetCpp) streamFeedText(stream uintptr, pcm []float32, finalize bool) (delta string, eou, eob bool, err error) {
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
if p.ctxPtr == 0 {
|
||||
return "", false, false, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
var ret uintptr
|
||||
var events int32
|
||||
if finalize {
|
||||
ret = CppStreamFinalize(stream)
|
||||
} else {
|
||||
ret = CppStreamFeed(stream, pcm, int32(len(pcm)), unsafe.Pointer(&events))
|
||||
}
|
||||
if ret == 0 {
|
||||
// last_error is ctx-shared: read it under the same lock as the call.
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return "", false, false, fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
delta = goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
// ABI v5: eou_out is a bitmask (bit 0 = <EOU>, bit 1 = <EOB>). A v4
|
||||
// library sets 0/1 for either token, which the bit-0 test reads as the
|
||||
// old conflated eou — the EOB distinction simply isn't available there.
|
||||
return delta, events&1 != 0, events&2 != 0, nil
|
||||
}
|
||||
|
||||
// streamFeedDoc runs one ABI v4 JSON feed (or finalize) under engineMu and
|
||||
// returns the parsed {text,eou,frame_sec,words} document.
|
||||
func (p *ParakeetCpp) streamFeedDoc(stream uintptr, pcm []float32, finalize bool) (streamFeedJSON, error) {
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
if p.ctxPtr == 0 {
|
||||
return streamFeedJSON{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
var ret uintptr
|
||||
if finalize {
|
||||
ret = CppStreamFinalizeJSON(stream)
|
||||
} else {
|
||||
ret = CppStreamFeedJSON(stream, pcm, int32(len(pcm)))
|
||||
}
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return streamFeedJSON{}, fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return streamFeedJSON{}, fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream drives the cache-aware streaming RNN-T over the
|
||||
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it in
|
||||
// chunks to parakeet_capi_stream_feed, and emits each newly-finalized text
|
||||
// run as a TranscriptStreamResponse delta. <EOU>/<EOB> events close the
|
||||
// current segment; a closing FinalResult carries the full transcript and the
|
||||
// per-utterance segments.
|
||||
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it through
|
||||
// the shared decode driver (feedSlices/flushTail), and emits each
|
||||
// newly-finalized text run as a TranscriptStreamResponse delta. <EOU>/<EOB>
|
||||
// events close the current segment; a closing FinalResult carries the full
|
||||
// transcript, the per-utterance segments, and whether the file ended on an
|
||||
// utterance boundary.
|
||||
//
|
||||
// stream_begin returns 0 for models that are not cache-aware streaming models
|
||||
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those we fall
|
||||
// back to a single offline transcription emitted as one delta plus a closing
|
||||
// FinalResult, matching LocalAI's non-streaming streaming contract (and the
|
||||
// whisper backend), so the streaming endpoint works for every model.
|
||||
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those this
|
||||
// returns codes.Unimplemented rather than faking a stream from an offline
|
||||
// decode — see the stream==0 branch and grpcerrors.StreamTranscriptionUnsupported.
|
||||
func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
@@ -560,185 +695,73 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
var stream uintptr
|
||||
if CppStreamBeginLang != nil {
|
||||
stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage())
|
||||
} else {
|
||||
stream = CppStreamBegin(p.ctxPtr)
|
||||
stream, err := p.streamBegin(opts.GetLanguage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
res, err := p.AudioTranscription(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t := strings.TrimSpace(res.Text); t != "" {
|
||||
results <- &pb.TranscriptStreamResponse{Delta: t}
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: &res}
|
||||
return nil
|
||||
// Not a cache-aware streaming model. Report the missing capability
|
||||
// honestly instead of decoding offline and emitting it as one "delta"
|
||||
// + final: a client that asked for streaming must learn the model
|
||||
// cannot stream, not receive a batch result dressed as a stream (which
|
||||
// is indistinguishable except qualitatively, and silently breaks any
|
||||
// feature that genuinely needs incremental output). Callers wanting a
|
||||
// plain transcript use the unary AudioTranscription path. This mirrors
|
||||
// AudioTranscriptionLive, which already returns Unimplemented here.
|
||||
return grpcerrors.StreamTranscriptionUnsupported("parakeet-cpp",
|
||||
"loaded model is not a cache-aware streaming model")
|
||||
}
|
||||
defer CppStreamFree(stream)
|
||||
// The C engine is a single shared context: a streaming session and a batched
|
||||
// unary dispatch must never touch it at once, so hold engineMu for the whole
|
||||
// stream. This lock is intentionally taken AFTER the non-streaming fallback
|
||||
// above returns: that fallback goes through AudioTranscription -> the batcher
|
||||
// -> runBatch, which itself acquires engineMu, so locking here first would
|
||||
// deadlock. Do not hoist this lock above the fallback.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
defer p.streamFree(stream)
|
||||
|
||||
data, duration, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ABI v4: when the streaming JSON entry points are present, drive them so the
|
||||
// per-utterance segments carry per-word start/end timestamps. Falls through to
|
||||
// the text-only loop below against an older libparakeet.so. Runs under the
|
||||
// engineMu already held above.
|
||||
if CppStreamFeedJSON != nil {
|
||||
return p.streamJSON(ctx, stream, data, duration, results)
|
||||
}
|
||||
|
||||
// Fold the shared decode driver's per-feed increments into the streamed
|
||||
// deltas and the closing batch result: words/text accumulate into
|
||||
// per-utterance segments (streamSegmenter), and the utterance-boundary
|
||||
// latch (boundary.go) records whether the file ended on an <EOU>. These
|
||||
// are the offline path's concern — the live RPC carries none of them.
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
segments []*pb.TranscriptSegment
|
||||
segID int32
|
||||
seg streamSegmenter
|
||||
boundary utteranceBoundary
|
||||
)
|
||||
|
||||
flushSegment := func() {
|
||||
t := strings.TrimSpace(segText.String())
|
||||
segText.Reset()
|
||||
if t == "" {
|
||||
return
|
||||
emit := func(r streamFeedResult) error {
|
||||
if r.Delta != "" {
|
||||
full.WriteString(r.Delta)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: r.Delta}
|
||||
}
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: segID, Text: t})
|
||||
segID++
|
||||
}
|
||||
|
||||
// emitDelta consumes the malloc'd char* returned by feed/finalize: frees
|
||||
// it, accumulates the text, and sends a delta when non-empty. A 0 return
|
||||
// is an error (vs the "" empty-but-non-NULL no-new-text case).
|
||||
emitDelta := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
delta := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
full.WriteString(delta)
|
||||
segText.WriteString(delta)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
seg.add(r)
|
||||
boundary = boundary.observe(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
|
||||
var eou int32
|
||||
ret := CppStreamFeed(stream, chunk, int32(len(chunk)), unsafe.Pointer(&eou))
|
||||
if err := emitDelta(ret); err != nil {
|
||||
return err
|
||||
}
|
||||
if eou != 0 {
|
||||
flushSegment()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the streaming tail (final encoder chunk).
|
||||
if err := emitDelta(CppStreamFinalize(stream)); err != nil {
|
||||
if err := p.feedSlices(ctx, stream, data, emit); err != nil {
|
||||
return err
|
||||
}
|
||||
flushSegment()
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamJSON drives the streaming JSON entry points (present since ABI v4): each
|
||||
// feed/finalize returns a {text,eou,eob,frame_sec,words} document. The
|
||||
// newly-finalized text is emitted as a delta (unchanged streaming contract)
|
||||
// while words are accumulated into per-utterance segments (closed on <EOU> or
|
||||
// <EOB>) so the closing FinalResult carries timestamped segments. Runs under
|
||||
// engineMu (already held by the caller).
|
||||
func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32,
|
||||
duration float32, results chan *pb.TranscriptStreamResponse) error {
|
||||
var (
|
||||
full strings.Builder
|
||||
seg streamSegmenter
|
||||
)
|
||||
// consume frees the malloc'd char* (a 0 return is an error), parses the JSON,
|
||||
// emits the delta, and routes words through the segmenter.
|
||||
consume := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
raw := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
var doc streamFeedJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return fmt.Errorf("parakeet-cpp: decode stream json: %w", err)
|
||||
}
|
||||
if doc.Text != "" {
|
||||
full.WriteString(doc.Text)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: doc.Text}
|
||||
}
|
||||
seg.add(doc)
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := consume(CppStreamFinalizeJSON(stream)); err != nil {
|
||||
if err := p.flushTail(stream, emit); err != nil {
|
||||
return err
|
||||
}
|
||||
seg.flush() // close any trailing utterance that never saw an EOU
|
||||
seg.flush() // close a trailing utterance that never saw an <EOU>
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
// final.Text is the exact concatenation of the streamed deltas (full is
|
||||
// their accumulation), so concat(deltas) == FinalResult.Text holds even
|
||||
// when the model prepends a leading space to the first word (SentencePiece
|
||||
// detokenization). This matches the whisper backend's streaming contract.
|
||||
// The single-segment fallback stays trimmed.
|
||||
fullText := full.String()
|
||||
segments := seg.segments()
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
if trimmed := strings.TrimSpace(fullText); len(segments) == 0 && trimmed != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: trimmed})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Text: fullText,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
Eou: boundary.ended(),
|
||||
},
|
||||
}
|
||||
return nil
|
||||
@@ -803,6 +826,10 @@ func (p *ParakeetCpp) Free() error {
|
||||
close(p.batStop)
|
||||
p.batStop = nil
|
||||
}
|
||||
// engineMu so an in-flight streaming call (which locks per C call and
|
||||
// re-checks ctxPtr under the lock) can never feed into a freed ctx.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
if p.ctxPtr != 0 {
|
||||
CppFree(p.ctxPtr)
|
||||
p.ctxPtr = 0
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestParakeetCpp(t *testing.T) {
|
||||
@@ -201,6 +203,29 @@ var _ = Describe("ParakeetCpp", func() {
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("returns the typed Unimplemented signal for non-streaming models (no offline fallback)", func() {
|
||||
// stream_begin == 0 means the loaded model is not a cache-aware
|
||||
// streaming model. The backend must surface that, not silently
|
||||
// decode offline and fake a one-shot "stream".
|
||||
savedBegin, savedBeginLang := CppStreamBegin, CppStreamBeginLang
|
||||
defer func() { CppStreamBegin, CppStreamBeginLang = savedBegin, savedBeginLang }()
|
||||
CppStreamBeginLang = nil
|
||||
CppStreamBegin = func(ctx uintptr) uintptr { return 0 }
|
||||
|
||||
p := &ParakeetCpp{ctxPtr: 1}
|
||||
results := make(chan *pb.TranscriptStreamResponse, 8)
|
||||
err := p.AudioTranscriptionStream(context.Background(),
|
||||
&pb.TranscriptRequest{Dst: "ignored.wav"}, results)
|
||||
Expect(status.Code(err)).To(Equal(codes.Unimplemented))
|
||||
|
||||
// Honest signal: nothing was emitted — no faked batch result.
|
||||
var emitted []*pb.TranscriptStreamResponse
|
||||
for r := range results {
|
||||
emitted = append(emitted, r)
|
||||
}
|
||||
Expect(emitted).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
|
||||
// Streaming needs a cache-aware streaming model (e.g.
|
||||
// realtime_eou); the offline test model would fail stream_begin.
|
||||
|
||||
186
backend/go/parakeet-cpp/live.go
Normal file
186
backend/go/parakeet-cpp/live.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// liveSampleRate is the only PCM rate the parakeet C streaming API accepts.
|
||||
const liveSampleRate = 16000
|
||||
|
||||
// AudioTranscriptionLive drives one cache-aware streaming session over audio
|
||||
// fed incrementally by the caller (the realtime API's semantic_vad turn
|
||||
// detection). Contract:
|
||||
//
|
||||
// - the first request must carry a Config; a Config mid-stream resets the
|
||||
// decode session (free + begin) and drops accumulated transcript state;
|
||||
// - a Ready ack is sent right after a successful stream_begin so callers
|
||||
// can degrade synchronously when the model has no streaming support
|
||||
// (LiveTranscriptionUnsupported, codes.Unimplemented);
|
||||
// - every feed that produced output is forwarded as {delta, eou, words};
|
||||
// the <EOU>/<EOB> flag is the model's own utterance boundary and the
|
||||
// decoder auto-resets after it, so one session spans many utterances;
|
||||
// - closing the send side finalizes: the held-back tail chunk is flushed
|
||||
// (the last ~2 encoder frames of words only appear here) and a terminal
|
||||
// FinalResult carries the full transcript Text only. Per-utterance
|
||||
// segments, duration, and the terminal <EOU> flag are NOT produced here —
|
||||
// the realtime core consumes the streamed per-feed tokens and the final
|
||||
// Text; those batch fields are the file path's concern (see
|
||||
// AudioTranscriptionStream).
|
||||
//
|
||||
// Engine access is serialized per C call (streamBegin/streamFeed*/streamFree
|
||||
// take engineMu internally), never for the session lifetime — unary
|
||||
// transcription keeps flowing between feeds.
|
||||
func (p *ParakeetCpp) AudioTranscriptionLive(in <-chan *pb.TranscriptLiveRequest, out chan<- *pb.TranscriptLiveResponse) error {
|
||||
defer close(out)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
|
||||
first, ok := <-in
|
||||
if !ok {
|
||||
return nil // caller closed without sending anything
|
||||
}
|
||||
cfg := first.GetConfig()
|
||||
if cfg == nil {
|
||||
return status.Error(codes.InvalidArgument, "parakeet-cpp: first live message must carry a config")
|
||||
}
|
||||
if err := validateLiveConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := p.streamBegin(cfg.GetLanguage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stream == 0 {
|
||||
return grpcerrors.LiveTranscriptionUnsupported("parakeet-cpp",
|
||||
"loaded model is not a cache-aware streaming model")
|
||||
}
|
||||
// stream is reassigned on a mid-stream Config reset; free whatever is
|
||||
// current when the RPC unwinds.
|
||||
defer func() { p.streamFree(stream) }()
|
||||
|
||||
out <- &pb.TranscriptLiveResponse{Ready: true}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
fedSecs float64
|
||||
|
||||
// behindSec accumulates how far decode wall time has fallen behind
|
||||
// the audio it was fed. A live caller feeds in real time, so a
|
||||
// persistent positive backlog means every downstream signal —
|
||||
// including the <EOU> the turn detector waits on — arrives that many
|
||||
// seconds late. Warned once per session; reset by a Config reset.
|
||||
behindSec float64
|
||||
behindWarned bool
|
||||
)
|
||||
|
||||
// emit forwards one decode increment: it streams the per-feed tokens the
|
||||
// realtime turn detector consumes (delta/eou/eob/words) and accumulates the
|
||||
// running transcript for the closing FinalResult. No segmentation or
|
||||
// boundary latch here — the live consumer reads only the streamed tokens
|
||||
// and the final Text; per-utterance segments and the terminal <EOU> flag
|
||||
// are an offline-path concern (see AudioTranscriptionStream / boundary.go).
|
||||
emit := func(r streamFeedResult) error {
|
||||
if r.Delta != "" {
|
||||
full.WriteString(r.Delta)
|
||||
}
|
||||
if r.Delta != "" || r.Eou || r.Eob || len(r.Words) > 0 {
|
||||
out <- &pb.TranscriptLiveResponse{
|
||||
Delta: r.Delta,
|
||||
Eou: r.Eou,
|
||||
Eob: r.Eob,
|
||||
Words: liveWordsToProto(r.Words),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for req := range in {
|
||||
switch payload := req.GetPayload().(type) {
|
||||
case *pb.TranscriptLiveRequest_Config:
|
||||
if err := validateLiveConfig(payload.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
// Reset: a fresh decode session, dropping accumulated state.
|
||||
p.streamFree(stream)
|
||||
stream, err = p.streamBegin(payload.Config.GetLanguage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stream == 0 {
|
||||
return grpcerrors.LiveTranscriptionUnsupported("parakeet-cpp",
|
||||
"loaded model is not a cache-aware streaming model")
|
||||
}
|
||||
full.Reset()
|
||||
fedSecs = 0
|
||||
case *pb.TranscriptLiveRequest_Audio:
|
||||
pcm := payload.Audio.GetPcm()
|
||||
audioSec := float64(len(pcm)) / liveSampleRate
|
||||
fedSecs += audioSec
|
||||
start := time.Now()
|
||||
// nil ctx: a live session is bounded by this request channel, not a
|
||||
// context — cancellation is the caller closing the stream.
|
||||
if err := p.feedSlices(nil, stream, pcm, emit); err != nil {
|
||||
return err
|
||||
}
|
||||
wallSec := time.Since(start).Seconds()
|
||||
behindSec += wallSec - audioSec
|
||||
if behindSec < 0 {
|
||||
behindSec = 0
|
||||
}
|
||||
xlog.Debug("parakeet-cpp: live feed",
|
||||
"audio_ms", int(audioSec*1000), "wall_ms", int(wallSec*1000),
|
||||
"behind_ms", int(behindSec*1000), "fed_s", fedSecs)
|
||||
if behindSec > 1 && !behindWarned {
|
||||
behindWarned = true
|
||||
xlog.Warn("parakeet-cpp: live decode is falling behind real time; "+
|
||||
"end-of-utterance signals will arrive late",
|
||||
"behind_s", behindSec, "fed_s", fedSecs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send side closed: flush the streaming tail and emit the final transcript.
|
||||
// The live FinalResult carries only Text — the authoritative full-turn
|
||||
// transcript the realtime core commits. Per-utterance segments, duration,
|
||||
// and the terminal <EOU> flag are not produced on the live path.
|
||||
if err := p.flushTail(stream, emit); err != nil {
|
||||
return err
|
||||
}
|
||||
out <- &pb.TranscriptLiveResponse{
|
||||
FinalResult: &pb.TranscriptResult{Text: strings.TrimSpace(full.String())},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateLiveConfig(cfg *pb.TranscriptLiveConfig) error {
|
||||
if sr := cfg.GetSampleRate(); sr != 0 && sr != liveSampleRate {
|
||||
return status.Errorf(codes.InvalidArgument,
|
||||
"parakeet-cpp: unsupported live sample_rate %d (only %d)", sr, liveSampleRate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func liveWordsToProto(words []transcriptWord) []*pb.TranscriptWord {
|
||||
if len(words) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*pb.TranscriptWord, len(words))
|
||||
for i, w := range words {
|
||||
out[i] = &pb.TranscriptWord{
|
||||
Start: secondsToNanos(w.Start),
|
||||
End: secondsToNanos(w.End),
|
||||
Text: w.W,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
417
backend/go/parakeet-cpp/live_test.go
Normal file
417
backend/go/parakeet-cpp/live_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// The live-RPC specs drive AudioTranscriptionLive entirely against stubbed
|
||||
// Cpp* package vars (the same seam batcher_test.go uses), so they run
|
||||
// without libparakeet.so.
|
||||
|
||||
// liveCstrPool hands out NUL-terminated C-style strings backed by Go memory
|
||||
// and keeps them alive for the duration of a spec (goStringFromCPtr reads
|
||||
// through the raw pointer; Go's GC must not collect the backing array while
|
||||
// a stub's return value is in flight).
|
||||
type liveCstrPool struct {
|
||||
mu sync.Mutex
|
||||
bufs [][]byte
|
||||
}
|
||||
|
||||
func (p *liveCstrPool) cstr(s string) uintptr {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
b := append([]byte(s), 0)
|
||||
p.bufs = append(p.bufs, b)
|
||||
return uintptr(unsafe.Pointer(&b[0]))
|
||||
}
|
||||
|
||||
// liveStubs swaps every C entry point the live path touches and returns a
|
||||
// restore func for AfterEach.
|
||||
func liveStubs() (restore func()) {
|
||||
savedBegin, savedBeginLang := CppStreamBegin, CppStreamBeginLang
|
||||
savedFeed, savedFeedJSON := CppStreamFeed, CppStreamFeedJSON
|
||||
savedFinalize, savedFinalizeJSON := CppStreamFinalize, CppStreamFinalizeJSON
|
||||
savedFree, savedLastError := CppStreamFree, CppLastError
|
||||
savedFreeString := CppFreeString
|
||||
return func() {
|
||||
CppStreamBegin, CppStreamBeginLang = savedBegin, savedBeginLang
|
||||
CppStreamFeed, CppStreamFeedJSON = savedFeed, savedFeedJSON
|
||||
CppStreamFinalize, CppStreamFinalizeJSON = savedFinalize, savedFinalizeJSON
|
||||
CppStreamFree, CppLastError = savedFree, savedLastError
|
||||
CppFreeString = savedFreeString
|
||||
}
|
||||
}
|
||||
|
||||
// runLive starts the RPC on its own goroutine and returns the request
|
||||
// channel plus a collector for everything the backend emitted.
|
||||
func runLive(p *ParakeetCpp) (chan *pb.TranscriptLiveRequest, chan *pb.TranscriptLiveResponse, chan error) {
|
||||
in := make(chan *pb.TranscriptLiveRequest)
|
||||
out := make(chan *pb.TranscriptLiveResponse, 32)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- p.AudioTranscriptionLive(in, out) }()
|
||||
return in, out, errCh
|
||||
}
|
||||
|
||||
func liveConfig(lang string) *pb.TranscriptLiveRequest {
|
||||
return &pb.TranscriptLiveRequest{
|
||||
Payload: &pb.TranscriptLiveRequest_Config{Config: &pb.TranscriptLiveConfig{Language: lang}},
|
||||
}
|
||||
}
|
||||
|
||||
func liveAudio(pcm []float32) *pb.TranscriptLiveRequest {
|
||||
return &pb.TranscriptLiveRequest{
|
||||
Payload: &pb.TranscriptLiveRequest_Audio{Audio: &pb.TranscriptLiveAudio{Pcm: pcm}},
|
||||
}
|
||||
}
|
||||
|
||||
func collectLive(out chan *pb.TranscriptLiveResponse) []*pb.TranscriptLiveResponse {
|
||||
var got []*pb.TranscriptLiveResponse
|
||||
for r := range out {
|
||||
got = append(got, r)
|
||||
}
|
||||
return got
|
||||
}
|
||||
|
||||
var _ = Describe("AudioTranscriptionLive (stubbed C API)", func() {
|
||||
var (
|
||||
pool *liveCstrPool
|
||||
restore func()
|
||||
p *ParakeetCpp
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
pool = &liveCstrPool{}
|
||||
restore = liveStubs()
|
||||
p = &ParakeetCpp{ctxPtr: 1}
|
||||
|
||||
CppStreamBeginLang = nil
|
||||
CppStreamBegin = func(ctx uintptr) uintptr { return 7 }
|
||||
CppStreamFree = func(s uintptr) {}
|
||||
CppFreeString = func(s uintptr) {}
|
||||
CppLastError = func(ctx uintptr) string { return "stub error" }
|
||||
CppStreamFeed = nil
|
||||
CppStreamFeedJSON = nil
|
||||
CppStreamFinalize = nil
|
||||
CppStreamFinalizeJSON = nil
|
||||
})
|
||||
|
||||
AfterEach(func() { restore() })
|
||||
|
||||
It("rejects a stream whose first message is not a config", func() {
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveAudio([]float32{0.1})
|
||||
close(in)
|
||||
|
||||
err := <-errCh
|
||||
Expect(status.Code(err)).To(Equal(codes.InvalidArgument))
|
||||
Expect(collectLive(out)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("rejects a non-16k sample rate", func() {
|
||||
in, _, errCh := runLive(p)
|
||||
in <- &pb.TranscriptLiveRequest{
|
||||
Payload: &pb.TranscriptLiveRequest_Config{Config: &pb.TranscriptLiveConfig{SampleRate: 8000}},
|
||||
}
|
||||
close(in)
|
||||
Expect(status.Code(<-errCh)).To(Equal(codes.InvalidArgument))
|
||||
})
|
||||
|
||||
It("returns the typed Unimplemented signal for non-streaming models, before any ack", func() {
|
||||
CppStreamBegin = func(ctx uintptr) uintptr { return 0 }
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
close(in)
|
||||
|
||||
err := <-errCh
|
||||
Expect(grpcerrors.IsLiveTranscriptionUnsupported(err)).To(BeTrue())
|
||||
Expect(collectLive(out)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("streams deltas, eou flags and words on the JSON path and finalizes on close", func() {
|
||||
var freed []uintptr
|
||||
CppStreamFree = func(s uintptr) { freed = append(freed, s) }
|
||||
feeds := 0
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
feeds++
|
||||
switch feeds {
|
||||
case 1:
|
||||
return pool.cstr(`{"text":"hello ","eou":0,"frame_sec":0.08,` +
|
||||
`"words":[{"w":"hello","start":0.1,"end":0.4,"conf":0.9}]}`)
|
||||
default:
|
||||
return pool.cstr(`{"text":"world","eou":1,"frame_sec":0.08,` +
|
||||
`"words":[{"w":"world","start":0.5,"end":0.8,"conf":0.9}]}`)
|
||||
}
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("en")
|
||||
in <- liveAudio(make([]float32, 100))
|
||||
in <- liveAudio(make([]float32, 200))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(4)) // ready, two deltas, final
|
||||
|
||||
Expect(got[0].Ready).To(BeTrue())
|
||||
|
||||
Expect(got[1].Delta).To(Equal("hello "))
|
||||
Expect(got[1].Eou).To(BeFalse())
|
||||
Expect(got[1].Words).To(HaveLen(1))
|
||||
Expect(got[1].Words[0].Text).To(Equal("hello"))
|
||||
|
||||
Expect(got[2].Delta).To(Equal("world"))
|
||||
Expect(got[2].Eou).To(BeTrue())
|
||||
|
||||
final := got[3].FinalResult
|
||||
Expect(final).NotTo(BeNil())
|
||||
Expect(final.Text).To(Equal("hello world"))
|
||||
// The live FinalResult carries only Text. Per-utterance segments,
|
||||
// duration and the terminal eou flag are an offline-path concern (see
|
||||
// boundary.go / AudioTranscriptionStream); the realtime core reads the
|
||||
// streamed per-feed tokens above plus this Text.
|
||||
Expect(final.Eou).To(BeFalse())
|
||||
Expect(final.Segments).To(BeEmpty())
|
||||
Expect(final.Duration).To(BeZero())
|
||||
|
||||
Expect(freed).To(Equal([]uintptr{7}))
|
||||
})
|
||||
|
||||
It("falls back to the text feed (eou out-param) when the JSON entry points are absent", func() {
|
||||
feeds := 0
|
||||
CppStreamFeed = func(s uintptr, pcm []float32, n int32, eouOut unsafe.Pointer) uintptr {
|
||||
feeds++
|
||||
if feeds == 2 {
|
||||
*(*int32)(eouOut) = 1
|
||||
return pool.cstr("done")
|
||||
}
|
||||
return pool.cstr("first ")
|
||||
}
|
||||
CppStreamFinalize = func(s uintptr) uintptr { return pool.cstr("") }
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(4))
|
||||
Expect(got[1].Delta).To(Equal("first "))
|
||||
Expect(got[1].Eou).To(BeFalse())
|
||||
Expect(got[2].Delta).To(Equal("done"))
|
||||
Expect(got[2].Eou).To(BeTrue())
|
||||
Expect(got[3].FinalResult.Text).To(Equal("first done"))
|
||||
})
|
||||
|
||||
It("forwards <EOB> as eob — a backchannel, never an eou (ABI v5 JSON)", func() {
|
||||
feeds := 0
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
feeds++
|
||||
if feeds == 1 {
|
||||
return pool.cstr(`{"text":"uh-huh","eou":0,"eob":1,"frame_sec":0.08,` +
|
||||
`"words":[{"w":"uh-huh","start":0.1,"end":0.3,"conf":0.9}]}`)
|
||||
}
|
||||
return pool.cstr(`{"text":"the turn","eou":1,"eob":0,"frame_sec":0.08,` +
|
||||
`"words":[{"w":"the","start":0.5,"end":0.6,"conf":0.9},{"w":"turn","start":0.6,"end":0.8,"conf":0.9}]}`)
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"eob":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(4))
|
||||
Expect(got[1].Eob).To(BeTrue())
|
||||
Expect(got[1].Eou).To(BeFalse(), "a backchannel must not masquerade as a turn boundary")
|
||||
Expect(got[2].Eou).To(BeTrue())
|
||||
})
|
||||
|
||||
It("maps the v5 eou_out bitmask on the text path (bit0 <EOU>, bit1 <EOB>)", func() {
|
||||
feeds := 0
|
||||
CppStreamFeed = func(s uintptr, pcm []float32, n int32, eouOut unsafe.Pointer) uintptr {
|
||||
feeds++
|
||||
if feeds == 1 {
|
||||
*(*int32)(eouOut) = 2 // <EOB> only
|
||||
return pool.cstr("uh-huh")
|
||||
}
|
||||
*(*int32)(eouOut) = 1 // <EOU>
|
||||
return pool.cstr(" done")
|
||||
}
|
||||
CppStreamFinalize = func(s uintptr) uintptr { return pool.cstr("") }
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(4))
|
||||
Expect(got[1].Eob).To(BeTrue())
|
||||
Expect(got[1].Eou).To(BeFalse())
|
||||
Expect(got[2].Eou).To(BeTrue())
|
||||
Expect(got[2].Eob).To(BeFalse())
|
||||
})
|
||||
|
||||
It("accumulates trailing text after an EOU into the final transcript", func() {
|
||||
feeds := 0
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
feeds++
|
||||
if feeds == 1 {
|
||||
return pool.cstr(`{"text":"turn one","eou":1,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
return pool.cstr(`{"text":" and more","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
final := got[len(got)-1].FinalResult
|
||||
Expect(final.Text).To(Equal("turn one and more"))
|
||||
})
|
||||
|
||||
It("resets the decode session on a mid-stream config", func() {
|
||||
var begun, freed int
|
||||
CppStreamBegin = func(ctx uintptr) uintptr { begun++; return uintptr(10 + begun) }
|
||||
CppStreamFree = func(s uintptr) { freed++ }
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
return pool.cstr(`{"text":"x","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
in <- liveConfig("") // reset
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
|
||||
got := collectLive(out)
|
||||
final := got[len(got)-1].FinalResult
|
||||
Expect(final.Text).To(Equal("x"), "pre-reset transcript dropped")
|
||||
Expect(begun).To(Equal(2))
|
||||
Expect(freed).To(Equal(2), "old session freed on reset, new one on unwind")
|
||||
})
|
||||
|
||||
It("does not hold engineMu between feeds (unary work interleaves with a live session)", func() {
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
CppStreamFinalizeJSON = func(s uintptr) uintptr {
|
||||
return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`)
|
||||
}
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
|
||||
// The session is open and idle between feeds: the engine lock must be
|
||||
// acquirable, which is what lets batched unary transcription proceed
|
||||
// mid-session. Under stream-lifetime locking this probe would block
|
||||
// until the stream ended and the Eventually would time out.
|
||||
locked := make(chan struct{})
|
||||
go func() {
|
||||
p.engineMu.Lock()
|
||||
p.engineMu.Unlock() //nolint:staticcheck // probe: acquire-release proves availability
|
||||
close(locked)
|
||||
}()
|
||||
Eventually(locked, time.Second).Should(BeClosed())
|
||||
|
||||
close(in)
|
||||
Expect(<-errCh).NotTo(HaveOccurred())
|
||||
collectLive(out)
|
||||
})
|
||||
|
||||
It("errors out and reads last_error under the lock when a feed fails", func() {
|
||||
CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr { return 0 }
|
||||
|
||||
in, out, errCh := runLive(p)
|
||||
in <- liveConfig("")
|
||||
in <- liveAudio(make([]float32, 10))
|
||||
|
||||
err := <-errCh
|
||||
Expect(err).To(MatchError(ContainSubstring("stub error")))
|
||||
got := collectLive(out)
|
||||
Expect(got).To(HaveLen(1)) // just the ready ack
|
||||
close(in)
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("stripEouMarker", func() {
|
||||
It("strips a trailing <EOU> and reports it", func() {
|
||||
text, eou := stripEouMarker("it is certainly very like the old portrait<EOU>")
|
||||
Expect(text).To(Equal("it is certainly very like the old portrait"))
|
||||
Expect(eou).To(BeTrue())
|
||||
})
|
||||
|
||||
It("strips a trailing <EOB> WITHOUT reporting an utterance end", func() {
|
||||
// A decode ending on a backchannel must not confirm the
|
||||
// retranscribe gate — the user was acknowledging, not yielding.
|
||||
text, eou := stripEouMarker("uh-huh<EOB>")
|
||||
Expect(text).To(Equal("uh-huh"))
|
||||
Expect(eou).To(BeFalse())
|
||||
})
|
||||
|
||||
It("leaves marker-free text alone", func() {
|
||||
text, eou := stripEouMarker("plain transcript")
|
||||
Expect(text).To(Equal("plain transcript"))
|
||||
Expect(eou).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not strip a marker in the middle of the text", func() {
|
||||
text, eou := stripEouMarker("a<EOU>b")
|
||||
Expect(text).To(Equal("a<EOU>b"))
|
||||
Expect(eou).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("transcriptResultFromDoc EOU handling", func() {
|
||||
It("strips the offline marker from text and sets the result flag", func() {
|
||||
doc := transcriptJSON{Text: "the old portrait<EOU>"}
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Text).To(Equal("the old portrait"))
|
||||
Expect(res.Eou).To(BeTrue())
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
Expect(res.Segments[0].Text).To(Equal("the old portrait"))
|
||||
})
|
||||
|
||||
It("reports eou=false for marker-free decodes", func() {
|
||||
doc := transcriptJSON{Text: "no marker here"}
|
||||
res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0)
|
||||
Expect(res.Text).To(Equal("no marker here"))
|
||||
Expect(res.Eou).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -106,7 +106,7 @@ var _ = Describe("transcriptResultFromDoc (multi-segment)", func() {
|
||||
var _ = Describe("streaming segment assembly", func() {
|
||||
It("closes a segment with start/end from its words on EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hello world", Eou: 1, Words: []transcriptWord{
|
||||
acc.add(streamFeedResult{Delta: "hello world", Eou: true, Words: []transcriptWord{
|
||||
{W: "hello", Start: 0.0, End: 0.4}, {W: "world", Start: 0.4, End: 0.9},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
@@ -118,9 +118,9 @@ var _ = Describe("streaming segment assembly", func() {
|
||||
|
||||
It("buffers words across feeds until EOU", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "hi", Eou: 0, Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
acc.add(streamFeedResult{Delta: "hi", Words: []transcriptWord{{W: "hi", Start: 0, End: 0.3}}})
|
||||
Expect(acc.segments()).To(BeEmpty())
|
||||
acc.add(streamFeedJSON{Text: "there", Eou: 1, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
acc.add(streamFeedResult{Delta: "there", Eou: true, Words: []transcriptWord{{W: "there", Start: 0.3, End: 0.7}}})
|
||||
Expect(acc.segments()).To(HaveLen(1))
|
||||
Expect(acc.segments()[0].Text).To(Equal("hi there"))
|
||||
})
|
||||
@@ -129,7 +129,7 @@ var _ = Describe("streaming segment assembly", func() {
|
||||
// field; a backchannel must still close the segment as it did in v4.
|
||||
It("closes a segment on EOB (backchannel) too", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedJSON{Text: "uh huh", Eou: 0, Eob: 1, Words: []transcriptWord{
|
||||
acc.add(streamFeedResult{Delta: "uh huh", Eob: true, Words: []transcriptWord{
|
||||
{W: "uh", Start: 0.0, End: 0.2}, {W: "huh", Start: 0.2, End: 0.5},
|
||||
}})
|
||||
segs := acc.segments()
|
||||
@@ -137,4 +137,18 @@ var _ = Describe("streaming segment assembly", func() {
|
||||
Expect(segs[0].Text).To(Equal("uh huh"))
|
||||
Expect(segs[0].End).To(Equal(secondsToNanos(0.5)))
|
||||
})
|
||||
|
||||
// Older text-only libparakeet.so: no per-word timings, so a segment is cut
|
||||
// from the delta text on each <EOU>/<EOB> (no timestamps), one per utterance.
|
||||
It("falls back to text segments when the feed carries no words", func() {
|
||||
acc := &streamSegmenter{}
|
||||
acc.add(streamFeedResult{Delta: "first turn", Eou: true})
|
||||
acc.add(streamFeedResult{Delta: "second turn", Eou: true})
|
||||
segs := acc.segments()
|
||||
Expect(segs).To(HaveLen(2))
|
||||
Expect(segs[0].Text).To(Equal("first turn"))
|
||||
Expect(segs[1].Text).To(Equal("second turn"))
|
||||
Expect(segs[0].Start).To(Equal(int64(0)), "no per-word timing on the text path")
|
||||
Expect(segs[0].End).To(Equal(int64(0)))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=8caa3f908ae6d4a4bef531e73b9a969f266a3d1f
|
||||
STABLEDIFFUSION_GGML_VERSION?=3b6c9ca97cfcda8e68e719e6670d06379fcbe943
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
18
backend/go/voice-detect/.gitignore
vendored
Normal file
18
backend/go/voice-detect/.gitignore
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
# Fetched upstream sources
|
||||
sources/
|
||||
|
||||
# CMake build directories
|
||||
build*/
|
||||
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in voice-detect.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
voicedetect_capi.h
|
||||
compile_commands.json
|
||||
|
||||
# Compiled backend binary
|
||||
voice-detect-grpc
|
||||
|
||||
# Packaging output
|
||||
package/
|
||||
107
backend/go/voice-detect/Makefile
Normal file
107
backend/go/voice-detect/Makefile
Normal file
@@ -0,0 +1,107 @@
|
||||
# voice-detect backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as VOICEDETECT_VERSION?=1db1759572c90faef6f3a78c36b5941a096a9f89
|
||||
# can find and update it - matches the parakeet.cpp / whisper.cpp / ds4 convention).
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree voice-detect.cpp build,
|
||||
# symlink the .so + header into this directory and skip the clone/cmake steps:
|
||||
#
|
||||
# ln -sf /path/to/voice-detect.cpp/build-shared/libvoicedetect.so .
|
||||
# ln -sf /path/to/voice-detect.cpp/include/voicedetect_capi.h .
|
||||
# go build -o voice-detect-grpc .
|
||||
#
|
||||
# The default target below does the proper clone-at-pin + cmake build so CI does
|
||||
# not need a side-checkout.
|
||||
|
||||
VOICEDETECT_VERSION?=1db1759572c90faef6f3a78c36b5941a096a9f89
|
||||
VOICEDETECT_REPO?=https://github.com/mudler/voice-detect.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# Resolve the target arch. The backend matrix / Docker build pass TARGETARCH
|
||||
# (amd64|arm64); fall back to uname -m (aarch64|x86_64) for a local build.
|
||||
RECON_ARCH?=$(or $(TARGETARCH),$(shell uname -m))
|
||||
|
||||
# Build ggml statically into libvoicedetect.so (PIC) so the shared lib is
|
||||
# self-contained: dlopen needs no libggml*.so alongside it, only system libs
|
||||
# (libstdc++/libgomp/libc) that the runtime image already provides.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DVOICEDETECT_SHARED=ON -DVOICEDETECT_BUILD_CLI=OFF -DVOICEDETECT_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# voice-detect.cpp gates its GGML backends behind VOICEDETECT_GGML_* options and
|
||||
# does set(GGML_CUDA ${VOICEDETECT_GGML_CUDA} CACHE BOOL "" FORCE), so a bare
|
||||
# -DGGML_CUDA=ON is overwritten back to OFF. Forward the VOICEDETECT_GGML_*
|
||||
# options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DVOICEDETECT_GGML_CUDA=ON
|
||||
# Opt-in cuDNN implicit-GEMM conv path (kills im2col on GPU, reaches
|
||||
# torch-cuDNN parity). Only the arm64 + CUDA 13 image (GB10/Jetson/L4T)
|
||||
# ships libcudnn9 + the -dev headers, so gate cuDNN to that variant.
|
||||
# x86 CUDA images carry no cuDNN -> enabling it there is a link failure.
|
||||
ifeq ($(CUDA_MAJOR_VERSION),13)
|
||||
ifneq (,$(filter arm64 aarch64,$(RECON_ARCH)))
|
||||
CMAKE_ARGS+=-DVOICEDETECT_GGML_CUDNN=ON
|
||||
endif
|
||||
endif
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DVOICEDETECT_GGML_HIP=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DVOICEDETECT_GGML_VULKAN=ON
|
||||
else ifeq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DVOICEDETECT_GGML_METAL=ON
|
||||
endif
|
||||
|
||||
.PHONY: voice-detect-grpc package build clean purge test all
|
||||
|
||||
all: voice-detect-grpc
|
||||
|
||||
# Clone the upstream voice-detect.cpp source at the pinned commit. Directory acts
|
||||
# as the target so make only re-clones when missing. After a VOICEDETECT_VERSION
|
||||
# bump, run 'make purge && make' to refetch.
|
||||
sources/voice-detect.cpp:
|
||||
mkdir -p sources/voice-detect.cpp
|
||||
cd sources/voice-detect.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(VOICEDETECT_REPO) && \
|
||||
git fetch --depth 1 origin $(VOICEDETECT_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib + header out-of-tree, then stage them next to the Go
|
||||
# sources so purego.Dlopen("libvoicedetect.so") and the cgo-less build both pick
|
||||
# them up.
|
||||
libvoicedetect.so: sources/voice-detect.cpp
|
||||
cmake -B sources/voice-detect.cpp/build-shared -S sources/voice-detect.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/voice-detect.cpp/build-shared --config Release -j$(JOBS) --target voicedetect
|
||||
cp -fv sources/voice-detect.cpp/build-shared/libvoicedetect.so* ./ 2>/dev/null || true
|
||||
cp -fv sources/voice-detect.cpp/include/voicedetect_capi.h ./
|
||||
|
||||
voice-detect-grpc: libvoicedetect.so main.go govoicedetect.go options.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o voice-detect-grpc .
|
||||
|
||||
package: voice-detect-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. The embed/verify/analyze smoke specs are gated on
|
||||
# VOICEDETECT_BACKEND_TEST_MODEL + VOICEDETECT_BACKEND_TEST_WAV; without them the
|
||||
# heavy specs auto-skip and only the pure-Go parsing specs run.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libvoicedetect.so* voicedetect_capi.h package voice-detect-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/voice-detect.cpp
|
||||
273
backend/go/voice-detect/govoicedetect.go
Normal file
273
backend/go/voice-detect/govoicedetect.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// purego-bound entry points from libvoicedetect.so. Names match
|
||||
// voicedetect_capi.h exactly so a `nm libvoicedetect.so | grep voicedetect_capi`
|
||||
// is enough to spot drift.
|
||||
//
|
||||
// The opaque ctx and the malloc'd char*/float* return values are declared as
|
||||
// uintptr so we get the raw pointer back and can release it via the matching
|
||||
// capi free function. purego's native string/[]float32 returns would copy and
|
||||
// forget the original pointer, leaking the C-owned buffer on every call.
|
||||
var (
|
||||
CppAbiVersion func() int32
|
||||
CppLoad func(ggufPath string) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppFreeString func(s uintptr)
|
||||
CppFreeVec func(v uintptr)
|
||||
CppEmbedPath func(ctx uintptr, wavPath string, outVec, outDim unsafe.Pointer) int32
|
||||
CppEmbedPCM func(ctx uintptr, pcm []float32, nSamples, sampleRate int32, outVec, outDim unsafe.Pointer) int32
|
||||
CppVerifyPaths func(ctx uintptr, a, b string, threshold float32, outDistance, outVerified unsafe.Pointer) int32
|
||||
CppAnalyzeJSON func(ctx uintptr, wavPath string) uintptr
|
||||
)
|
||||
|
||||
// VoiceDetect implements the speaker-recognition voice subset of the Backend
|
||||
// gRPC service over libvoicedetect.so. The C side keeps a single loaded model
|
||||
// plus a per-ctx last-error buffer and is not reentrant, so base.SingleThread
|
||||
// serializes every call.
|
||||
type VoiceDetect struct {
|
||||
base.SingleThread
|
||||
opts loadOptions
|
||||
ctxPtr uintptr
|
||||
}
|
||||
|
||||
func (v *VoiceDetect) Load(opts *pb.ModelOptions) error {
|
||||
model := opts.ModelFile
|
||||
if model == "" {
|
||||
model = opts.ModelPath
|
||||
}
|
||||
if !filepath.IsAbs(model) && opts.ModelPath != "" {
|
||||
model = filepath.Join(opts.ModelPath, model)
|
||||
}
|
||||
if model == "" {
|
||||
return errors.New("voice-detect: ModelFile is required")
|
||||
}
|
||||
|
||||
v.opts = parseOptions(opts.Options)
|
||||
if v.opts.modelName == "" {
|
||||
v.opts.modelName = filepath.Base(model)
|
||||
}
|
||||
|
||||
// Propagate LocalAI's per-model thread budget to the engine. LocalAI spawns
|
||||
// one backend process per model and serves requests concurrently, so the
|
||||
// engine's own min(hardware_concurrency, 8) default can oversubscribe cores.
|
||||
// VOICEDETECT_THREADS is read by the engine at backend construction, so it
|
||||
// must be set before the capi load. A non-positive Threads means "unset":
|
||||
// leave the env alone so the engine keeps its sane default.
|
||||
threads := opts.Threads
|
||||
if threads > 0 {
|
||||
if err := os.Setenv("VOICEDETECT_THREADS", strconv.Itoa(int(threads))); err != nil {
|
||||
return fmt.Errorf("voice-detect: set VOICEDETECT_THREADS: %w", err)
|
||||
}
|
||||
xlog.Info("voice-detect: applying LocalAI thread budget", "threads", threads)
|
||||
}
|
||||
|
||||
xlog.Info("voice-detect: loading model", "model", model,
|
||||
"verify_threshold", v.opts.verifyThreshold, "abi", CppAbiVersion())
|
||||
|
||||
ctx := CppLoad(model)
|
||||
if ctx == 0 {
|
||||
// The last-error buffer lives on the ctx that was never returned, so
|
||||
// surface the path the operator tried to load instead.
|
||||
return fmt.Errorf("voice-detect: voicedetect_capi_load failed for %q", model)
|
||||
}
|
||||
v.ctxPtr = ctx
|
||||
return nil
|
||||
}
|
||||
|
||||
// VoiceEmbed returns the L2-normalized speaker embedding for an audio clip.
|
||||
// The request carries a filesystem PATH; the HTTP layer materializes
|
||||
// base64/URL/data-URI inputs to a temp file before the gRPC call.
|
||||
func (v *VoiceDetect) VoiceEmbed(req *pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error) {
|
||||
if v.ctxPtr == 0 {
|
||||
return pb.VoiceEmbedResponse{}, errors.New("voice-detect: model not loaded")
|
||||
}
|
||||
if req.Audio == "" {
|
||||
return pb.VoiceEmbedResponse{}, errors.New("voice-detect: audio path is required")
|
||||
}
|
||||
emb, err := v.embedPath(req.Audio)
|
||||
if err != nil {
|
||||
return pb.VoiceEmbedResponse{}, err
|
||||
}
|
||||
return pb.VoiceEmbedResponse{Embedding: emb, Model: v.opts.modelName}, nil
|
||||
}
|
||||
|
||||
func (v *VoiceDetect) embedPath(path string) ([]float32, error) {
|
||||
var vec uintptr
|
||||
var dim int32
|
||||
rc := CppEmbedPath(v.ctxPtr, path, unsafe.Pointer(&vec), unsafe.Pointer(&dim))
|
||||
if rc != 0 || vec == 0 || dim <= 0 {
|
||||
return nil, v.lastErr("embed", path)
|
||||
}
|
||||
defer CppFreeVec(vec)
|
||||
// Copy out of the C-owned malloc'd buffer before freeing it. The
|
||||
// uintptr->Pointer conversion trips vet's unsafeptr check, which can't tell
|
||||
// a C heap pointer from Go-managed memory; safe here, the GC neither tracks
|
||||
// nor moves this buffer and we copy immediately.
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(vec)), int(dim)) //nolint:govet // C-owned malloc'd vector, copied out before free
|
||||
out := make([]float32, int(dim))
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// VoiceVerify embeds two clips and reports whether they are the same speaker by
|
||||
// cosine distance against a threshold. A request threshold <= 0 falls back to
|
||||
// the model-configured default (verify_threshold option, 0.25 if unset).
|
||||
func (v *VoiceDetect) VoiceVerify(req *pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error) {
|
||||
if v.ctxPtr == 0 {
|
||||
return pb.VoiceVerifyResponse{}, errors.New("voice-detect: model not loaded")
|
||||
}
|
||||
if req.Audio1 == "" || req.Audio2 == "" {
|
||||
return pb.VoiceVerifyResponse{}, errors.New("voice-detect: audio1 and audio2 are required")
|
||||
}
|
||||
|
||||
threshold := req.Threshold
|
||||
if threshold <= 0 {
|
||||
threshold = v.opts.verifyThreshold
|
||||
}
|
||||
|
||||
started := time.Now()
|
||||
var distance float32
|
||||
var verified int32
|
||||
rc := CppVerifyPaths(v.ctxPtr, req.Audio1, req.Audio2, threshold,
|
||||
unsafe.Pointer(&distance), unsafe.Pointer(&verified))
|
||||
if rc != 0 {
|
||||
return pb.VoiceVerifyResponse{}, v.lastErr("verify", req.Audio1+","+req.Audio2)
|
||||
}
|
||||
elapsedMs := float32(time.Since(started).Seconds() * 1000.0)
|
||||
|
||||
// Confidence decays linearly from 100 at distance 0 to 0 at the threshold,
|
||||
// matching the Python speaker-recognition backend's reporting.
|
||||
confidence := float32(0)
|
||||
if threshold > 0 {
|
||||
confidence = float32(math.Max(0, math.Min(100, (1.0-float64(distance)/float64(threshold))*100.0)))
|
||||
}
|
||||
|
||||
return pb.VoiceVerifyResponse{
|
||||
Verified: verified != 0,
|
||||
Distance: distance,
|
||||
Threshold: threshold,
|
||||
Confidence: confidence,
|
||||
Model: v.opts.modelName,
|
||||
ProcessingTimeMs: elapsedMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VoiceAnalyze runs the age/gender/emotion heads on a single clip. The C-API
|
||||
// always evaluates every supported head, so the request's actions filter is
|
||||
// advisory and the full analysis is returned as a single segment (the engine
|
||||
// does not produce time-bounded segments).
|
||||
func (v *VoiceDetect) VoiceAnalyze(req *pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error) {
|
||||
if v.ctxPtr == 0 {
|
||||
return pb.VoiceAnalyzeResponse{}, errors.New("voice-detect: model not loaded")
|
||||
}
|
||||
if req.Audio == "" {
|
||||
return pb.VoiceAnalyzeResponse{}, errors.New("voice-detect: audio path is required")
|
||||
}
|
||||
|
||||
ptr := CppAnalyzeJSON(v.ctxPtr, req.Audio)
|
||||
if ptr == 0 {
|
||||
return pb.VoiceAnalyzeResponse{}, v.lastErr("analyze", req.Audio)
|
||||
}
|
||||
defer CppFreeString(ptr)
|
||||
|
||||
seg, err := parseAnalyzeJSON(goStringFromCPtr(ptr))
|
||||
if err != nil {
|
||||
return pb.VoiceAnalyzeResponse{}, fmt.Errorf("voice-detect: analyze JSON for %q: %w", req.Audio, err)
|
||||
}
|
||||
return pb.VoiceAnalyzeResponse{Segments: []*pb.VoiceAnalysis{seg}}, nil
|
||||
}
|
||||
|
||||
// analyzeJSON mirrors the document returned by voicedetect_capi_analyze_path_json:
|
||||
//
|
||||
// {"age":42.0,
|
||||
// "gender":{"label":"female","female":0.88,"male":0.12},
|
||||
// "emotion":{"label":"neutral","scores":{"neutral":0.7, ...}}}
|
||||
//
|
||||
// gender is a mixed object (a "label" string plus per-class float scores), so
|
||||
// it is decoded into raw messages and split in parseAnalyzeJSON.
|
||||
type analyzeJSON struct {
|
||||
Age float32 `json:"age"`
|
||||
Gender map[string]json.RawMessage `json:"gender"`
|
||||
Emotion struct {
|
||||
Label string `json:"label"`
|
||||
Scores map[string]float32 `json:"scores"`
|
||||
} `json:"emotion"`
|
||||
}
|
||||
|
||||
// parseAnalyzeJSON maps the engine's analyze document onto a VoiceAnalysis.
|
||||
// start/end stay 0: the model emits a single whole-utterance result, not
|
||||
// time-bounded segments.
|
||||
func parseAnalyzeJSON(doc string) (*pb.VoiceAnalysis, error) {
|
||||
var a analyzeJSON
|
||||
if err := json.Unmarshal([]byte(doc), &a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seg := &pb.VoiceAnalysis{
|
||||
Age: a.Age,
|
||||
DominantEmotion: a.Emotion.Label,
|
||||
Emotion: a.Emotion.Scores,
|
||||
}
|
||||
|
||||
if len(a.Gender) > 0 {
|
||||
gender := make(map[string]float32, len(a.Gender))
|
||||
for k, raw := range a.Gender {
|
||||
if k == "label" {
|
||||
_ = json.Unmarshal(raw, &seg.DominantGender)
|
||||
continue
|
||||
}
|
||||
var score float32
|
||||
if err := json.Unmarshal(raw, &score); err == nil {
|
||||
gender[k] = score
|
||||
}
|
||||
}
|
||||
seg.Gender = gender
|
||||
}
|
||||
|
||||
return seg, nil
|
||||
}
|
||||
|
||||
// lastErr wraps the C-API's per-ctx last-error buffer into a Go error.
|
||||
func (v *VoiceDetect) lastErr(op, subject string) error {
|
||||
msg := strings.TrimSpace(CppLastError(v.ctxPtr))
|
||||
if msg == "" {
|
||||
msg = "no error detail"
|
||||
}
|
||||
return fmt.Errorf("voice-detect: %s failed for %q: %s", op, subject, msg)
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory. cptr is a
|
||||
// malloc'd buffer the caller owns; release it via CppFreeString after the copy.
|
||||
//
|
||||
// The uintptr->Pointer conversion trips vet's unsafeptr check, which can't tell
|
||||
// a C heap pointer from Go-managed memory. Safe here: the GC neither tracks nor
|
||||
// moves the buffer and we dereference it immediately to copy the bytes out.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := unsafe.Pointer(cptr) //nolint:govet // C-owned malloc'd buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
144
backend/go/voice-detect/govoicedetect_test.go
Normal file
144
backend/go/voice-detect/govoicedetect_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestVoiceDetect(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "voice-detect Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the C-API
|
||||
// bridge without spinning up the gRPC server. Records the error (the smoke
|
||||
// specs skip themselves) when libvoicedetect.so is not loadable from cwd
|
||||
// (LD_LIBRARY_PATH or a symlink in ./).
|
||||
func ensureLibLoaded() error {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("VOICEDETECT_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libvoicedetect.so"
|
||||
}
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppAbiVersion, lib, "voicedetect_capi_abi_version")
|
||||
purego.RegisterLibFunc(&CppLoad, lib, "voicedetect_capi_load")
|
||||
purego.RegisterLibFunc(&CppFree, lib, "voicedetect_capi_free")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "voicedetect_capi_last_error")
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "voicedetect_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppFreeVec, lib, "voicedetect_capi_free_vec")
|
||||
purego.RegisterLibFunc(&CppEmbedPath, lib, "voicedetect_capi_embed_path")
|
||||
purego.RegisterLibFunc(&CppEmbedPCM, lib, "voicedetect_capi_embed_pcm")
|
||||
purego.RegisterLibFunc(&CppVerifyPaths, lib, "voicedetect_capi_verify_paths")
|
||||
purego.RegisterLibFunc(&CppAnalyzeJSON, lib, "voicedetect_capi_analyze_path_json")
|
||||
})
|
||||
return libLoadErr
|
||||
}
|
||||
|
||||
var _ = Describe("parseOptions", func() {
|
||||
It("defaults verify_threshold to 0.25", func() {
|
||||
o := parseOptions(nil)
|
||||
Expect(o.verifyThreshold).To(Equal(float32(0.25)))
|
||||
Expect(o.modelName).To(Equal(""))
|
||||
})
|
||||
|
||||
It("parses verify_threshold, threshold alias and model_name", func() {
|
||||
o := parseOptions([]string{"verify_threshold:0.4", "model_name:ecapa", "unknown:x"})
|
||||
Expect(o.verifyThreshold).To(Equal(float32(0.4)))
|
||||
Expect(o.modelName).To(Equal("ecapa"))
|
||||
|
||||
o2 := parseOptions([]string{"threshold:0.3"})
|
||||
Expect(o2.verifyThreshold).To(Equal(float32(0.3)))
|
||||
})
|
||||
|
||||
It("ignores non-positive thresholds and keeps the default", func() {
|
||||
o := parseOptions([]string{"verify_threshold:0", "threshold:-1"})
|
||||
Expect(o.verifyThreshold).To(Equal(float32(0.25)))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("parseAnalyzeJSON", func() {
|
||||
It("maps age, gender label+scores and emotion label+scores", func() {
|
||||
doc := `{"age":42.0,
|
||||
"gender":{"label":"female","female":0.88,"male":0.12},
|
||||
"emotion":{"label":"neutral","scores":{"neutral":0.7,"happy":0.2,"sad":0.1}}}`
|
||||
seg, err := parseAnalyzeJSON(doc)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(seg.Age).To(BeNumerically("~", 42.0, 1e-4))
|
||||
Expect(seg.Start).To(Equal(float32(0)))
|
||||
Expect(seg.End).To(Equal(float32(0)))
|
||||
|
||||
Expect(seg.DominantGender).To(Equal("female"))
|
||||
Expect(seg.Gender).To(HaveKeyWithValue("female", BeNumerically("~", 0.88, 1e-4)))
|
||||
Expect(seg.Gender).To(HaveKeyWithValue("male", BeNumerically("~", 0.12, 1e-4)))
|
||||
// The "label" entry is consumed into DominantGender, not the score map.
|
||||
Expect(seg.Gender).ToNot(HaveKey("label"))
|
||||
|
||||
Expect(seg.DominantEmotion).To(Equal("neutral"))
|
||||
Expect(seg.Emotion).To(HaveKeyWithValue("neutral", BeNumerically("~", 0.7, 1e-4)))
|
||||
Expect(seg.Emotion).To(HaveKeyWithValue("happy", BeNumerically("~", 0.2, 1e-4)))
|
||||
})
|
||||
|
||||
It("tolerates a missing gender block", func() {
|
||||
seg, err := parseAnalyzeJSON(`{"age":30.0,"emotion":{"label":"happy","scores":{"happy":1.0}}}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(seg.DominantGender).To(Equal(""))
|
||||
Expect(seg.DominantEmotion).To(Equal("happy"))
|
||||
})
|
||||
|
||||
It("returns an error on malformed JSON", func() {
|
||||
_, err := parseAnalyzeJSON(`{not-json`)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
// The specs below exercise the real C-API end to end. They run only when both a
|
||||
// model GGUF and a test WAV are provided, and skip cleanly otherwise so the
|
||||
// suite stays green without large assets.
|
||||
var _ = Describe("VoiceDetect end-to-end", Ordered, func() {
|
||||
var (
|
||||
v *VoiceDetect
|
||||
modelPath = os.Getenv("VOICEDETECT_BACKEND_TEST_MODEL")
|
||||
wavPath = os.Getenv("VOICEDETECT_BACKEND_TEST_WAV")
|
||||
)
|
||||
|
||||
BeforeAll(func() {
|
||||
if modelPath == "" || wavPath == "" {
|
||||
Skip("set VOICEDETECT_BACKEND_TEST_MODEL and VOICEDETECT_BACKEND_TEST_WAV to run the e2e specs")
|
||||
}
|
||||
if err := ensureLibLoaded(); err != nil {
|
||||
Skip("libvoicedetect.so not loadable: " + err.Error())
|
||||
}
|
||||
v = &VoiceDetect{}
|
||||
Expect(v.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
})
|
||||
|
||||
It("embeds an audio clip", func() {
|
||||
resp, err := v.VoiceEmbed(&pb.VoiceEmbedRequest{Audio: wavPath})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Embedding).ToNot(BeEmpty())
|
||||
Expect(resp.Model).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("verifies a clip against itself as the same speaker", func() {
|
||||
resp, err := v.VoiceVerify(&pb.VoiceVerifyRequest{Audio1: wavPath, Audio2: wavPath})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Verified).To(BeTrue())
|
||||
Expect(resp.Distance).To(BeNumerically("<=", resp.Threshold))
|
||||
})
|
||||
})
|
||||
64
backend/go/voice-detect/main.go
Normal file
64
backend/go/voice-detect/main.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libvoicedetect.so via purego and registers the flat C-API entry points
|
||||
// declared in voicedetect_capi.h. The library name can be overridden with
|
||||
// VOICEDETECT_LIBRARY (mirrors the PARAKEET_LIBRARY / OMNIVOICE_LIBRARY
|
||||
// convention in the sibling backends); the default looks for the .so next to
|
||||
// this binary (resolved via LD_LIBRARY_PATH by run.sh).
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("VOICEDETECT_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libvoicedetect.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("voice-detect: dlopen %q: %w", libName, err))
|
||||
}
|
||||
|
||||
// Bound 1:1 to voicedetect_capi.h. char*/float* returns are registered as
|
||||
// uintptr so the raw pointer can be freed via the matching capi free fn.
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppAbiVersion, "voicedetect_capi_abi_version"},
|
||||
{&CppLoad, "voicedetect_capi_load"},
|
||||
{&CppFree, "voicedetect_capi_free"},
|
||||
{&CppLastError, "voicedetect_capi_last_error"},
|
||||
{&CppFreeString, "voicedetect_capi_free_string"},
|
||||
{&CppFreeVec, "voicedetect_capi_free_vec"},
|
||||
{&CppEmbedPath, "voicedetect_capi_embed_path"},
|
||||
{&CppEmbedPCM, "voicedetect_capi_embed_pcm"},
|
||||
{&CppVerifyPaths, "voicedetect_capi_verify_paths"},
|
||||
{&CppAnalyzeJSON, "voicedetect_capi_analyze_path_json"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[voice-detect] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &VoiceDetect{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
46
backend/go/voice-detect/options.go
Normal file
46
backend/go/voice-detect/options.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// defaultVerifyThreshold is the cosine-distance cutoff used when a request does
|
||||
// not set one. Matches the Python speaker-recognition backend's default so the
|
||||
// two implementations agree on verdicts out of the box.
|
||||
const defaultVerifyThreshold float32 = 0.25
|
||||
|
||||
// loadOptions holds the parsed model-level options for voice-detect.
|
||||
type loadOptions struct {
|
||||
verifyThreshold float32
|
||||
modelName string
|
||||
}
|
||||
|
||||
func splitOption(o string) (key, value string, ok bool) {
|
||||
i := strings.Index(o, ":")
|
||||
if i < 0 {
|
||||
return "", "", false
|
||||
}
|
||||
return strings.TrimSpace(o[:i]), strings.TrimSpace(o[i+1:]), true
|
||||
}
|
||||
|
||||
// parseOptions reads the backend "key:value" option slice. Unknown keys are
|
||||
// ignored. Defaults: verify_threshold 0.25, model_name derived from the file.
|
||||
func parseOptions(opts []string) loadOptions {
|
||||
o := loadOptions{verifyThreshold: defaultVerifyThreshold}
|
||||
for _, oo := range opts {
|
||||
key, value, ok := splitOption(oo)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "verify_threshold", "threshold":
|
||||
if f, err := strconv.ParseFloat(value, 32); err == nil && f > 0 {
|
||||
o.verifyThreshold = float32(f)
|
||||
}
|
||||
case "model_name":
|
||||
o.modelName = value
|
||||
}
|
||||
}
|
||||
return o
|
||||
}
|
||||
68
backend/go/voice-detect/package.sh
Executable file
68
backend/go/voice-detect/package.sh
Executable file
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Bundle the voice-detect-grpc binary, libvoicedetect.so, the core runtime libs
|
||||
# (libc/libstdc++/libgomp + ld.so) and the GPU runtime for the active BUILD_TYPE
|
||||
# so the package is self-contained. Mirrors backend/go/parakeet-cpp/package.sh;
|
||||
# run.sh routes the (CGO_ENABLED=0) binary through lib/ld.so so the packaged libc
|
||||
# is used instead of the host's.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/voice-detect-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libvoicedetect.so + any soname symlinks. purego.Dlopen resolves it via
|
||||
# LD_LIBRARY_PATH, which run.sh points at lib/.
|
||||
cp -avf "$CURDIR"/libvoicedetect.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libvoicedetect.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Detect architecture and copy the core runtime libs libvoicedetect.so links
|
||||
# against, plus the matching dynamic loader as lib/ld.so.
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 "$CURDIR/package/lib/ld.so"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 "$CURDIR/package/lib/ld.so"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 "$CURDIR/package/lib/libc.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 "$CURDIR/package/lib/libgcc_s.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 "$CURDIR/package/lib/libstdc++.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 "$CURDIR/package/lib/libm.so.6"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 "$CURDIR/package/lib/libgomp.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 "$CURDIR/package/lib/libdl.so.2"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 "$CURDIR/package/lib/librt.so.1"
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 "$CURDIR/package/lib/libpthread.so.0"
|
||||
elif [ "$(uname -s)" = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries (CUDA/ROCm/Intel/Vulkan loader + ICDs + drivers) based on
|
||||
# BUILD_TYPE so the backend can reach the GPU without the runtime base image
|
||||
# shipping those drivers.
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
16
backend/go/voice-detect/run.sh
Executable file
16
backend/go/voice-detect/run.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the packaged
|
||||
# libc / libstdc++ are used instead of the host's (matches the whisper /
|
||||
# parakeet backends' runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/voice-detect-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/voice-detect-grpc" "$@"
|
||||
14
backend/go/voice-detect/test.sh
Executable file
14
backend/go/voice-detect/test.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
cd "$CURDIR"
|
||||
|
||||
echo "Running voice-detect backend tests..."
|
||||
|
||||
# The pure-Go parsing specs always run. The embed/verify/analyze smoke specs run
|
||||
# only when a model + WAV are provided via VOICEDETECT_BACKEND_TEST_MODEL and
|
||||
# VOICEDETECT_BACKEND_TEST_WAV; otherwise they auto-skip.
|
||||
LD_LIBRARY_PATH="$CURDIR:${LD_LIBRARY_PATH:-}" go test -v -timeout 1200s .
|
||||
|
||||
echo "voice-detect tests completed."
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=43d78af5be58f41d6ffbc227d608f104577741ea
|
||||
WHISPER_CPP_VERSION?=0ae02cdb2c7317b50991367c165736ce42ed96ac
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -13,8 +13,14 @@ if [ "$(uname)" != "Darwin" ]; then
|
||||
fi
|
||||
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
# macOS: single dylib variant (Metal or Accelerate)
|
||||
LIBRARY="$CURDIR/libgowhisper-fallback.dylib"
|
||||
# macOS: single fallback variant (Metal/Accelerate). The cmake build emits a
|
||||
# Mach-O named .so, but tolerate .dylib too — pick whichever exists so the Go
|
||||
# loader doesn't panic on a hardcoded name that isn't on disk.
|
||||
if [ -e "$CURDIR/libgowhisper-fallback.dylib" ]; then
|
||||
LIBRARY="$CURDIR/libgowhisper-fallback.dylib"
|
||||
else
|
||||
LIBRARY="$CURDIR/libgowhisper-fallback.so"
|
||||
fi
|
||||
export DYLD_LIBRARY_PATH="$CURDIR"/lib:$DYLD_LIBRARY_PATH
|
||||
else
|
||||
LIBRARY="$CURDIR/libgowhisper-fallback.so"
|
||||
|
||||
@@ -209,6 +209,78 @@
|
||||
nvidia-cuda-12: "cuda12-ced"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-ced"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-ced"
|
||||
- &voicedetect
|
||||
name: "voice-detect"
|
||||
alias: "voice-detect"
|
||||
license: mit
|
||||
icon: https://avatars.githubusercontent.com/u/95302084
|
||||
description: |
|
||||
voice-detect speaker recognition and voice analysis.
|
||||
voice-detect.cpp is a C++/ggml engine that produces L2-normalised
|
||||
speaker embeddings (ECAPA-TDNN, WeSpeaker ResNet34, 3D-Speaker
|
||||
ERes2Net, CAM++) for voice verification and 1:N identification, plus
|
||||
a wav2vec2 age / gender / emotion analysis head. It replaces the
|
||||
Python speaker-recognition backend and is exposed through the Voice*
|
||||
gRPC rpcs and the /v1/voice/* REST endpoints. It runs on CPU, NVIDIA
|
||||
CUDA, AMD ROCm/HIP, Intel SYCL, Vulkan and NVIDIA Jetson (L4T) targets.
|
||||
urls:
|
||||
- https://github.com/mudler/voice-detect.cpp
|
||||
tags:
|
||||
- voice-recognition
|
||||
- speaker-verification
|
||||
- speaker-embedding
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-voice-detect"
|
||||
nvidia: "cuda12-voice-detect"
|
||||
intel: "intel-sycl-f16-voice-detect"
|
||||
metal: "metal-voice-detect"
|
||||
amd: "rocm-voice-detect"
|
||||
vulkan: "vulkan-voice-detect"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-voice-detect"
|
||||
nvidia-cuda-13: "cuda13-voice-detect"
|
||||
nvidia-cuda-12: "cuda12-voice-detect"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-voice-detect"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-voice-detect"
|
||||
- &facedetect
|
||||
name: "face-detect"
|
||||
alias: "face-detect"
|
||||
license: mit
|
||||
icon: https://avatars.githubusercontent.com/u/95302084
|
||||
description: |
|
||||
face-detect face detection, embedding, verification and analysis.
|
||||
face-detect.cpp is a C++/ggml engine that runs SCRFD / YuNet face
|
||||
detection and ArcFace / SFace 512-d (or 128-d) L2-normalised face
|
||||
embeddings for verification and 1:N identification, plus a landmark /
|
||||
age / gender analysis head. It replaces the Python insightface backend
|
||||
and is exposed through the Embedding, Detect and Face* gRPC rpcs and
|
||||
the /v1/face/* REST endpoints. It runs on CPU, NVIDIA CUDA, AMD
|
||||
ROCm/HIP, Intel SYCL, Vulkan and NVIDIA Jetson (L4T) targets.
|
||||
urls:
|
||||
- https://github.com/mudler/face-detect.cpp
|
||||
tags:
|
||||
- face-recognition
|
||||
- face-verification
|
||||
- face-embedding
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-face-detect"
|
||||
nvidia: "cuda12-face-detect"
|
||||
intel: "intel-sycl-f16-face-detect"
|
||||
metal: "metal-face-detect"
|
||||
amd: "rocm-face-detect"
|
||||
vulkan: "vulkan-face-detect"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-face-detect"
|
||||
nvidia-cuda-13: "cuda13-face-detect"
|
||||
nvidia-cuda-12: "cuda12-face-detect"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-face-detect"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-face-detect"
|
||||
- &voxtral
|
||||
name: "voxtral"
|
||||
alias: "voxtral"
|
||||
@@ -1356,7 +1428,6 @@
|
||||
intel: "intel-fish-speech"
|
||||
amd: "rocm-fish-speech"
|
||||
nvidia-l4t: "nvidia-l4t-fish-speech"
|
||||
metal: "metal-fish-speech"
|
||||
default: "cpu-fish-speech"
|
||||
nvidia-cuda-13: "cuda13-fish-speech"
|
||||
nvidia-cuda-12: "cuda12-fish-speech"
|
||||
@@ -2828,6 +2899,236 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ced"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-ced
|
||||
## voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "voice-detect-development"
|
||||
capabilities:
|
||||
default: "cpu-voice-detect-development"
|
||||
nvidia: "cuda12-voice-detect-development"
|
||||
intel: "intel-sycl-f16-voice-detect-development"
|
||||
metal: "metal-voice-detect-development"
|
||||
amd: "rocm-voice-detect-development"
|
||||
vulkan: "vulkan-voice-detect-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-voice-detect-development"
|
||||
nvidia-cuda-13: "cuda13-voice-detect-development"
|
||||
nvidia-cuda-12: "cuda12-voice-detect-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-voice-detect-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-voice-detect-development"
|
||||
- !!merge <<: *voicedetect
|
||||
name: "nvidia-l4t-arm64-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "nvidia-l4t-arm64-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "cuda13-nvidia-l4t-arm64-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "cuda13-nvidia-l4t-arm64-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "cpu-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "cpu-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "metal-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "metal-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "cuda12-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "cuda12-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "rocm-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "rocm-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "intel-sycl-f32-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "intel-sycl-f32-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "intel-sycl-f16-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "intel-sycl-f16-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "vulkan-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "vulkan-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "cuda13-voice-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-voice-detect
|
||||
- !!merge <<: *voicedetect
|
||||
name: "cuda13-voice-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-voice-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-voice-detect
|
||||
## face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "face-detect-development"
|
||||
capabilities:
|
||||
default: "cpu-face-detect-development"
|
||||
nvidia: "cuda12-face-detect-development"
|
||||
intel: "intel-sycl-f16-face-detect-development"
|
||||
metal: "metal-face-detect-development"
|
||||
amd: "rocm-face-detect-development"
|
||||
vulkan: "vulkan-face-detect-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-face-detect-development"
|
||||
nvidia-cuda-13: "cuda13-face-detect-development"
|
||||
nvidia-cuda-12: "cuda12-face-detect-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-face-detect-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-face-detect-development"
|
||||
- !!merge <<: *facedetect
|
||||
name: "nvidia-l4t-arm64-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "nvidia-l4t-arm64-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "cuda13-nvidia-l4t-arm64-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "cuda13-nvidia-l4t-arm64-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "cpu-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "cpu-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "metal-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "metal-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "cuda12-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "cuda12-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "rocm-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "rocm-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "intel-sycl-f32-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "intel-sycl-f32-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "intel-sycl-f16-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "intel-sycl-f16-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "vulkan-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "vulkan-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "cuda13-face-detect"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-face-detect
|
||||
- !!merge <<: *facedetect
|
||||
name: "cuda13-face-detect-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-face-detect"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-face-detect
|
||||
## stablediffusion-ggml
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "cpu-stablediffusion-ggml"
|
||||
@@ -4870,7 +5171,6 @@
|
||||
intel: "intel-fish-speech-development"
|
||||
amd: "rocm-fish-speech-development"
|
||||
nvidia-l4t: "nvidia-l4t-fish-speech-development"
|
||||
metal: "metal-fish-speech-development"
|
||||
default: "cpu-fish-speech-development"
|
||||
nvidia-cuda-13: "cuda13-fish-speech-development"
|
||||
nvidia-cuda-12: "cuda12-fish-speech-development"
|
||||
@@ -4946,16 +5246,6 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-fish-speech"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-fish-speech
|
||||
- !!merge <<: *fish-speech
|
||||
name: "metal-fish-speech"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-fish-speech"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-fish-speech
|
||||
- !!merge <<: *fish-speech
|
||||
name: "metal-fish-speech-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-fish-speech"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-fish-speech
|
||||
## faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "faster-qwen3-tts-development"
|
||||
|
||||
@@ -13,6 +13,17 @@ fi
|
||||
# fish-speech uses pyrootutils which requires a .project-root marker
|
||||
touch "${backend_dir}/.project-root"
|
||||
|
||||
# On darwin arm64 the transitive `tokenizers` dep compiles its Rust extension
|
||||
# from source (Linux uses prebuilt manylinux wheels, so it never compiles
|
||||
# there). The pinned tokenizers crate that fish-speech's stack resolves to
|
||||
# contains a `&T` -> `&mut T` cast that trips the now-deny-by-default
|
||||
# `invalid_reference_casting` lint in the macOS runner's newer Rust toolchain,
|
||||
# breaking the build (seen in the v4.5.5 release CI fish-speech darwin/metal
|
||||
# job). Allow that lint so the unchanged third-party crate compiles as before.
|
||||
# Append rather than clobber any pre-existing RUSTFLAGS; harmless on Linux
|
||||
# where no Rust compile happens.
|
||||
export RUSTFLAGS="${RUSTFLAGS:-} -A invalid_reference_casting"
|
||||
|
||||
installRequirements
|
||||
|
||||
# Clone fish-speech source (the pip package doesn't include inference modules)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
torch
|
||||
torchaudio
|
||||
@@ -3,4 +3,5 @@ protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
pip
|
||||
chardet
|
||||
chardet
|
||||
click
|
||||
|
||||
@@ -147,9 +147,25 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
tool_calls = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
else:
|
||||
# OpenAI wire format carries function.arguments as a
|
||||
# JSON-encoded string, but chat templates (e.g. Qwen3)
|
||||
# iterate over it as a mapping. The vllm backend
|
||||
# already parses arguments before applying the chat
|
||||
# template (PR #10256); mirror that here so the
|
||||
# sglang backend works with the same wire format.
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function") if isinstance(tc, dict) else None
|
||||
if isinstance(func, dict) and isinstance(func.get("arguments"), str):
|
||||
try:
|
||||
func["arguments"] = json.loads(func["arguments"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
d["tool_calls"] = tool_calls
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ if [ "$(uname -s)" = "Darwin" ]; then
|
||||
# can rewrite it. Darwin therefore follows vllm-metal and can lag the Linux
|
||||
# vllm pin (requirements-cublas13-after.txt, bumped independently against
|
||||
# vllm/vllm) until vllm-metal supports a newer vLLM.
|
||||
VLLM_METAL_VERSION="v0.3.0.dev20260622062346"
|
||||
VLLM_METAL_VERSION="v0.3.0.dev20260628073537"
|
||||
|
||||
# The coupled vLLM source version is whatever this vllm-metal release builds
|
||||
# against -- it declares it in its own installer as `vllm_v=`. Derive it from
|
||||
|
||||
@@ -429,7 +429,7 @@ func (l *Launcher) CheckForUpdates() (bool, string, error) {
|
||||
}
|
||||
|
||||
// DownloadUpdate downloads the latest version
|
||||
func (l *Launcher) DownloadUpdate(version string, progressCallback func(float64)) error {
|
||||
func (l *Launcher) DownloadUpdate(version string, progressCallback func(downloaded, total int64)) error {
|
||||
return l.releaseManager.DownloadRelease(version, progressCallback)
|
||||
}
|
||||
|
||||
@@ -486,7 +486,6 @@ func (l *Launcher) showDownloadLocalAIDialog() {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create a standalone window for the download dialog
|
||||
dialogWindow := l.app.NewWindow("LocalAI Installation Required")
|
||||
dialogWindow.Resize(fyne.NewSize(500, 350))
|
||||
dialogWindow.CenterOnScreen()
|
||||
dialogWindow.SetCloseIntercept(func() {
|
||||
dialogWindow.Close()
|
||||
@@ -548,6 +547,7 @@ func (l *Launcher) showDownloadLocalAIDialog() {
|
||||
)
|
||||
|
||||
dialogWindow.SetContent(content)
|
||||
resizeToContent(dialogWindow, content)
|
||||
dialogWindow.Show()
|
||||
})
|
||||
}
|
||||
@@ -621,88 +621,134 @@ func (l *Launcher) showDownloadError(title, message string) {
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a standalone progress window for downloading LocalAI
|
||||
// after a fresh install (no LocalAI binary present yet).
|
||||
func (l *Launcher) showDownloadProgress(version, title string) {
|
||||
l.showDownloadProgressWindow(version, title, func(win fyne.Window) {
|
||||
dialog.ShowConfirm("Installation Complete",
|
||||
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
|
||||
func(bool) {
|
||||
win.Close()
|
||||
l.updateStatus("LocalAI installed successfully")
|
||||
if l.systray != nil {
|
||||
l.systray.recreateMenu()
|
||||
}
|
||||
}, win)
|
||||
})
|
||||
}
|
||||
|
||||
// showDownloadProgressWindow renders the download progress popup shared by every
|
||||
// "download/upgrade LocalAI" entry point. It owns the progress bar, the
|
||||
// human-readable byte readout, resume-aware retry, and content-fit window
|
||||
// sizing so the behaviour stays identical everywhere. onSuccess runs (on the UI
|
||||
// goroutine) once the download verifies, and is responsible for the success
|
||||
// dialog and any follow-up; the window is passed in so it can be parented/closed.
|
||||
func (l *Launcher) showDownloadProgressWindow(version, title string, onSuccess func(win fyne.Window)) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create progress window
|
||||
progressWindow := l.app.NewWindow("Downloading LocalAI")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
progressWindow.SetCloseIntercept(func() {
|
||||
progressWindow.Close()
|
||||
})
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label. Truncate with an ellipsis so a long "Download failed:
|
||||
// <url>" message can't stretch the window (and progress bar) to fit the
|
||||
// whole error on one line; the full error is shown in the dialog below.
|
||||
// whole error on one line.
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
statusLabel.Truncation = fyne.TextTruncateEllipsis
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := l.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
l.app.OpenURL(releaseNotesURL)
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
// Retry button: hidden until a download fails. GitHub downloads are
|
||||
// flaky, and the underlying download resumes from the partial file, so
|
||||
// a retry continues where it left off rather than starting over.
|
||||
retryButton := widget.NewButton("Retry", nil)
|
||||
retryButton.Importance = widget.HighImportance
|
||||
retryButton.Hide()
|
||||
|
||||
buttonRow := container.NewHBox(releaseNotesButton, retryButton)
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel(title),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
buttonRow,
|
||||
)
|
||||
progressWindow.SetContent(content)
|
||||
resizeToContent(progressWindow, content)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
var startDownload func()
|
||||
startDownload = func() {
|
||||
retryButton.Hide()
|
||||
progressBar.SetValue(0)
|
||||
statusLabel.SetText("Preparing download...")
|
||||
resizeToContent(progressWindow, content)
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := l.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
go func() {
|
||||
err := l.DownloadUpdate(version, func(downloaded, total int64) {
|
||||
fyne.Do(func() {
|
||||
if total > 0 {
|
||||
progressBar.SetValue(float64(downloaded) / float64(total))
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading… %s / %s", formatBytes(downloaded), formatBytes(total)))
|
||||
} else {
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading… %s", formatBytes(downloaded)))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
retryButton.Show()
|
||||
resizeToContent(progressWindow, content)
|
||||
return
|
||||
}
|
||||
progressBar.SetValue(1.0)
|
||||
statusLabel.SetText("Download complete")
|
||||
onSuccess(progressWindow)
|
||||
})
|
||||
}()
|
||||
}
|
||||
retryButton.OnTapped = startDownload
|
||||
|
||||
// Show success dialog
|
||||
dialog.ShowConfirm("Installation Complete",
|
||||
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
|
||||
func(close bool) {
|
||||
progressWindow.Close()
|
||||
// Update status and refresh systray menu
|
||||
l.updateStatus("LocalAI installed successfully")
|
||||
|
||||
if l.systray != nil {
|
||||
l.systray.recreateMenu()
|
||||
}
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
}()
|
||||
progressWindow.Show()
|
||||
startDownload()
|
||||
})
|
||||
}
|
||||
|
||||
// resizeToContent sizes a window to fit its content (with a sane minimum width)
|
||||
// so the dialog doesn't show a large blank gap below the last widget.
|
||||
func resizeToContent(w fyne.Window, content fyne.CanvasObject) {
|
||||
size := content.MinSize()
|
||||
if size.Width < 400 {
|
||||
size.Width = 400
|
||||
}
|
||||
w.Resize(size)
|
||||
}
|
||||
|
||||
// formatBytes renders a byte count as a human-readable size (e.g. "12.3 MB").
|
||||
func formatBytes(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// monitorLogs monitors the output of LocalAI and adds it to the log buffer
|
||||
func (l *Launcher) monitorLogs(reader io.Reader, prefix string) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -50,6 +51,12 @@ type ReleaseManager struct {
|
||||
ChecksumsPath string
|
||||
// MetadataPath is where version metadata is stored
|
||||
MetadataPath string
|
||||
// BaseDownloadURL is the base URL release assets are downloaded from
|
||||
// (defaults to https://github.com; overridable for testing)
|
||||
BaseDownloadURL string
|
||||
// RetryBackoff is the base wait between download attempts; the Nth retry
|
||||
// waits N*RetryBackoff (defaults to 1s; lowered in tests)
|
||||
RetryBackoff time.Duration
|
||||
// HTTPClient is the HTTP client used for downloads
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
@@ -62,28 +69,94 @@ func NewReleaseManager() *ReleaseManager {
|
||||
metadataPath := filepath.Join(homeDir, ".localai", "metadata")
|
||||
|
||||
return &ReleaseManager{
|
||||
GitHubOwner: "mudler",
|
||||
GitHubRepo: "LocalAI",
|
||||
BinaryPath: binaryPath,
|
||||
CurrentVersion: internal.PrintableVersion(),
|
||||
ChecksumsPath: checksumsPath,
|
||||
MetadataPath: metadataPath,
|
||||
HTTPClient: httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects()),
|
||||
GitHubOwner: "mudler",
|
||||
GitHubRepo: "LocalAI",
|
||||
BinaryPath: binaryPath,
|
||||
CurrentVersion: internal.PrintableVersion(),
|
||||
ChecksumsPath: checksumsPath,
|
||||
MetadataPath: metadataPath,
|
||||
BaseDownloadURL: "https://github.com",
|
||||
RetryBackoff: 1 * time.Second,
|
||||
HTTPClient: httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects()),
|
||||
}
|
||||
}
|
||||
|
||||
// GetLatestRelease fetches the latest release information from GitHub
|
||||
// GetLatestRelease resolves the latest LocalAI release.
|
||||
//
|
||||
// It first follows the github.com "releases/latest" redirect, which reveals the
|
||||
// latest tag in the final URL and—crucially—is NOT subject to the
|
||||
// 60-requests/hour unauthenticated rate limit of api.github.com. That limit is
|
||||
// per-IP, so on shared/NAT/CGNAT/cloud addresses the API returns 403 almost
|
||||
// immediately (e.g. on a fresh install with no LocalAI present yet). The
|
||||
// redirect avoids that entirely. The richer JSON API is kept only as a fallback.
|
||||
//
|
||||
// Only the version is consumed by callers, so the redirect's tag is sufficient.
|
||||
func (rm *ReleaseManager) GetLatestRelease() (*Release, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo)
|
||||
version, redirectErr := rm.latestVersionFromRedirect()
|
||||
if redirectErr == nil {
|
||||
return &Release{Version: version}, nil
|
||||
}
|
||||
log.Printf("Could not resolve latest version via release redirect (%v); falling back to GitHub API", redirectErr)
|
||||
|
||||
release, apiErr := rm.latestReleaseFromAPI()
|
||||
if apiErr != nil {
|
||||
// Surface both failures so a rate-limited API doesn't mask the (usually
|
||||
// more relevant) redirect error.
|
||||
return nil, fmt.Errorf("failed to fetch latest release: %v (redirect: %v)", apiErr, redirectErr)
|
||||
}
|
||||
return release, nil
|
||||
}
|
||||
|
||||
// latestVersionFromRedirect returns the latest tag by following the github.com
|
||||
// "releases/latest" redirect to ".../releases/tag/<tag>".
|
||||
func (rm *ReleaseManager) latestVersionFromRedirect() (string, error) {
|
||||
url := fmt.Sprintf("%s/%s/%s/releases/latest", rm.BaseDownloadURL, rm.GitHubOwner, rm.GitHubRepo)
|
||||
|
||||
resp, err := rm.HTTPClient.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status %s", resp.Status)
|
||||
}
|
||||
|
||||
// After the redirect is followed, the final request URL is the tag page.
|
||||
version := path.Base(resp.Request.URL.Path)
|
||||
if version == "" || version == "." || version == "latest" {
|
||||
return "", fmt.Errorf("could not determine version from %s", resp.Request.URL.String())
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// latestReleaseFromAPI fetches the latest release JSON from api.github.com. This
|
||||
// is the fallback path; it is rate-limited unless GITHUB_TOKEN is set.
|
||||
func (rm *ReleaseManager) latestReleaseFromAPI() (*Release, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
// An optional token lifts the unauthenticated 60/hour limit to 5000/hour.
|
||||
if token := os.Getenv("GITHUB_TOKEN"); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
resp, err := rm.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch latest release: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch latest release: status %d", resp.StatusCode)
|
||||
if (resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests) &&
|
||||
resp.Header.Get("X-RateLimit-Remaining") == "0" {
|
||||
return nil, fmt.Errorf("GitHub API rate limit exceeded (status %d); retry later or set GITHUB_TOKEN to raise the limit", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse the JSON response properly
|
||||
@@ -106,7 +179,7 @@ func (rm *ReleaseManager) GetLatestRelease() (*Release, error) {
|
||||
}
|
||||
|
||||
// DownloadRelease downloads a specific version of LocalAI
|
||||
func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(float64)) error {
|
||||
func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(downloaded, total int64)) error {
|
||||
// Ensure the binary directory exists
|
||||
if err := os.MkdirAll(rm.BinaryPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create binary directory: %w", err)
|
||||
@@ -117,16 +190,16 @@ func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(
|
||||
localPath := filepath.Join(rm.BinaryPath, "local-ai")
|
||||
|
||||
// Download the binary
|
||||
downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
|
||||
rm.GitHubOwner, rm.GitHubRepo, version, binaryName)
|
||||
downloadURL := fmt.Sprintf("%s/%s/%s/releases/download/%s/%s",
|
||||
rm.BaseDownloadURL, rm.GitHubOwner, rm.GitHubRepo, version, binaryName)
|
||||
|
||||
if err := rm.downloadFile(downloadURL, localPath, progressCallback); err != nil {
|
||||
return fmt.Errorf("failed to download binary: %w", err)
|
||||
}
|
||||
|
||||
// Download and verify checksums
|
||||
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt",
|
||||
rm.GitHubOwner, rm.GitHubRepo, version, version)
|
||||
checksumURL := fmt.Sprintf("%s/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt",
|
||||
rm.BaseDownloadURL, rm.GitHubOwner, rm.GitHubRepo, version, version)
|
||||
|
||||
checksumPath := filepath.Join(rm.BinaryPath, "checksums.txt")
|
||||
manualChecksumPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version))
|
||||
@@ -154,6 +227,10 @@ func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(
|
||||
// Verify the checksum if we have a checksum file
|
||||
if _, err := os.Stat(checksumPath); err == nil {
|
||||
if err := rm.VerifyChecksum(localPath, checksumPath, binaryName); err != nil {
|
||||
// Discard the corrupt binary (and any leftover partial) so the next
|
||||
// retry starts from a clean slate rather than resuming corruption.
|
||||
os.Remove(localPath)
|
||||
os.Remove(localPath + ".part")
|
||||
return fmt.Errorf("checksum verification failed: %w", err)
|
||||
}
|
||||
log.Printf("Checksum verification successful")
|
||||
@@ -196,44 +273,88 @@ func (rm *ReleaseManager) GetBinaryName(version string) string {
|
||||
}
|
||||
|
||||
// downloadFile downloads a file from a URL to a local path with optional progress callback
|
||||
func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(float64)) error {
|
||||
func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(downloaded, total int64)) error {
|
||||
return rm.downloadFileWithRetry(url, filepath, progressCallback, 3)
|
||||
}
|
||||
|
||||
// downloadFileWithRetry downloads a file from a URL with retry logic
|
||||
func (rm *ReleaseManager) downloadFileWithRetry(url, filepath string, progressCallback func(float64), maxRetries int) error {
|
||||
// downloadFileWithRetry downloads a file with retry and HTTP Range resume.
|
||||
//
|
||||
// The body is streamed to "<dest>.part" and only renamed to dest on success, so
|
||||
// a dropped connection leaves a partial file that the next attempt continues via
|
||||
// a "Range: bytes=N-" request instead of restarting from zero. This matters for
|
||||
// GitHub release downloads, which are large and flaky.
|
||||
func (rm *ReleaseManager) downloadFileWithRetry(url, dest string, progressCallback func(downloaded, total int64), maxRetries int) error {
|
||||
partPath := dest + ".part"
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 1 {
|
||||
log.Printf("Retrying download (attempt %d/%d): %s", attempt, maxRetries, url)
|
||||
time.Sleep(time.Duration(attempt) * time.Second)
|
||||
time.Sleep(time.Duration(attempt) * rm.RetryBackoff)
|
||||
}
|
||||
|
||||
resp, err := rm.HTTPClient.Get(url)
|
||||
// Resume from however much we already have on disk.
|
||||
var offset int64
|
||||
if fi, err := os.Stat(partPath); err == nil {
|
||||
offset = fi.Size()
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if offset > 0 {
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", offset))
|
||||
}
|
||||
|
||||
resp, err := rm.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
// Server ignored the Range (or we had nothing): start fresh.
|
||||
offset = 0
|
||||
case http.StatusPartialContent:
|
||||
// Resume: append to the existing partial file.
|
||||
case http.StatusRequestedRangeNotSatisfiable:
|
||||
// Stale or already-complete partial: discard and restart fresh.
|
||||
resp.Body.Close()
|
||||
os.Remove(partPath)
|
||||
lastErr = fmt.Errorf("partial download no longer valid (status %s), restarting", resp.Status)
|
||||
continue
|
||||
default:
|
||||
resp.Body.Close()
|
||||
lastErr = fmt.Errorf("bad status: %s", resp.Status)
|
||||
continue
|
||||
}
|
||||
|
||||
out, err := os.Create(filepath)
|
||||
var out *os.File
|
||||
if offset > 0 {
|
||||
out, err = os.OpenFile(partPath, os.O_WRONLY|os.O_APPEND, 0644)
|
||||
} else {
|
||||
out, err = os.Create(partPath)
|
||||
}
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a progress reader if callback is provided
|
||||
// On a 206 the Content-Length is the remaining bytes, so the full size
|
||||
// is what we already have plus what's still to come.
|
||||
total := resp.ContentLength
|
||||
if offset > 0 && total > 0 {
|
||||
total += offset
|
||||
}
|
||||
|
||||
var reader io.Reader = resp.Body
|
||||
if progressCallback != nil && resp.ContentLength > 0 {
|
||||
if progressCallback != nil && total > 0 {
|
||||
reader = &progressReader{
|
||||
Reader: resp.Body,
|
||||
Total: resp.ContentLength,
|
||||
Total: total,
|
||||
Current: offset,
|
||||
Callback: progressCallback,
|
||||
}
|
||||
}
|
||||
@@ -243,11 +364,14 @@ func (rm *ReleaseManager) downloadFileWithRetry(url, filepath string, progressCa
|
||||
out.Close()
|
||||
|
||||
if err != nil {
|
||||
// Keep the partial file so the next attempt can resume from it.
|
||||
lastErr = err
|
||||
os.Remove(filepath)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.Rename(partPath, dest); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -322,20 +446,21 @@ func (rm *ReleaseManager) saveVersionMetadata(version string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// progressReader wraps an io.Reader to provide download progress
|
||||
// progressReader wraps an io.Reader to provide download progress as a
|
||||
// (downloaded, total) byte count so callers can render both a progress bar and
|
||||
// a human-readable size.
|
||||
type progressReader struct {
|
||||
io.Reader
|
||||
Total int64
|
||||
Current int64
|
||||
Callback func(float64)
|
||||
Callback func(downloaded, total int64)
|
||||
}
|
||||
|
||||
func (pr *progressReader) Read(p []byte) (int, error) {
|
||||
n, err := pr.Reader.Read(p)
|
||||
pr.Current += int64(n)
|
||||
if pr.Callback != nil {
|
||||
progress := float64(pr.Current) / float64(pr.Total)
|
||||
pr.Callback(progress)
|
||||
pr.Callback(pr.Current, pr.Total)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
package launcher_test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -178,4 +186,221 @@ var _ = Describe("ReleaseManager", func() {
|
||||
Expect(err.Error()).To(ContainSubstring("checksum not found"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("DownloadRelease resume and retry", func() {
|
||||
var (
|
||||
version string
|
||||
binaryName string
|
||||
content []byte
|
||||
checksums string
|
||||
finalPath string
|
||||
partPath string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
version = "v9.9.9"
|
||||
binaryName = rm.GetBinaryName(version)
|
||||
|
||||
// Deterministic, non-trivial content so resume/append bugs surface.
|
||||
content = make([]byte, 4096)
|
||||
for i := range content {
|
||||
content[i] = byte(i % 251)
|
||||
}
|
||||
sum := sha256.Sum256(content)
|
||||
checksums = fmt.Sprintf("%s %s\n", hex.EncodeToString(sum[:]), binaryName)
|
||||
|
||||
finalPath = filepath.Join(tempDir, "local-ai")
|
||||
partPath = finalPath + ".part"
|
||||
|
||||
// Isolate the persistent checksum/metadata dirs to the temp dir so
|
||||
// the test never touches the real ~/.localai and existing checksum
|
||||
// files don't short-circuit the download.
|
||||
rm.ChecksumsPath = filepath.Join(tempDir, "checksums")
|
||||
rm.MetadataPath = filepath.Join(tempDir, "metadata")
|
||||
rm.GitHubOwner = "owner"
|
||||
rm.GitHubRepo = "repo"
|
||||
rm.RetryBackoff = time.Millisecond
|
||||
|
||||
Expect(os.MkdirAll(tempDir, 0755)).To(Succeed())
|
||||
})
|
||||
|
||||
It("resumes from a partial .part file using a Range request", func() {
|
||||
Expect(os.WriteFile(partPath, content[:1024], 0644)).To(Succeed())
|
||||
|
||||
var mu sync.Mutex
|
||||
sawRange := false
|
||||
binBytesServed := 0
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
if rangeHdr := r.Header.Get("Range"); rangeHdr != "" {
|
||||
var start int
|
||||
_, _ = fmt.Sscanf(rangeHdr, "bytes=%d-", &start)
|
||||
mu.Lock()
|
||||
sawRange = true
|
||||
mu.Unlock()
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, len(content)-1, len(content)))
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
n, _ := w.Write(content[start:])
|
||||
mu.Lock()
|
||||
binBytesServed += n
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
n, _ := w.Write(content)
|
||||
mu.Lock()
|
||||
binBytesServed += n
|
||||
mu.Unlock()
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
err := rm.DownloadRelease(version, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got, err := os.ReadFile(finalPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(content))
|
||||
Expect(sawRange).To(BeTrue(), "expected the download to resume with a Range request")
|
||||
Expect(binBytesServed).To(Equal(len(content)-1024), "expected only the remaining bytes to be served")
|
||||
Expect(partPath).ToNot(BeAnExistingFile())
|
||||
})
|
||||
|
||||
It("starts fresh when the server ignores the Range header (200)", func() {
|
||||
// A stale/garbage partial that must NOT be appended to.
|
||||
Expect(os.WriteFile(partPath, []byte("garbage-garbage-garbage"), 0644)).To(Succeed())
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
// Ignore any Range and always serve the full body.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(content)
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
err := rm.DownloadRelease(version, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got, err := os.ReadFile(finalPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(content))
|
||||
})
|
||||
|
||||
It("restarts the download when the partial is stale (416)", func() {
|
||||
// Oversized partial -> requested Range start is beyond the content.
|
||||
Expect(os.WriteFile(partPath, make([]byte, len(content)+10), 0644)).To(Succeed())
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
if rangeHdr := r.Header.Get("Range"); rangeHdr != "" {
|
||||
var start int
|
||||
_, _ = fmt.Sscanf(rangeHdr, "bytes=%d-", &start)
|
||||
if start >= len(content) {
|
||||
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, len(content)-1, len(content)))
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
_, _ = w.Write(content[start:])
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(content)
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
err := rm.DownloadRelease(version, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got, err := os.ReadFile(finalPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(Equal(content))
|
||||
})
|
||||
|
||||
It("removes the downloaded file when checksum verification fails", func() {
|
||||
bad := []byte("this is definitely not the expected binary content")
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
// Checksums are for `content`, but we serve `bad`.
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bad)
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
err := rm.DownloadRelease(version, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("checksum"))
|
||||
Expect(finalPath).ToNot(BeAnExistingFile())
|
||||
Expect(partPath).ToNot(BeAnExistingFile())
|
||||
})
|
||||
|
||||
It("reports progress as downloaded and total byte counts", func() {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "checksums.txt") {
|
||||
_, _ = w.Write([]byte(checksums))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(content)
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
|
||||
var mu sync.Mutex
|
||||
var lastDownloaded, lastTotal int64
|
||||
err := rm.DownloadRelease(version, func(downloaded, total int64) {
|
||||
mu.Lock()
|
||||
lastDownloaded = downloaded
|
||||
lastTotal = total
|
||||
mu.Unlock()
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(lastTotal).To(Equal(int64(len(content))))
|
||||
Expect(lastDownloaded).To(Equal(int64(len(content))))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetLatestRelease", func() {
|
||||
It("resolves the latest version from the releases/latest redirect", func() {
|
||||
// The github.com redirect path must be preferred over the
|
||||
// rate-limited api.github.com, so a working redirect yields the tag
|
||||
// without ever needing the API.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/releases/latest"):
|
||||
http.Redirect(w, r, "/owner/repo/releases/tag/v9.9.9", http.StatusFound)
|
||||
case strings.HasSuffix(r.URL.Path, "/releases/tag/v9.9.9"):
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
rm.BaseDownloadURL = srv.URL
|
||||
rm.GitHubOwner = "owner"
|
||||
rm.GitHubRepo = "repo"
|
||||
|
||||
release, err := rm.GetLatestRelease()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(release.Version).To(Equal("v9.9.9"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -443,84 +443,23 @@ func (sm *SystrayManager) showStartupErrorDialog(err error) {
|
||||
})
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a progress window for downloading updates
|
||||
// showDownloadProgress shows a progress window for downloading updates. The
|
||||
// progress UI (byte readout, resume-aware retry, sizing) is shared with the
|
||||
// other download entry points via the launcher; only the post-success behaviour
|
||||
// (restart prompt + systray refresh) is specific to the update flow.
|
||||
func (sm *SystrayManager) showDownloadProgress(version string) {
|
||||
// Create a new window for download progress
|
||||
progressWindow := sm.app.NewWindow("Downloading LocalAI Update")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
sm.launcher.showDownloadProgressWindow(version, fmt.Sprintf("Downloading LocalAI version %s", version), func(win fyne.Window) {
|
||||
dialog.ShowConfirm("Update Downloaded",
|
||||
"LocalAI has been updated successfully. Please restart the launcher to use the new version.",
|
||||
func(restart bool) {
|
||||
if restart {
|
||||
sm.app.Quit()
|
||||
}
|
||||
win.Close()
|
||||
}, win)
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label. Truncate with an ellipsis so a long "Download failed:
|
||||
// <url>" message can't stretch the window (and progress bar) to fit the
|
||||
// whole error on one line; the full error is shown in the dialog below.
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
statusLabel.Truncation = fyne.TextTruncateEllipsis
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := sm.launcher.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sm.app.OpenURL(releaseNotesURL)
|
||||
sm.hasUpdateAvailable = false
|
||||
sm.latestVersion = ""
|
||||
sm.recreateMenu()
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
widget.NewLabel(fmt.Sprintf("Downloading LocalAI version %s", version)),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := sm.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
progressBar.SetValue(1.0)
|
||||
|
||||
// Show restart dialog
|
||||
dialog.ShowConfirm("Update Downloaded",
|
||||
"LocalAI has been updated successfully. Please restart the launcher to use the new version.",
|
||||
func(restart bool) {
|
||||
if restart {
|
||||
sm.app.Quit()
|
||||
}
|
||||
progressWindow.Close()
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
|
||||
// Update systray menu
|
||||
if err == nil {
|
||||
sm.hasUpdateAvailable = false
|
||||
sm.latestVersion = ""
|
||||
sm.recreateMenu()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -490,14 +490,19 @@ func (ui *LauncherUI) downloadUpdate() {
|
||||
ui.UpdateStatus("Downloading update " + version + "...")
|
||||
|
||||
go func() {
|
||||
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
err := ui.launcher.DownloadUpdate(version, func(downloaded, total int64) {
|
||||
fyne.Do(func() {
|
||||
ui.progressBar.SetValue(progress)
|
||||
if total > 0 {
|
||||
ui.progressBar.SetValue(float64(downloaded) / float64(total))
|
||||
}
|
||||
})
|
||||
// Update status with percentage
|
||||
percentage := int(progress * 100)
|
||||
ui.UpdateStatus(fmt.Sprintf("Downloading update %s... %d%%", version, percentage))
|
||||
// The progress bar already shows the percentage, so report the
|
||||
// human-readable size here instead of repeating the percent.
|
||||
if total > 0 {
|
||||
ui.UpdateStatus(fmt.Sprintf("Downloading update %s… %s / %s", version, formatBytes(downloaded), formatBytes(total)))
|
||||
} else {
|
||||
ui.UpdateStatus(fmt.Sprintf("Downloading update %s… %s", version, formatBytes(downloaded)))
|
||||
}
|
||||
})
|
||||
|
||||
fyne.Do(func() {
|
||||
@@ -598,82 +603,6 @@ func (ui *LauncherUI) LoadConfiguration() {
|
||||
log.Printf("UI LoadConfiguration: configuration loaded successfully")
|
||||
}
|
||||
|
||||
// showDownloadProgress shows a progress window for downloading LocalAI
|
||||
func (ui *LauncherUI) showDownloadProgress(version, title string) {
|
||||
fyne.DoAndWait(func() {
|
||||
// Create progress window using the launcher's app
|
||||
progressWindow := ui.launcher.app.NewWindow("Downloading LocalAI")
|
||||
progressWindow.Resize(fyne.NewSize(400, 250))
|
||||
progressWindow.CenterOnScreen()
|
||||
|
||||
// Progress bar
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
// Status label. Truncate with an ellipsis so a long "Download failed:
|
||||
// <url>" message can't stretch the window (and progress bar) to fit the
|
||||
// whole error on one line; the full error is shown in the dialog below.
|
||||
statusLabel := widget.NewLabel("Preparing download...")
|
||||
statusLabel.Truncation = fyne.TextTruncateEllipsis
|
||||
|
||||
// Release notes button
|
||||
releaseNotesButton := widget.NewButton("View Release Notes", func() {
|
||||
releaseNotesURL, err := ui.launcher.githubReleaseNotesURL(version)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ui.launcher.app.OpenURL(releaseNotesURL)
|
||||
})
|
||||
|
||||
// Progress container
|
||||
progressContainer := container.NewVBox(
|
||||
widget.NewLabel(title),
|
||||
progressBar,
|
||||
statusLabel,
|
||||
widget.NewSeparator(),
|
||||
releaseNotesButton,
|
||||
)
|
||||
|
||||
progressWindow.SetContent(progressContainer)
|
||||
progressWindow.Show()
|
||||
|
||||
// Start download in background
|
||||
go func() {
|
||||
err := ui.launcher.DownloadUpdate(version, func(progress float64) {
|
||||
// Update progress bar
|
||||
fyne.Do(func() {
|
||||
progressBar.SetValue(progress)
|
||||
percentage := int(progress * 100)
|
||||
statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage))
|
||||
})
|
||||
})
|
||||
|
||||
// Handle completion
|
||||
fyne.Do(func() {
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Download failed: %v", err))
|
||||
// Show error dialog
|
||||
dialog.ShowError(err, progressWindow)
|
||||
} else {
|
||||
statusLabel.SetText("Download completed successfully!")
|
||||
progressBar.SetValue(1.0)
|
||||
|
||||
// Show success dialog
|
||||
dialog.ShowConfirm("Installation Complete",
|
||||
"LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.",
|
||||
func(close bool) {
|
||||
progressWindow.Close()
|
||||
// Update status
|
||||
ui.UpdateStatus("LocalAI installed successfully")
|
||||
}, progressWindow)
|
||||
}
|
||||
})
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRunningState updates UI based on LocalAI running state
|
||||
func (ui *LauncherUI) UpdateRunningState(isRunning bool) {
|
||||
fyne.Do(func() {
|
||||
|
||||
@@ -37,6 +37,8 @@ func (a *Application) RestartAgentJobService() error {
|
||||
if d.JobStore != nil {
|
||||
agentJobService.SetDistributedJobStore(d.JobStore)
|
||||
}
|
||||
// Keep agent tasks consistent across replicas (same client the dispatcher uses).
|
||||
agentJobService.SetTaskSyncNATS(d.Nats)
|
||||
}
|
||||
|
||||
// Start the service
|
||||
|
||||
@@ -103,6 +103,11 @@ func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
mcpTools.CloseMCPSessions(modelName)
|
||||
})
|
||||
|
||||
// Record a model_load backend trace for every real backend load, so the
|
||||
// Traces UI shows which backend runtime served each model and how long
|
||||
// the load took. Load failures are traced by the modality wrappers.
|
||||
ml.SetLoadObserver(corebackend.ModelLoadTraceObserver(appConfig))
|
||||
|
||||
app := &Application{
|
||||
backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath),
|
||||
modelLoader: ml,
|
||||
@@ -604,6 +609,10 @@ func (a *Application) StartAgentPool() {
|
||||
usm.SetJobDBStore(s)
|
||||
}
|
||||
}
|
||||
// Keep per-user agent tasks consistent across replicas (nil in standalone).
|
||||
if d := a.Distributed(); d != nil {
|
||||
usm.SetJobSyncNATS(d.Nats)
|
||||
}
|
||||
aps.SetUserServicesManager(usm)
|
||||
|
||||
a.agentPoolService.Store(aps)
|
||||
|
||||
@@ -197,6 +197,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
envWatchdogBusy := appConfig.WatchDogBusy == startupAppConfig.WatchDogBusy
|
||||
envWatchdogIdleTimeout := appConfig.WatchDogIdleTimeout == startupAppConfig.WatchDogIdleTimeout
|
||||
envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout
|
||||
envWatchdogInterval := appConfig.WatchDogInterval == startupAppConfig.WatchDogInterval
|
||||
envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend
|
||||
envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends
|
||||
envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled
|
||||
@@ -257,6 +258,14 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
xlog.Warn("invalid watchdog busy timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogBusyTimeout)
|
||||
}
|
||||
}
|
||||
if settings.WatchdogInterval != nil && !envWatchdogInterval {
|
||||
dur, err := time.ParseDuration(*settings.WatchdogInterval)
|
||||
if err == nil {
|
||||
appConfig.WatchDogInterval = dur
|
||||
} else {
|
||||
xlog.Warn("invalid watchdog interval in runtime_settings.json", "error", err, "interval", *settings.WatchdogInterval)
|
||||
}
|
||||
}
|
||||
// Handle MaxActiveBackends (new) and SingleBackend (deprecated)
|
||||
if settings.MaxActiveBackends != nil && !envMaxActiveBackends {
|
||||
appConfig.MaxActiveBackends = *settings.MaxActiveBackends
|
||||
|
||||
@@ -355,6 +355,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
PrefixProvider: prefixProvider,
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
SharedModels: cfg.Distributed.SharedModels,
|
||||
})
|
||||
|
||||
// Wire staging-progress broadcasting so file-staging shows up on every
|
||||
|
||||
@@ -87,6 +87,31 @@ var _ = Describe("loadRuntimeSettingsFromFile", func() {
|
||||
})
|
||||
})
|
||||
|
||||
// Watchdog check interval (issue #10601). Unlike the idle/busy timeouts
|
||||
// (which default to 0), NewApplicationConfig baseline-defaults the
|
||||
// interval to 500ms. The loader's "apply file value only if still at the
|
||||
// zero default" env-detection therefore never fired for the interval, so
|
||||
// a UI-saved Check Interval silently reverted to 500ms on every restart
|
||||
// while the idle/busy timeouts persisted. These specs construct the
|
||||
// config the same way boot does (NewApplicationConfig) so they observe
|
||||
// the real default the loader sees.
|
||||
Describe("watchdog interval", func() {
|
||||
It("loads a UI-saved watchdog_interval on the next startup", func() {
|
||||
cfg := config.NewApplicationConfig()
|
||||
cfg.DynamicConfigsDir = seedSettings(`{"watchdog_interval": "2s"}`)
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.WatchDogInterval).To(Equal(2 * time.Second))
|
||||
})
|
||||
|
||||
It("does not override an explicit env/CLI interval", func() {
|
||||
cfg := config.NewApplicationConfig()
|
||||
cfg.DynamicConfigsDir = seedSettings(`{"watchdog_interval": "2s"}`)
|
||||
cfg.WatchDogInterval = 1 * time.Second // simulate SetWatchDogInterval from env
|
||||
loadRuntimeSettingsFromFile(cfg)
|
||||
Expect(cfg.WatchDogInterval).To(Equal(1*time.Second), "env/CLI interval must win over the persisted file value")
|
||||
})
|
||||
})
|
||||
|
||||
// MITM listener address. The file is the only source — no env var
|
||||
// exists — so a regression here means an admin who configured the
|
||||
// listener via /api/settings loses it after a reboot, even though
|
||||
|
||||
@@ -280,6 +280,9 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
if application.agentJobService != nil {
|
||||
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
|
||||
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
|
||||
// Keep agent tasks consistent across replicas (jobs already sync via the
|
||||
// dispatcher + DB read-through). Same NATS client the dispatcher uses.
|
||||
application.agentJobService.SetTaskSyncNATS(distSvc.Nats)
|
||||
}
|
||||
// Wire skill store into AgentPoolService (wired at pool start time via closure)
|
||||
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
|
||||
|
||||
72
core/backend/model_load_trace_test.go
Normal file
72
core/backend/model_load_trace_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package backend_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// ModelLoadTraceObserver is what makes successful loads visible on the
|
||||
// Traces page: one model_load row per real backend load, carrying the
|
||||
// resolved backend runtime. Failures must NOT be recorded here — the
|
||||
// modality wrappers own those — and the observer must respect the runtime
|
||||
// tracing toggle.
|
||||
var _ = Describe("ModelLoadTraceObserver", func() {
|
||||
var appConfig *config.ApplicationConfig
|
||||
|
||||
successEvent := model.BackendLoadEvent{
|
||||
ModelID: "parakeet-cpp-realtime_eou_120m-v1",
|
||||
ModelName: "realtime_eou_120m.gguf",
|
||||
Backend: "parakeet-cpp",
|
||||
BackendURI: "/backends/intel-sycl-f16-parakeet-cpp-development/run.sh",
|
||||
Duration: 1500 * time.Millisecond,
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
appConfig = &config.ApplicationConfig{
|
||||
EnableTracing: true,
|
||||
TracingMaxItems: 64,
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.ClearBackendTraces()
|
||||
})
|
||||
|
||||
It("records a model_load trace with the backend runtime on success", func() {
|
||||
backend.ModelLoadTraceObserver(appConfig)(successEvent)
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
Expect(got.Type).To(Equal(trace.BackendTraceModelLoad))
|
||||
Expect(got.Summary).To(Equal("Model loaded"))
|
||||
Expect(got.ModelName).To(Equal("parakeet-cpp-realtime_eou_120m-v1"))
|
||||
Expect(got.Backend).To(Equal("parakeet-cpp"))
|
||||
Expect(got.Duration).To(Equal(1500 * time.Millisecond))
|
||||
Expect(got.Data["backend_runtime"]).To(Equal("/backends/intel-sycl-f16-parakeet-cpp-development/run.sh"))
|
||||
Expect(got.Data["model_file"]).To(Equal("realtime_eou_120m.gguf"))
|
||||
Expect(got.Error).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("skips failed loads — the modality wrappers trace those with request context", func() {
|
||||
failed := successEvent
|
||||
failed.Err = errors.New("grpc service not ready")
|
||||
|
||||
backend.ModelLoadTraceObserver(appConfig)(failed)
|
||||
|
||||
Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty())
|
||||
})
|
||||
|
||||
It("records nothing when tracing is disabled", func() {
|
||||
appConfig.EnableTracing = false
|
||||
|
||||
backend.ModelLoadTraceObserver(appConfig)(successEvent)
|
||||
|
||||
Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -19,6 +19,39 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ModelLoadTraceObserver returns the ModelLoader load observer that records
|
||||
// a model_load backend trace for every successful real load (backend process
|
||||
// spawn + LoadModel RPC; cache hits never reach the observer). Failures are
|
||||
// deliberately skipped here: the modality wrappers already record them via
|
||||
// recordModelLoadFailure with request context, and the backend auto-discovery
|
||||
// scan probes several backends before one succeeds — tracing every probe
|
||||
// failure would bury the buffer in noise.
|
||||
//
|
||||
// The traced data includes the resolved backend runtime (the installed
|
||||
// backend's launcher path, which names the variant directory) — that is what
|
||||
// identifies WHICH build served the load. A stale installed backend is
|
||||
// invisible in the model config but obvious here.
|
||||
func ModelLoadTraceObserver(appConfig *config.ApplicationConfig) func(model.BackendLoadEvent) {
|
||||
return func(ev model.BackendLoadEvent) {
|
||||
if ev.Err != nil || !appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Duration: ev.Duration,
|
||||
Type: trace.BackendTraceModelLoad,
|
||||
ModelName: ev.ModelID,
|
||||
Backend: ev.Backend,
|
||||
Summary: "Model loaded",
|
||||
Data: map[string]any{
|
||||
"model_file": ev.ModelName,
|
||||
"backend_runtime": ev.BackendURI,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// recordModelLoadFailure records a backend trace when model loading fails.
|
||||
func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, backend string, err error, data map[string]any) {
|
||||
if !appConfig.EnableTracing {
|
||||
|
||||
@@ -181,6 +181,7 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
|
||||
Text: r.Text,
|
||||
Language: r.Language,
|
||||
Duration: float64(r.Duration),
|
||||
Eou: r.Eou,
|
||||
}
|
||||
|
||||
for _, s := range r.Segments {
|
||||
|
||||
297
core/backend/transcript_live.go
Normal file
297
core/backend/transcript_live.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/sound"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// LiveTranscriptionEvent is one streamed event from a live (bidirectional)
|
||||
// transcription session. Delta/Eou/Eob/Words arrive as the user speaks; Final
|
||||
// is set exactly once, on the terminal event after Close flushes the decode
|
||||
// tail. Eou means the model judged the user yielded the turn; Eob means a
|
||||
// backchannel ("uh-huh") ended — callers must NOT treat Eob as a turn
|
||||
// boundary.
|
||||
type LiveTranscriptionEvent struct {
|
||||
Delta string
|
||||
Eou bool
|
||||
Eob bool
|
||||
Words []schema.TranscriptionWord
|
||||
Final *schema.TranscriptionResult
|
||||
}
|
||||
|
||||
// LiveTranscriptionSession is a handle on an open live transcription stream.
|
||||
// Feed pushes 16 kHz mono float PCM; Close signals end-of-audio, waits for
|
||||
// the backend's terminal Final event to be delivered, and releases the
|
||||
// stream.
|
||||
type LiveTranscriptionSession interface {
|
||||
Feed(pcm []float32) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// liveCloseDrainTimeout bounds how long Close waits for the backend to flush
|
||||
// the decode tail before force-cancelling the stream. Finalize is one short
|
||||
// engine call; seconds here means the backend is wedged.
|
||||
const liveCloseDrainTimeout = 10 * time.Second
|
||||
|
||||
type liveTranscriptionSession struct {
|
||||
stream grpcPkg.AudioTranscriptionLiveClient
|
||||
cancel context.CancelFunc
|
||||
recvDone chan struct{}
|
||||
recvErr error // written by the recv goroutine before recvDone closes
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
trace *liveTraceState // nil when tracing was disabled at open
|
||||
}
|
||||
|
||||
func (s *liveTranscriptionSession) Feed(pcm []float32) error {
|
||||
s.trace.addPCM(pcm)
|
||||
return s.stream.Send(&proto.TranscriptLiveRequest{
|
||||
Payload: &proto.TranscriptLiveRequest_Audio{Audio: &proto.TranscriptLiveAudio{Pcm: pcm}},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *liveTranscriptionSession) Close() error {
|
||||
s.closeOnce.Do(func() {
|
||||
err := s.stream.CloseSend()
|
||||
select {
|
||||
case <-s.recvDone:
|
||||
case <-time.After(liveCloseDrainTimeout):
|
||||
xlog.Warn("live transcription: backend did not finalize in time; cancelling stream")
|
||||
s.cancel()
|
||||
<-s.recvDone
|
||||
}
|
||||
s.cancel()
|
||||
if err == nil {
|
||||
err = s.recvErr
|
||||
}
|
||||
s.closeErr = err
|
||||
s.trace.record(err)
|
||||
})
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
// liveSampleRate is the PCM rate of a live transcription session, fixed by
|
||||
// the session config sent in ModelTranscriptionLive.
|
||||
const liveSampleRate = 16000
|
||||
|
||||
// liveTraceState accumulates what the per-turn backend trace needs while a
|
||||
// live session runs: a bounded copy of the fed PCM for the audio snippet,
|
||||
// the decode outputs, and timing. One trace is recorded at Close — the live
|
||||
// path never touches the unary transcription wrapper, so without this a
|
||||
// streaming-only pipeline produced no transcription traces at all. Feed and
|
||||
// the recv goroutine run concurrently; mu guards the accumulators.
|
||||
type liveTraceState struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
modelName string
|
||||
backend string
|
||||
language string
|
||||
started time.Time
|
||||
|
||||
mu sync.Mutex
|
||||
pcm []byte // first trace.MaxSnippetSeconds of fed audio, int16 LE
|
||||
fedSamples int // ALL samples fed, beyond the snippet cap
|
||||
deltaEvents int
|
||||
eouEvents int
|
||||
eobEvents int
|
||||
finalText string
|
||||
}
|
||||
|
||||
func newLiveTraceState(modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, language string) *liveTraceState {
|
||||
if !appConfig.EnableTracing {
|
||||
return nil
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
return &liveTraceState{
|
||||
appConfig: appConfig,
|
||||
modelName: modelConfig.Name,
|
||||
backend: modelConfig.Backend,
|
||||
language: language,
|
||||
started: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *liveTraceState) addPCM(pcm []float32) {
|
||||
if ts == nil {
|
||||
return
|
||||
}
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.fedSamples += len(pcm)
|
||||
maxBytes := trace.MaxSnippetSeconds * liveSampleRate * 2
|
||||
if room := (maxBytes - len(ts.pcm)) / 2; room > 0 {
|
||||
if len(pcm) > room {
|
||||
pcm = pcm[:room]
|
||||
}
|
||||
ts.pcm = append(ts.pcm, sound.Float32sToInt16LEBytes(pcm)...)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *liveTraceState) observe(ev LiveTranscriptionEvent) {
|
||||
if ts == nil {
|
||||
return
|
||||
}
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ev.Delta != "" {
|
||||
ts.deltaEvents++
|
||||
}
|
||||
if ev.Eou {
|
||||
ts.eouEvents++
|
||||
}
|
||||
if ev.Eob {
|
||||
ts.eobEvents++
|
||||
}
|
||||
if ev.Final != nil {
|
||||
ts.finalText = ev.Final.Text
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *liveTraceState) record(closeErr error) {
|
||||
if ts == nil || !ts.appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
ts.mu.Lock()
|
||||
data := map[string]any{
|
||||
"source": "live_stream",
|
||||
"language": ts.language,
|
||||
"result_text": ts.finalText,
|
||||
"eou_events": ts.eouEvents,
|
||||
"eob_events": ts.eobEvents,
|
||||
"delta_events": ts.deltaEvents,
|
||||
}
|
||||
if snippet := trace.AudioSnippetFromPCM(ts.pcm, liveSampleRate, ts.fedSamples*2, ts.appConfig.TracingMaxBodyBytes); snippet != nil {
|
||||
maps.Copy(data, snippet)
|
||||
}
|
||||
summary := "live -> " + ts.finalText
|
||||
ts.mu.Unlock()
|
||||
|
||||
bt := trace.BackendTrace{
|
||||
Timestamp: ts.started,
|
||||
Duration: time.Since(ts.started),
|
||||
Type: trace.BackendTraceTranscription,
|
||||
ModelName: ts.modelName,
|
||||
Backend: ts.backend,
|
||||
Summary: trace.TruncateString(summary, 200),
|
||||
Data: data,
|
||||
}
|
||||
if closeErr != nil {
|
||||
bt.Error = closeErr.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(bt)
|
||||
}
|
||||
|
||||
// ModelTranscriptionLive loads the transcription backend, opens the
|
||||
// bidirectional AudioTranscriptionLive RPC, sends the session config, and
|
||||
// BLOCKS until the backend's ready ack. A grpcerrors.
|
||||
// IsLiveTranscriptionUnsupported error means the backend (or the loaded
|
||||
// model) cannot do live transcription and the caller should degrade to the
|
||||
// unary/file path. After a successful return, onEvent is invoked from a
|
||||
// background goroutine — in order, one event at a time — for every response
|
||||
// the backend streams, ending with the Final event triggered by Close.
|
||||
func ModelTranscriptionLive(ctx context.Context, language string,
|
||||
ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig,
|
||||
onEvent func(LiveTranscriptionEvent)) (LiveTranscriptionSession, error) {
|
||||
|
||||
transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The derived cancel out-lives this call inside the session: Close uses
|
||||
// it to unwind the stream (and, in embed mode, the server-side recv
|
||||
// pump, which only stops on send-close or context cancellation).
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
stream, err := transcriptionModel.AudioTranscriptionLive(streamCtx)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fail := func(err error) (LiveTranscriptionSession, error) {
|
||||
_ = stream.CloseSend()
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := stream.Send(&proto.TranscriptLiveRequest{
|
||||
Payload: &proto.TranscriptLiveRequest_Config{Config: &proto.TranscriptLiveConfig{
|
||||
Language: language,
|
||||
SampleRate: liveSampleRate,
|
||||
}},
|
||||
}); err != nil {
|
||||
return fail(err)
|
||||
}
|
||||
|
||||
// Ready-ack contract: the backend answers a successful open with a
|
||||
// {ready:true} response before any transcript data; unsupported
|
||||
// backends surface Unimplemented here instead.
|
||||
ack, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fail(err)
|
||||
}
|
||||
if !ack.GetReady() {
|
||||
return fail(fmt.Errorf("live transcription: backend %q broke the ready-ack contract (first response carried data)", modelConfig.Backend))
|
||||
}
|
||||
|
||||
s := &liveTranscriptionSession{
|
||||
stream: stream,
|
||||
cancel: cancel,
|
||||
recvDone: make(chan struct{}),
|
||||
trace: newLiveTraceState(modelConfig, appConfig, language),
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(s.recvDone)
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && streamCtx.Err() == nil {
|
||||
xlog.Warn("live transcription stream ended unexpectedly", "error", err)
|
||||
s.recvErr = err
|
||||
}
|
||||
return
|
||||
}
|
||||
ev := liveEventFromProto(resp)
|
||||
if ev.Delta == "" && !ev.Eou && !ev.Eob && len(ev.Words) == 0 && ev.Final == nil {
|
||||
continue // duplicate ready ack / keep-alive: nothing to deliver
|
||||
}
|
||||
s.trace.observe(ev)
|
||||
onEvent(ev)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func liveEventFromProto(r *proto.TranscriptLiveResponse) LiveTranscriptionEvent {
|
||||
ev := LiveTranscriptionEvent{
|
||||
Delta: r.GetDelta(),
|
||||
Eou: r.GetEou(),
|
||||
Eob: r.GetEob(),
|
||||
}
|
||||
for _, w := range r.GetWords() {
|
||||
ev.Words = append(ev.Words, schema.TranscriptionWord{
|
||||
Start: time.Duration(w.Start),
|
||||
End: time.Duration(w.End),
|
||||
Text: w.Text,
|
||||
})
|
||||
}
|
||||
if r.GetFinalResult() != nil {
|
||||
ev.Final = transcriptResultFromProto(r.GetFinalResult())
|
||||
}
|
||||
return ev
|
||||
}
|
||||
162
core/backend/transcript_live_internal_test.go
Normal file
162
core/backend/transcript_live_internal_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("liveEventFromProto", func() {
|
||||
It("maps deltas, eou flags and words (ns -> duration)", func() {
|
||||
ev := liveEventFromProto(&proto.TranscriptLiveResponse{
|
||||
Delta: "hello ",
|
||||
Eou: true,
|
||||
Words: []*proto.TranscriptWord{
|
||||
{Start: int64(100 * time.Millisecond), End: int64(400 * time.Millisecond), Text: "hello"},
|
||||
},
|
||||
})
|
||||
Expect(ev.Delta).To(Equal("hello "))
|
||||
Expect(ev.Eou).To(BeTrue())
|
||||
Expect(ev.Words).To(HaveLen(1))
|
||||
Expect(ev.Words[0].Text).To(Equal("hello"))
|
||||
Expect(ev.Words[0].Start).To(Equal(100 * time.Millisecond))
|
||||
Expect(ev.Words[0].End).To(Equal(400 * time.Millisecond))
|
||||
Expect(ev.Final).To(BeNil())
|
||||
})
|
||||
|
||||
It("maps the terminal final result including the eou flag", func() {
|
||||
ev := liveEventFromProto(&proto.TranscriptLiveResponse{
|
||||
FinalResult: &proto.TranscriptResult{
|
||||
Text: "hello world",
|
||||
Duration: 1.5,
|
||||
Eou: true,
|
||||
Segments: []*proto.TranscriptSegment{{Id: 0, Text: "hello world"}},
|
||||
},
|
||||
})
|
||||
Expect(ev.Final).NotTo(BeNil())
|
||||
Expect(ev.Final.Text).To(Equal("hello world"))
|
||||
Expect(ev.Final.Duration).To(BeNumerically("~", 1.5, 1e-6))
|
||||
Expect(ev.Final.Eou).To(BeTrue())
|
||||
Expect(ev.Final.Segments).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("yields an empty event for a bare ready ack (filtered by the recv loop)", func() {
|
||||
ev := liveEventFromProto(&proto.TranscriptLiveResponse{Ready: true})
|
||||
Expect(ev.Delta).To(BeEmpty())
|
||||
Expect(ev.Eou).To(BeFalse())
|
||||
Expect(ev.Words).To(BeEmpty())
|
||||
Expect(ev.Final).To(BeNil())
|
||||
})
|
||||
|
||||
It("maps the eob backchannel flag separately from eou", func() {
|
||||
ev := liveEventFromProto(&proto.TranscriptLiveResponse{Delta: "uh-huh", Eob: true})
|
||||
Expect(ev.Eob).To(BeTrue())
|
||||
Expect(ev.Eou).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
// liveTraceState is what makes streaming-only pipelines visible on the
|
||||
// Traces page: without it a semantic_vad session with retranscribe off
|
||||
// produced no transcription trace at all. One trace per session (= one per
|
||||
// realtime turn), recorded at Close.
|
||||
var _ = Describe("liveTraceState", func() {
|
||||
var appConfig *config.ApplicationConfig
|
||||
|
||||
BeforeEach(func() {
|
||||
appConfig = &config.ApplicationConfig{
|
||||
EnableTracing: true,
|
||||
TracingMaxItems: 64,
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
trace.ClearBackendTraces()
|
||||
})
|
||||
|
||||
modelCfg := func() config.ModelConfig {
|
||||
cfg := config.ModelConfig{Backend: "parakeet-cpp"}
|
||||
cfg.Name = "parakeet-live"
|
||||
return cfg
|
||||
}
|
||||
|
||||
It("is disabled (nil) when tracing is off, and nil receivers are no-ops", func() {
|
||||
appConfig.EnableTracing = false
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "en")
|
||||
Expect(ts).To(BeNil())
|
||||
|
||||
// The session calls these unconditionally; nil must be safe.
|
||||
ts.addPCM([]float32{0.5})
|
||||
ts.observe(LiveTranscriptionEvent{Eou: true})
|
||||
ts.record(nil)
|
||||
Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty())
|
||||
})
|
||||
|
||||
It("records one transcription trace with text, eou event counts and audio snippet at Close", func() {
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "en")
|
||||
Expect(ts).NotTo(BeNil())
|
||||
|
||||
// One second of a loud-ish constant tone so the snippet has signal.
|
||||
pcm := make([]float32, liveSampleRate)
|
||||
for i := range pcm {
|
||||
pcm[i] = 0.25
|
||||
}
|
||||
ts.addPCM(pcm)
|
||||
ts.observe(LiveTranscriptionEvent{Delta: "hello "})
|
||||
ts.observe(LiveTranscriptionEvent{Delta: "world", Eou: true})
|
||||
ts.observe(LiveTranscriptionEvent{Final: &schema.TranscriptionResult{Text: "hello world", Eou: true}})
|
||||
|
||||
ts.record(nil)
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
Expect(got.Type).To(Equal(trace.BackendTraceTranscription))
|
||||
Expect(got.ModelName).To(Equal("parakeet-live"))
|
||||
Expect(got.Backend).To(Equal("parakeet-cpp"))
|
||||
Expect(got.Summary).To(ContainSubstring("hello world"))
|
||||
Expect(got.Data["source"]).To(Equal("live_stream"))
|
||||
Expect(got.Data["result_text"]).To(Equal("hello world"))
|
||||
// The live FinalResult no longer carries a terminal eou flag; the
|
||||
// per-feed eou_events count is what the trace records instead.
|
||||
Expect(got.Data).NotTo(HaveKey("eou"))
|
||||
Expect(got.Data["eou_events"]).To(Equal(1))
|
||||
Expect(got.Data["delta_events"]).To(Equal(2))
|
||||
Expect(got.Data["audio_duration_s"]).To(BeNumerically("~", 1.0, 0.01))
|
||||
Expect(got.Data["audio_wav_base64"]).NotTo(BeEmpty())
|
||||
Expect(got.Error).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("caps the stored snippet but keeps counting the full fed duration", func() {
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "")
|
||||
|
||||
// Feed past the snippet cap in two chunks (cap + one extra second).
|
||||
ts.addPCM(make([]float32, trace.MaxSnippetSeconds*liveSampleRate))
|
||||
ts.addPCM(make([]float32, liveSampleRate))
|
||||
|
||||
Expect(len(ts.pcm)).To(Equal(trace.MaxSnippetSeconds * liveSampleRate * 2))
|
||||
Expect(ts.fedSamples).To(Equal((trace.MaxSnippetSeconds + 1) * liveSampleRate))
|
||||
|
||||
ts.record(nil)
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
got := trace.GetBackendTraces()[0]
|
||||
Expect(got.Data["audio_duration_s"]).To(BeNumerically("~", float64(trace.MaxSnippetSeconds+1), 0.01))
|
||||
Expect(got.Data["audio_snippet_s"]).To(BeNumerically("~", float64(trace.MaxSnippetSeconds), 0.01))
|
||||
})
|
||||
|
||||
It("clamps out-of-range float samples instead of wrapping", func() {
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "")
|
||||
ts.addPCM([]float32{2.0, -2.0})
|
||||
Expect(ts.pcm).To(Equal([]byte{0xff, 0x7f, 0x00, 0x80})) // 32767, -32768
|
||||
})
|
||||
|
||||
It("stamps the close error on the trace", func() {
|
||||
ts := newLiveTraceState(modelCfg(), appConfig, "")
|
||||
ts.record(errors.New("stream torn down"))
|
||||
|
||||
Eventually(trace.GetBackendTraces).Should(HaveLen(1))
|
||||
Expect(trace.GetBackendTraces()[0].Error).To(Equal("stream torn down"))
|
||||
})
|
||||
})
|
||||
@@ -160,6 +160,7 @@ type RunCMD struct {
|
||||
RegistrationRequireAuth bool `env:"LOCALAI_REGISTRATION_REQUIRE_AUTH" default:"false" help:"Fail startup when distributed mode is enabled but LOCALAI_REGISTRATION_TOKEN is empty (node endpoints and worker file-transfer server would otherwise be unauthenticated)" group:"distributed"`
|
||||
DistributedRequireAuth bool `env:"LOCALAI_DISTRIBUTED_REQUIRE_AUTH" default:"false" help:"Umbrella switch: require BOTH NATS JWT credentials and a registration token when distributed mode is enabled (implies --nats-require-auth and --registration-require-auth)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
DistributedSharedModels bool `env:"LOCALAI_DISTRIBUTED_SHARED_MODELS" default:"false" help:"Assert that every node mounts the SAME models directory at the SAME path (shared volume). When true, the router skips staging model files to workers and loads them directly from the shared path, avoiding re-downloads." group:"distributed"`
|
||||
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
@@ -310,6 +311,9 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.DistributedRequireAuth {
|
||||
opts = append(opts, config.EnableDistributedRequireAuth)
|
||||
}
|
||||
if r.DistributedSharedModels {
|
||||
opts = append(opts, config.EnableDistributedSharedModels)
|
||||
}
|
||||
if r.NatsAccountSeed != "" {
|
||||
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -241,12 +242,19 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
Context: context.Background(),
|
||||
UploadLimitMB: 15,
|
||||
Debug: true,
|
||||
AgentJobRetentionDays: 30, // Default: 30 days
|
||||
LRUEvictionMaxRetries: 30, // Default: 30 retries
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||
TracingMaxItems: 1024,
|
||||
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||
AgentJobRetentionDays: 30, // Default: 30 days
|
||||
LRUEvictionMaxRetries: 30, // Default: 30 retries
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
// WatchDogInterval is intentionally left at the zero value here.
|
||||
// The startup loader applies a persisted runtime_settings.json value
|
||||
// only when the interval is still 0 (its "not set by env var"
|
||||
// heuristic, matching the idle/busy timeouts); a non-zero baseline
|
||||
// default would defeat that and silently revert a UI-saved Check
|
||||
// Interval to the default on every restart (#10601). The effective
|
||||
// 500ms default is supplied at the watchdog layer (DefaultWatchdogInterval)
|
||||
// when the value is still 0.
|
||||
TracingMaxItems: 1024,
|
||||
TracingMaxBodyBytes: 64 * 1024, // 64 KiB - caps each request/response body in the trace buffer
|
||||
AgentPool: AgentPoolConfig{
|
||||
Enabled: true,
|
||||
Timeout: "5m",
|
||||
@@ -1097,7 +1105,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
if o.WatchDogInterval > 0 {
|
||||
watchdogInterval = o.WatchDogInterval.String()
|
||||
} else {
|
||||
watchdogInterval = "2s" // default
|
||||
watchdogInterval = model.DefaultWatchdogInterval.String() // default: 500ms
|
||||
}
|
||||
var lruEvictionRetryInterval string
|
||||
if o.LRUEvictionRetryInterval > 0 {
|
||||
|
||||
@@ -542,6 +542,19 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
DefaultUsecases: []string{UsecaseSpeakerRecognition},
|
||||
Description: "Speaker recognition — voice identity verification and analysis",
|
||||
},
|
||||
"voice-detect": {
|
||||
GRPCMethods: []GRPCMethod{MethodVoiceVerify, MethodVoiceEmbed, MethodVoiceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseSpeakerRecognition},
|
||||
DefaultUsecases: []string{UsecaseSpeakerRecognition},
|
||||
Description: "voice-detect.cpp: C++/ggml speaker embedding, verification and voice analysis (age/gender/emotion)",
|
||||
},
|
||||
"face-detect": {
|
||||
GRPCMethods: []GRPCMethod{MethodEmbedding, MethodDetect, MethodFaceVerify, MethodFaceAnalyze},
|
||||
PossibleUsecases: []string{UsecaseEmbeddings, UsecaseDetection, UsecaseFaceRecognition},
|
||||
DefaultUsecases: []string{UsecaseFaceRecognition},
|
||||
AcceptsImages: true,
|
||||
Description: "face-detect.cpp: C++/ggml face detection, embedding, verification and attribute analysis",
|
||||
},
|
||||
"silero-vad": {
|
||||
GRPCMethods: []GRPCMethod{MethodVAD},
|
||||
PossibleUsecases: []string{UsecaseVAD},
|
||||
|
||||
@@ -12,14 +12,12 @@ package config
|
||||
// these; config never imports backend.
|
||||
const (
|
||||
// DefaultContextSize is the fallback context window when none is configured
|
||||
// or estimable from the model.
|
||||
// or estimable from the model. It is also the fallback for a GGUF whose
|
||||
// metadata yields no usable estimate or that the parser cannot read at all
|
||||
// (e.g. a quant type it does not know, such as NVFP4): a model-agnostic
|
||||
// safe default beats a tiny, surprising window that truncates real prompts.
|
||||
DefaultContextSize = 4096
|
||||
|
||||
// GGUFFallbackContextSize is the context window for a GGUF model whose
|
||||
// metadata yields no usable estimate (see guessGGUFFromFile). Deliberately
|
||||
// smaller than DefaultContextSize to stay conservative on memory there.
|
||||
GGUFFallbackContextSize = 1024
|
||||
|
||||
// DefaultNGPULayers means "offload all layers"; the backend (fit_params)
|
||||
// clamps to what actually fits in device memory.
|
||||
DefaultNGPULayers = 99999999
|
||||
|
||||
@@ -31,6 +31,14 @@ type DistributedConfig struct {
|
||||
// available to enforce just one layer.
|
||||
RequireAuth bool // LOCALAI_DISTRIBUTED_REQUIRE_AUTH
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
// SharedModels asserts that every node (frontend and workers) mounts the
|
||||
// SAME models directory at the SAME path (e.g. a shared volume, as in
|
||||
// docker-compose.distributed.yaml). When true, the router skips staging
|
||||
// model files to workers entirely: the frontend's absolute model paths are
|
||||
// already valid on the worker, so re-uploading them into a per-model
|
||||
// subdirectory only re-downloads what is already present (#10556). Default
|
||||
// false preserves the historical per-node staging behavior.
|
||||
SharedModels bool // --distributed-shared-models / LOCALAI_DISTRIBUTED_SHARED_MODELS
|
||||
|
||||
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
|
||||
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
|
||||
@@ -282,6 +290,13 @@ var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// EnableDistributedSharedModels marks the cluster as sharing one models
|
||||
// directory across all nodes, so the router skips staging model files to
|
||||
// workers (see DistributedConfig.SharedModels).
|
||||
var EnableDistributedSharedModels = func(o *ApplicationConfig) {
|
||||
o.Distributed.SharedModels = true
|
||||
}
|
||||
|
||||
// DisablePrefixCache turns off prefix-cache-aware routing (falls back to
|
||||
// round-robin). Prefix-cache routing is enabled by default in distributed mode.
|
||||
var DisablePrefixCache = func(o *ApplicationConfig) {
|
||||
|
||||
@@ -33,7 +33,7 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
cSize := int(ctxSize)
|
||||
cfg.ContextSize = &cSize
|
||||
} else {
|
||||
defaultCtx = GGUFFallbackContextSize
|
||||
defaultCtx = DefaultContextSize
|
||||
cfg.ContextSize = &defaultCtx
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ func llamaCppDefaults(cfg *ModelConfig, modelPath string) {
|
||||
// Default context size if not set, regardless of whether GGUF parsing succeeds
|
||||
defer func() {
|
||||
if cfg.ContextSize == nil {
|
||||
ctx := GGUFFallbackContextSize
|
||||
ctx := DefaultContextSize
|
||||
cfg.ContextSize = &ctx
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -248,7 +248,11 @@ var _ = Describe("Backend hooks and parser defaults", func() {
|
||||
}
|
||||
cfg.SetDefaults(ModelPath(dir))
|
||||
|
||||
// An unreadable/unparseable GGUF (e.g. a quant type the parser does
|
||||
// not know, such as NVFP4) yields no estimate, so the hook must fall
|
||||
// back to DefaultContextSize rather than a tiny, surprising value.
|
||||
Expect(cfg.ContextSize).NotTo(BeNil())
|
||||
Expect(*cfg.ContextSize).To(Equal(DefaultContextSize))
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -567,6 +567,38 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Advanced: true,
|
||||
Order: 83,
|
||||
},
|
||||
"pipeline.turn_detection.type": {
|
||||
Section: "pipeline",
|
||||
Label: "Turn Detection",
|
||||
Description: "Default turn-detection mode for realtime sessions on this pipeline. server_vad commits after a fixed silence window; semantic_vad lets the transcription model's end-of-utterance token drive a dynamic window (fast commit after the token, long eagerness fallback without it). semantic_vad requires a streaming-EOU transcription model (e.g. parakeet-cpp-realtime_eou_120m-v1) and degrades to silence-only otherwise. Clients can override per session via session.update.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Default (server_vad)"},
|
||||
{Value: "server_vad", Label: "server_vad (silence-based)"},
|
||||
{Value: "semantic_vad", Label: "semantic_vad (end-of-utterance token)"},
|
||||
},
|
||||
Order: 87,
|
||||
},
|
||||
"pipeline.turn_detection.eagerness": {
|
||||
Section: "pipeline",
|
||||
Label: "Eagerness",
|
||||
Description: "semantic_vad fallback silence window used when no end-of-utterance token was seen: low waits 8s, medium/auto 4s, high 2s.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "", Label: "Default (auto)"},
|
||||
{Value: "low", Label: "low (8s)"},
|
||||
{Value: "medium", Label: "medium (4s)"},
|
||||
{Value: "high", Label: "high (2s)"},
|
||||
},
|
||||
Order: 88,
|
||||
},
|
||||
"pipeline.turn_detection.retranscribe": {
|
||||
Section: "pipeline",
|
||||
Label: "Retranscribe on Commit",
|
||||
Description: "Cross-check every semantic_vad commit with an offline decode of the buffered turn: commit only proceeds when the batch decode also ends in the end-of-utterance token, and its transcript is used. Logs a streamed-vs-batch comparison — useful to gauge streaming/batch alignment — at the cost of one extra decode per turn.",
|
||||
Component: "toggle",
|
||||
Order: 89,
|
||||
},
|
||||
|
||||
// --- Functions ---
|
||||
"function.grammar.parallel_calls": {
|
||||
|
||||
@@ -650,6 +650,12 @@ type Pipeline struct {
|
||||
// VoiceRecognition gates the pipeline behind speaker verification. Nil
|
||||
// (block absent) means no gate, preserving existing behavior.
|
||||
VoiceRecognition *PipelineVoiceRecognition `yaml:"voice_recognition,omitempty" json:"voice_recognition,omitempty"`
|
||||
|
||||
// TurnDetection sets the server-side default turn-detection mode for
|
||||
// realtime sessions on this pipeline, so clients need no session.update
|
||||
// to benefit. A client session.update still overrides type and eagerness
|
||||
// per session; retranscribe is server-side only. Unset keeps server_vad.
|
||||
TurnDetection PipelineTurnDetection `yaml:"turn_detection,omitempty" json:"turn_detection,omitempty"`
|
||||
}
|
||||
|
||||
// PipelineCompaction configures summarize-then-drop for a realtime pipeline.
|
||||
@@ -934,6 +940,38 @@ func (v PipelineVoiceRecognition) Validate(registryAvailable bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// @Description PipelineTurnDetection sets realtime turn-detection defaults.
|
||||
type PipelineTurnDetection struct {
|
||||
// Type selects the default turn_detection mode for sessions on this
|
||||
// pipeline: "server_vad" (silence-based) or "semantic_vad" (the
|
||||
// transcription model's end-of-utterance token drives a dynamic silence
|
||||
// window; needs a streaming-EOU transcription model such as
|
||||
// parakeet_realtime_eou_120m-v1, degrades to silence-only otherwise).
|
||||
Type string `yaml:"type,omitempty" json:"type,omitempty"`
|
||||
// Eagerness is the semantic_vad fallback when no end-of-utterance token
|
||||
// was seen: low waits 8s of silence, medium/auto 4s, high 2s.
|
||||
Eagerness string `yaml:"eagerness,omitempty" json:"eagerness,omitempty"`
|
||||
// Retranscribe (semantic_vad only) cross-checks every EOU-triggered
|
||||
// commit with an offline decode of the buffered turn: the commit only
|
||||
// proceeds when the batch decode also ends in the end-of-utterance token,
|
||||
// and its transcript is the one used. The streamed and batch transcripts
|
||||
// are compared in the logs — a diagnostic for streaming/batch alignment
|
||||
// at the cost of one extra decode per turn.
|
||||
Retranscribe *bool `yaml:"retranscribe,omitempty" json:"retranscribe,omitempty"`
|
||||
}
|
||||
|
||||
// TurnDetectionSemantic reports whether this pipeline defaults sessions to
|
||||
// semantic (EOU-driven) turn detection.
|
||||
func (p Pipeline) TurnDetectionSemantic() bool {
|
||||
return strings.EqualFold(strings.TrimSpace(p.TurnDetection.Type), "semantic_vad")
|
||||
}
|
||||
|
||||
// TurnDetectionRetranscribe reports whether semantic_vad commits should be
|
||||
// cross-checked (and transcribed) by an offline decode of the buffered turn.
|
||||
func (p Pipeline) TurnDetectionRetranscribe() bool {
|
||||
return p.TurnDetection.Retranscribe != nil && *p.TurnDetection.Retranscribe
|
||||
}
|
||||
|
||||
// @Description File configuration for model downloads
|
||||
type File struct {
|
||||
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
|
||||
|
||||
61
core/config/pipeline_turn_detection_test.go
Normal file
61
core/config/pipeline_turn_detection_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// pipeline.turn_detection sets the server-side default turn-detection mode
|
||||
// for realtime sessions. Unset keeps server_vad, so existing configs are
|
||||
// unaffected; retranscribe is opt-in.
|
||||
var _ = Describe("Pipeline turn_detection config", func() {
|
||||
It("defaults to non-semantic with retranscribe off when unset", func() {
|
||||
var p Pipeline
|
||||
Expect(p.TurnDetectionSemantic()).To(BeFalse())
|
||||
Expect(p.TurnDetectionRetranscribe()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("parses the nested turn_detection block from YAML", func() {
|
||||
var c ModelConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
name: gpt-realtime
|
||||
pipeline:
|
||||
transcription: parakeet-cpp-realtime_eou_120m-v1
|
||||
turn_detection:
|
||||
type: semantic_vad
|
||||
eagerness: high
|
||||
retranscribe: true
|
||||
`), &c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Pipeline.TurnDetectionSemantic()).To(BeTrue())
|
||||
Expect(c.Pipeline.TurnDetection.Eagerness).To(Equal("high"))
|
||||
Expect(c.Pipeline.TurnDetectionRetranscribe()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("treats server_vad and unknown types as non-semantic", func() {
|
||||
var p Pipeline
|
||||
p.TurnDetection.Type = "server_vad"
|
||||
Expect(p.TurnDetectionSemantic()).To(BeFalse())
|
||||
p.TurnDetection.Type = "something_else"
|
||||
Expect(p.TurnDetectionSemantic()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("matches semantic_vad case-insensitively with surrounding space", func() {
|
||||
var p Pipeline
|
||||
p.TurnDetection.Type = " Semantic_VAD "
|
||||
Expect(p.TurnDetectionSemantic()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("treats an explicit retranscribe false as off", func() {
|
||||
var c ModelConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
pipeline:
|
||||
turn_detection:
|
||||
type: semantic_vad
|
||||
retranscribe: false
|
||||
`), &c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Pipeline.TurnDetectionRetranscribe()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -22,11 +22,13 @@ var _ = Describe("DiscoverModelConfig", func() {
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// No name preference + repo-root URI: the name follows the selected
|
||||
// GGUF file, not the repo (issue #10587).
|
||||
Expect(modelConfig.Name).To(Equal("localai-functioncall-qwen2.5-7b-v0.5-q4_k_m"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/resolve/main/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
@@ -38,16 +40,17 @@ var _ = Describe("DiscoverModelConfig", func() {
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// No name preference: name follows the selected model GGUF (issue #10587).
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3VL-2B-Instruct-Q4_K_M"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3VL-2B-Instruct-Q4_K_M/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3VL-2B-Instruct-Q4_K_M/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3VL-2B-Instruct-Q4_K_M/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3VL-2B-Instruct-Q4_K_M/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
@@ -59,16 +62,17 @@ var _ = Describe("DiscoverModelConfig", func() {
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// No name preference: name follows the selected Q8_0 model GGUF (issue #10587).
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3VL-2B-Instruct-Q8_0"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3VL-2B-Instruct-Q8_0/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3VL-2B-Instruct-Q8_0/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3VL-2B-Instruct-Q8_0/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3VL-2B-Instruct-Q8_0/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
@@ -25,8 +25,8 @@ var (
|
||||
|
||||
type LlamaCPPImporter struct{}
|
||||
|
||||
func (i *LlamaCPPImporter) Name() string { return "llama-cpp" }
|
||||
func (i *LlamaCPPImporter) Modality() string { return "text" }
|
||||
func (i *LlamaCPPImporter) Name() string { return "llama-cpp" }
|
||||
func (i *LlamaCPPImporter) Modality() string { return "text" }
|
||||
func (i *LlamaCPPImporter) AutoDetects() bool { return true }
|
||||
|
||||
// AdditionalBackends advertises drop-in replacements that share the
|
||||
@@ -98,8 +98,13 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
// nameProvided tracks whether the user supplied an explicit model name.
|
||||
// When they didn't, the URI base is only a fallback: for a HuggingFace
|
||||
// repo-root URI (no file component) it would be the repo name, so the HF
|
||||
// branch below re-derives the name from the selected GGUF file instead
|
||||
// (issue #10587).
|
||||
name, nameProvided := preferencesMap["name"].(string)
|
||||
if !nameProvided {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
@@ -227,10 +232,23 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
mmprojGroups := hfapi.GroupShards(mmprojFiles)
|
||||
ggufGroups := hfapi.GroupShards(ggufFiles)
|
||||
|
||||
modelGroup := pickPreferredGroup(ggufGroups, quants)
|
||||
|
||||
// A repo-root URI has no file component, so the URI-base fallback
|
||||
// above produced the repo name. When the user left the name blank,
|
||||
// derive it from the GGUF file actually selected from the listing so
|
||||
// the gallery entry and `model:` directory reflect the model, not the
|
||||
// repository (issue #10587). An explicit name preference always wins.
|
||||
if !nameProvided && modelGroup != nil {
|
||||
name = modelNameFromShardGroup(*modelGroup)
|
||||
modelConfig.Name = name
|
||||
cfg.Name = name
|
||||
}
|
||||
|
||||
// Emit the model group first so cfg.Files[0] is the model — callers
|
||||
// and tests rely on the model file preceding any mmproj companion.
|
||||
if group := pickPreferredGroup(ggufGroups, quants); group != nil {
|
||||
appendShardGroup(&cfg, *group, filepath.Join("llama-cpp", "models", name))
|
||||
if modelGroup != nil {
|
||||
appendShardGroup(&cfg, *modelGroup, filepath.Join("llama-cpp", "models", name))
|
||||
}
|
||||
if group := pickPreferredGroup(mmprojGroups, mmprojQuantsList); group != nil {
|
||||
appendShardGroup(&cfg, *group, filepath.Join("llama-cpp", "mmproj", name))
|
||||
@@ -281,6 +299,20 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// modelNameFromShardGroup derives a human-facing model name from the picked
|
||||
// GGUF group: the logical base filename with its .gguf extension stripped.
|
||||
// ShardGroup.Base is the common prefix for sharded sets (without the
|
||||
// -NNNNN-of-MMMMM suffix) and the sole basename for single-file models, so
|
||||
// this yields a clean name like "model-Q4_K_M" rather than an individual
|
||||
// shard filename or the repo-root URI base.
|
||||
func modelNameFromShardGroup(group hfapi.ShardGroup) string {
|
||||
base := group.Base
|
||||
if ext := filepath.Ext(base); strings.EqualFold(ext, ".gguf") {
|
||||
base = strings.TrimSuffix(base, ext)
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// pickPreferredGroup walks the preference list in priority order and returns
|
||||
// the first group whose base filename contains any preference. When nothing
|
||||
// matches, the last group wins — this preserves the historical "if the user
|
||||
@@ -293,7 +325,7 @@ func pickPreferredGroup(groups []hfapi.ShardGroup, prefs []string) *hfapi.ShardG
|
||||
for _, pref := range prefs {
|
||||
lower := strings.ToLower(pref)
|
||||
for i := range groups {
|
||||
if strings.Contains(strings.ToLower(groups[i].Base), lower) {
|
||||
if quantTokenMatches(strings.ToLower(groups[i].Base), lower) {
|
||||
return &groups[i]
|
||||
}
|
||||
}
|
||||
@@ -301,6 +333,39 @@ func pickPreferredGroup(groups []hfapi.ShardGroup, prefs []string) *hfapi.ShardG
|
||||
return &groups[len(groups)-1]
|
||||
}
|
||||
|
||||
// quantTokenMatches reports whether pref appears in base as a whole token
|
||||
// rather than as a substring of a larger alphanumeric run. Both arguments
|
||||
// must already be lowercased.
|
||||
//
|
||||
// A plain strings.Contains is wrong here: `f16` is a substring of `bf16`, so
|
||||
// asking for the `F16` quant used to wrongly select a `BF16` file (#10559).
|
||||
// Only the OUTER edges of the matched preference must hit a boundary — a
|
||||
// non-alphanumeric char (or the start/end of base). Separators inside the
|
||||
// preference itself (e.g. `ud-q4_k_xl`) are intentionally left untouched.
|
||||
func quantTokenMatches(base, pref string) bool {
|
||||
if pref == "" {
|
||||
return false
|
||||
}
|
||||
for start := strings.Index(base, pref); start != -1; {
|
||||
end := start + len(pref)
|
||||
leftOK := start == 0 || !isAlphaNum(base[start-1])
|
||||
rightOK := end == len(base) || !isAlphaNum(base[end])
|
||||
if leftOK && rightOK {
|
||||
return true
|
||||
}
|
||||
next := strings.Index(base[start+1:], pref)
|
||||
if next == -1 {
|
||||
break
|
||||
}
|
||||
start += next + 1
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isAlphaNum(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9')
|
||||
}
|
||||
|
||||
// maybeApplyMTPDefaults parses the picked GGUF header (range-fetched over
|
||||
// HTTP for HF/URL imports) and, if the file declares a Multi-Token Prediction
|
||||
// head, appends the auto-MTP option keys to modelConfig.Options. Failures
|
||||
|
||||
@@ -372,6 +372,160 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("derives the model name from the selected GGUF when no name is given", func() {
|
||||
// Regression for #10587: a repo-root URI has no file component, so
|
||||
// the URI base ("example-GGUF") is just the repo name. With the
|
||||
// name field left blank, the emitted name and model directory must
|
||||
// follow the GGUF file actually selected, not the repository.
|
||||
details := withHF(`{"quantizations":"Q4_K_M"}`,
|
||||
hfFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf", "aaa"),
|
||||
hfFile("Meta-Llama-3-8B-Instruct.Q3_K_M.gguf", "bbb"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("Meta-Llama-3-8B-Instruct.Q4_K_M"))
|
||||
Expect(modelConfig.Files).To(HaveLen(1), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal(
|
||||
"llama-cpp/models/Meta-Llama-3-8B-Instruct.Q4_K_M/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("name: Meta-Llama-3-8B-Instruct.Q4_K_M"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: llama-cpp/models/Meta-Llama-3-8B-Instruct.Q4_K_M/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("derives a clean name from the shard base for split GGUFs when no name is given", func() {
|
||||
// The selected primary file is shard 1; using its raw basename
|
||||
// would leak the -00001-of-00002 suffix into the name. The shard
|
||||
// base must be used so the name is the logical model.
|
||||
details := withHF(``,
|
||||
hfFile("Qwen3-30B-A3B-Q4_K_M-00001-of-00002.gguf", "p1"),
|
||||
hfFile("Qwen3-30B-A3B-Q4_K_M-00002-of-00002.gguf", "p2"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-30B-A3B-Q4_K_M"))
|
||||
Expect(modelConfig.Files).To(HaveLen(2), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal(
|
||||
"llama-cpp/models/Qwen3-30B-A3B-Q4_K_M/Qwen3-30B-A3B-Q4_K_M-00001-of-00002.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: llama-cpp/models/Qwen3-30B-A3B-Q4_K_M/Qwen3-30B-A3B-Q4_K_M-00001-of-00002.gguf"))
|
||||
})
|
||||
|
||||
It("keeps an explicit name over the selected GGUF filename", func() {
|
||||
// Precedence guard: when the user supplies a name it always wins,
|
||||
// even though a GGUF file was selected from the listing.
|
||||
details := withHF(`{"name":"my-custom-name","quantizations":"Q4_K_M"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "aaa"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-custom-name"))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/my-custom-name/model-Q4_K_M.gguf"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("quant token boundary matching", func() {
|
||||
// Regression for #10559: the quant preference must match as a whole
|
||||
// token, not as a substring. Asking for `F16` used to select a
|
||||
// `BF16` mmproj because strings.Contains("...bf16.gguf", "f16") is
|
||||
// true — the leading `b` was ignored.
|
||||
|
||||
const repoBase = "https://huggingface.co/acme/example-GGUF/resolve/main/"
|
||||
|
||||
hfFile := func(path, sha string) hfapi.ModelFile {
|
||||
return hfapi.ModelFile{
|
||||
Path: path,
|
||||
SHA256: sha,
|
||||
URL: repoBase + path,
|
||||
}
|
||||
}
|
||||
|
||||
withHF := func(preferences string, files ...hfapi.ModelFile) Details {
|
||||
d := Details{
|
||||
URI: "https://huggingface.co/acme/example-GGUF",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "acme/example-GGUF",
|
||||
Files: files,
|
||||
},
|
||||
}
|
||||
if preferences != "" {
|
||||
d.Preferences = json.RawMessage(preferences)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
It("selects the F16 mmproj over BF16 (BF16 listed first)", func() {
|
||||
details := withHF(`{"name":"VL","mmproj_quantizations":"F16"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "model"),
|
||||
hfFile("mmproj-x-BF16.gguf", "bf16"),
|
||||
hfFile("mmproj-x-F16.gguf", "f16"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/VL/mmproj-x-F16.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("BF16"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
|
||||
It("selects the F16 mmproj over BF16 (F16 listed first)", func() {
|
||||
details := withHF(`{"name":"VL","mmproj_quantizations":"F16"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "model"),
|
||||
hfFile("mmproj-x-F16.gguf", "f16"),
|
||||
hfFile("mmproj-x-BF16.gguf", "bf16"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/VL/mmproj-x-F16.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).ToNot(ContainSubstring("BF16"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
|
||||
It("selects BF16 when BF16 is the requested mmproj quant", func() {
|
||||
details := withHF(`{"name":"VL","mmproj_quantizations":"BF16"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "model"),
|
||||
hfFile("mmproj-x-F16.gguf", "f16"),
|
||||
hfFile("mmproj-x-BF16.gguf", "bf16"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/VL/mmproj-x-BF16.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
|
||||
It("still matches a normal model quant with internal separators", func() {
|
||||
// ud-q4_k_xl contains `-`/`_` internally; only the outer edges
|
||||
// must hit a token boundary.
|
||||
details := withHF(`{"name":"M","quantizations":"ud-q4_k_xl"}`,
|
||||
hfFile("model-UD-Q4_K_XL.gguf", "xl"),
|
||||
hfFile("model-Q3_K_M.gguf", "q3"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/M/model-UD-Q4_K_XL.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
|
||||
It("falls back to the last group when no preference matches", func() {
|
||||
details := withHF(`{"name":"M","quantizations":"Q2_K"}`,
|
||||
hfFile("model-Q8_0.gguf", "q8"),
|
||||
hfFile("model-Q3_K_M.gguf", "q3"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/M/model-Q3_K_M.gguf"), fmt.Sprintf("%+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("AdditionalBackends", func() {
|
||||
|
||||
@@ -23,8 +23,10 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/distributed"
|
||||
"github.com/mudler/LocalAI/core/services/finetune"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/quantization"
|
||||
|
||||
@@ -400,25 +402,45 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
||||
// Fine-tuning routes
|
||||
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||
// In distributed mode pass the shared NATS client + PostgreSQL store so
|
||||
// fine-tune jobs stay consistent across replicas (the SyncedMap broadcasts
|
||||
// mutations and hydrates from the DB); standalone passes nil for both.
|
||||
var ftNats messaging.MessagingClient
|
||||
var ftStore *distributed.FineTuneStore
|
||||
if d := application.Distributed(); d != nil {
|
||||
ftNats = d.Nats
|
||||
if d.DistStores != nil && d.DistStores.FineTune != nil {
|
||||
ftStore = d.DistStores.FineTune
|
||||
}
|
||||
}
|
||||
ftService := finetune.NewFineTuneService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
ftNats,
|
||||
ftStore,
|
||||
)
|
||||
if d := application.Distributed(); d != nil {
|
||||
ftService.SetNATSClient(d.Nats)
|
||||
if d.DistStores != nil && d.DistStores.FineTune != nil {
|
||||
ftService.SetFineTuneStore(d.DistStores.FineTune)
|
||||
}
|
||||
}
|
||||
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||
|
||||
// Quantization routes
|
||||
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
|
||||
// In distributed mode pass the shared NATS client + PostgreSQL store so
|
||||
// quantization jobs stay consistent across replicas (the SyncedMap broadcasts
|
||||
// mutations and hydrates from the DB); standalone passes nil for both.
|
||||
var quantNats messaging.MessagingClient
|
||||
var quantStore *distributed.QuantStore
|
||||
if d := application.Distributed(); d != nil {
|
||||
quantNats = d.Nats
|
||||
if d.DistStores != nil && d.DistStores.Quant != nil {
|
||||
quantStore = d.DistStores.Quant
|
||||
}
|
||||
}
|
||||
qService := quantization.NewQuantizationService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
quantNats,
|
||||
quantStore,
|
||||
)
|
||||
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
@@ -550,12 +551,23 @@ func DeleteBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerF
|
||||
}
|
||||
|
||||
// ListBackendsOnNodeEndpoint lists installed backends on a worker node via NATS.
|
||||
func ListBackendsOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.HandlerFunc {
|
||||
func ListBackendsOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
nodeID := c.Param("id")
|
||||
// Agent-type workers don't run backends and never subscribe to the
|
||||
// nodes.<id>.backend.list NATS subject, so the request would hang
|
||||
// until timeout with "no responders". Their backend list is simply
|
||||
// empty. Mirror the aggregate-list guard in managers_distributed.go
|
||||
// (skip nodes whose NodeType is set and not "backend") so the
|
||||
// single-node and cluster-wide views stay consistent.
|
||||
if node, err := registry.Get(c.Request().Context(), nodeID); err == nil {
|
||||
if node.NodeType != "" && node.NodeType != nodes.NodeTypeBackend {
|
||||
return c.JSON(http.StatusOK, []messaging.NodeBackendInfo{})
|
||||
}
|
||||
}
|
||||
if unloader == nil {
|
||||
return c.JSON(http.StatusServiceUnavailable, nodeError(http.StatusServiceUnavailable, "NATS not configured"))
|
||||
}
|
||||
nodeID := c.Param("id")
|
||||
reply, err := unloader.ListBackends(nodeID)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to list backends on node", "node", nodeID, "error", err)
|
||||
|
||||
103
core/http/endpoints/localai/nodes_backends_list_test.go
Normal file
103
core/http/endpoints/localai/nodes_backends_list_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// stubNodeCommandSender records whether ListBackends was invoked so the test can
|
||||
// assert the endpoint short-circuits (no NATS request) for agent-type nodes.
|
||||
type stubNodeCommandSender struct {
|
||||
listBackendsCalled bool
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) InstallBackend(_, _, _, _, _, _, _ string, _ int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
|
||||
return &messaging.BackendInstallReply{}, nil
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) UpgradeBackend(_, _, _, _, _, _ string, _ int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
|
||||
return &messaging.BackendUpgradeReply{}, nil
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) DeleteBackend(_, _ string) (*messaging.BackendDeleteReply, error) {
|
||||
return &messaging.BackendDeleteReply{Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) ListBackends(_ string) (*messaging.BackendListReply, error) {
|
||||
s.listBackendsCalled = true
|
||||
return &messaging.BackendListReply{Backends: []messaging.NodeBackendInfo{{Name: "llama-cpp"}}}, nil
|
||||
}
|
||||
|
||||
func (s *stubNodeCommandSender) StopBackend(_, _ string) error { return nil }
|
||||
|
||||
func (s *stubNodeCommandSender) UnloadModelOnNode(_, _ string) error { return nil }
|
||||
|
||||
var _ = Describe("ListBackendsOnNodeEndpoint", func() {
|
||||
var registry *nodes.NodeRegistry
|
||||
|
||||
BeforeEach(func() {
|
||||
db := testutil.SetupTestDB()
|
||||
var err error
|
||||
registry, err = nodes.NewNodeRegistry(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
callEndpoint := func(unloader nodes.NodeCommandSender, nodeID string) *httptest.ResponseRecorder {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetParamNames("id")
|
||||
c.SetParamValues(nodeID)
|
||||
handler := ListBackendsOnNodeEndpoint(unloader, registry)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
It("returns an empty list for an agent node without issuing a NATS request", func() {
|
||||
ctx := context.Background()
|
||||
node := &nodes.BackendNode{Name: "agent-1", NodeType: nodes.NodeTypeAgent}
|
||||
Expect(registry.Register(ctx, node, true)).To(Succeed())
|
||||
|
||||
stub := &stubNodeCommandSender{}
|
||||
rec := callEndpoint(stub, node.ID)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(stub.listBackendsCalled).To(BeFalse(),
|
||||
"agent workers don't subscribe to backend.list; the endpoint must not issue the doomed NATS request")
|
||||
|
||||
var list []messaging.NodeBackendInfo
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &list)).To(Succeed())
|
||||
Expect(list).To(BeEmpty())
|
||||
// Must be `[]`, not `null`, so the UI can render it.
|
||||
Expect(rec.Body.String()).To(ContainSubstring("[]"))
|
||||
})
|
||||
|
||||
It("consults the unloader (NATS) for a backend node", func() {
|
||||
ctx := context.Background()
|
||||
node := &nodes.BackendNode{Name: "backend-1", NodeType: nodes.NodeTypeBackend, Address: "10.0.0.1:50051"}
|
||||
Expect(registry.Register(ctx, node, true)).To(Succeed())
|
||||
|
||||
stub := &stubNodeCommandSender{}
|
||||
rec := callEndpoint(stub, node.ID)
|
||||
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(stub.listBackendsCalled).To(BeTrue(),
|
||||
"backend nodes must still be queried over NATS")
|
||||
|
||||
var list []messaging.NodeBackendInfo
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &list)).To(Succeed())
|
||||
Expect(list).To(HaveLen(1))
|
||||
Expect(list[0].Name).To(Equal("llama-cpp"))
|
||||
})
|
||||
})
|
||||
@@ -618,6 +618,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
finishReason = FinishReasonToolCalls
|
||||
} else if toolsCalled {
|
||||
finishReason = FinishReasonFunctionCall
|
||||
} else if reachedTokenBudget(finalUsage.Completion, config.Maxtokens) {
|
||||
// Generation stopped because it hit the max_tokens ceiling
|
||||
// rather than a natural stop — report "length" (issue #9716).
|
||||
finishReason = FinishReasonLength
|
||||
}
|
||||
|
||||
// Final delta chunk: empty delta with finish_reason set. Per
|
||||
@@ -984,6 +988,18 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}
|
||||
|
||||
// If generation hit the max_tokens ceiling, report "length"
|
||||
// instead of a natural "stop" (issue #9716). Mirrors the
|
||||
// streaming path; tool/function finish reasons are untouched.
|
||||
if reachedTokenBudget(tokenUsage.Completion, config.Maxtokens) {
|
||||
for i := range result {
|
||||
if result[i].FinishReason != nil && *result[i].FinishReason == FinishReasonStop {
|
||||
lengthReason := FinishReasonLength
|
||||
result[i].FinishReason = &lengthReason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute (or no MCP tools configured), return response
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
|
||||
149
core/http/endpoints/openai/compactcoord/compactcoord.go
Normal file
149
core/http/endpoints/openai/compactcoord/compactcoord.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// Package compactcoord is the explicit state machine for the realtime API's
|
||||
// conversation-compaction concern (machine "M4" in
|
||||
// docs/design/realtime-state-machines.md).
|
||||
//
|
||||
// In the legacy code this machine is an implicit single-flight guard: a
|
||||
// per-conversation `compacting atomic.Bool` that maybeCompact CAS-flips to start
|
||||
// a background summarize+evict and a deferred Store(false) clears. The intent —
|
||||
// at most one compaction running per conversation at a time, so two goroutines
|
||||
// never summarize and evict the same overflow concurrently (Part 4, invariant
|
||||
// #9) — is correct but implicit in a bare atomic.
|
||||
//
|
||||
// This package makes it explicit:
|
||||
// - a sealed sum type for State (Idle | Running) — "two compactions running" is
|
||||
// unrepresentable,
|
||||
// - a total, pure transition function Next(state, event) -> (state, effects),
|
||||
// - a single-writer Coordinator that serializes every transition.
|
||||
//
|
||||
// Unlike respcoord (M3), a Trigger while Running is NOT a supersede: compaction
|
||||
// is idempotent work on the same overflow, so a concurrent trigger is simply
|
||||
// dropped (matching the legacy CAS-fails-so-skip), not queued or restarted.
|
||||
package compactcoord
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/coordinator"
|
||||
)
|
||||
|
||||
// State is the sealed sum type of compaction states. Exhaustively:
|
||||
// Idle | Running | Terminated.
|
||||
type State interface {
|
||||
isState()
|
||||
String() string
|
||||
}
|
||||
|
||||
// Idle: no compaction is running.
|
||||
type Idle struct{}
|
||||
|
||||
// Running: exactly one compaction is in flight.
|
||||
type Running struct{}
|
||||
|
||||
// Terminated: the conversation/session is torn down. Absorbing — no compaction
|
||||
// can start from here, so the M1 (connection) parent's teardown can cancel +
|
||||
// join the in-flight compaction and guarantee none outlives the session (see
|
||||
// formal-verification/session_lifecycle.fizz). This closes the legacy gap where
|
||||
// the fire-and-forget compaction goroutine could outlive the session.
|
||||
type Terminated struct{}
|
||||
|
||||
func (Idle) isState() {}
|
||||
func (Running) isState() {}
|
||||
func (Terminated) isState() {}
|
||||
|
||||
func (Idle) String() string { return "Idle" }
|
||||
func (Running) String() string { return "Running" }
|
||||
func (Terminated) String() string { return "Terminated" }
|
||||
|
||||
// Event is the sealed sum type of inputs. Exhaustively:
|
||||
// Trigger | Finished | Shutdown.
|
||||
type Event interface {
|
||||
isEvent()
|
||||
String() string
|
||||
}
|
||||
|
||||
// Trigger requests a compaction (the live buffer grew past the trigger). It
|
||||
// starts one only when Idle; while Running it is a no-op (single-flight).
|
||||
type Trigger struct{}
|
||||
|
||||
// Finished reports that the running compaction goroutine finished (success, error, or
|
||||
// timeout — it always reports Finished so the flag can never stick).
|
||||
type Finished struct{}
|
||||
|
||||
// Shutdown terminates the coordinator at teardown: the in-flight compaction is
|
||||
// cancelled + joined by the sink, and no compaction can start afterwards.
|
||||
type Shutdown struct{}
|
||||
|
||||
func (Trigger) isEvent() {}
|
||||
func (Finished) isEvent() {}
|
||||
func (Shutdown) isEvent() {}
|
||||
|
||||
func (Trigger) String() string { return "Trigger" }
|
||||
func (Finished) String() string { return "Finished" }
|
||||
func (Shutdown) String() string { return "Shutdown" }
|
||||
|
||||
// Effect is a side effect returned by Next as data. Exhaustively: StartCompaction.
|
||||
type Effect interface {
|
||||
isEffect()
|
||||
String() string
|
||||
}
|
||||
|
||||
// StartCompaction: spawn the background summarize+evict goroutine.
|
||||
type StartCompaction struct{}
|
||||
|
||||
func (StartCompaction) isEffect() {}
|
||||
|
||||
func (StartCompaction) String() string { return "StartCompaction" }
|
||||
|
||||
// Next is the total, pure transition function. For every (state, event) it
|
||||
// returns the next state and the ordered effects. It returns a non-nil error
|
||||
// only for an unknown State/Event implementation. Every in-domain pair is
|
||||
// defined; there are no forbidden transitions, only no-ops.
|
||||
//
|
||||
// Single-flight crux: StartCompaction is emitted only on Idle+Trigger, and a
|
||||
// Trigger while Running is a no-op — so at most one compaction ever runs.
|
||||
func Next(s State, e Event) (State, []Effect, error) {
|
||||
switch s.(type) {
|
||||
case Idle:
|
||||
switch e.(type) {
|
||||
case Trigger:
|
||||
return Running{}, []Effect{StartCompaction{}}, nil
|
||||
case Finished:
|
||||
// No compaction to finish: stale/idempotent no-op.
|
||||
return Idle{}, nil, nil
|
||||
case Shutdown:
|
||||
return Terminated{}, nil, nil
|
||||
}
|
||||
case Running:
|
||||
switch e.(type) {
|
||||
case Trigger:
|
||||
// Already compacting: drop (single-flight).
|
||||
return Running{}, nil, nil
|
||||
case Finished:
|
||||
return Idle{}, nil, nil
|
||||
case Shutdown:
|
||||
// Teardown while compacting: the sink cancels + joins the goroutine,
|
||||
// so its later Finished is absorbed here in Terminated.
|
||||
return Terminated{}, nil, nil
|
||||
}
|
||||
case Terminated:
|
||||
// Absorbing: a Trigger after teardown is rejected (no StartCompaction), so
|
||||
// no compaction outlives the session.
|
||||
switch e.(type) {
|
||||
case Trigger, Finished, Shutdown:
|
||||
return Terminated{}, nil, nil
|
||||
}
|
||||
}
|
||||
return s, nil, fmt.Errorf("compactcoord: unhandled transition %s <- %s", s, e)
|
||||
}
|
||||
|
||||
// EffectSink performs the effects produced by a transition. See coordinator.Sink:
|
||||
// StartCompaction spawns a goroutine, so Perform does not block under the lock.
|
||||
type EffectSink = coordinator.Sink[Effect]
|
||||
|
||||
// Coordinator serializes the compaction transitions. See coordinator.Coordinator.
|
||||
type Coordinator = coordinator.Coordinator[State, Event, Effect]
|
||||
|
||||
// New returns an idle Coordinator that performs effects via sink.
|
||||
func New(sink EffectSink) *Coordinator {
|
||||
return coordinator.New[State, Event, Effect](Idle{}, Next, sink)
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package compactcoord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestCompactcoord(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "compactcoord (realtime M4) Suite")
|
||||
}
|
||||
202
core/http/endpoints/openai/compactcoord/compactcoord_test.go
Normal file
202
core/http/endpoints/openai/compactcoord/compactcoord_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package compactcoord
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// recordingSink captures the ordered stream of effects. Perform is called under
|
||||
// the coordinator lock; the mutex here guards reads from the spec goroutine.
|
||||
type recordingSink struct {
|
||||
mu sync.Mutex
|
||||
log []Effect
|
||||
}
|
||||
|
||||
func (s *recordingSink) Perform(e Effect) {
|
||||
s.mu.Lock()
|
||||
s.log = append(s.log, e)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *recordingSink) count() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return len(s.log)
|
||||
}
|
||||
|
||||
type unknownEvent struct{}
|
||||
|
||||
func (unknownEvent) isEvent() {}
|
||||
func (unknownEvent) String() string { return "unknownEvent" }
|
||||
|
||||
type unknownState struct{}
|
||||
|
||||
func (unknownState) isState() {}
|
||||
func (unknownState) String() string { return "unknownState" }
|
||||
|
||||
var _ = Describe("compactcoord.Next", func() {
|
||||
DescribeTable("transitions",
|
||||
func(state State, event Event, wantState State, wantEff []Effect) {
|
||||
gotState, gotEff, err := Next(state, event)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(gotState).To(Equal(wantState))
|
||||
Expect(gotEff).To(Equal(wantEff))
|
||||
},
|
||||
Entry("idle+trigger -> running: start",
|
||||
Idle{}, Trigger{}, Running{}, []Effect{StartCompaction{}}),
|
||||
Entry("idle+finished -> idle, no-op (stale)",
|
||||
Idle{}, Finished{}, Idle{}, []Effect(nil)),
|
||||
Entry("running+trigger -> running, no-op (single-flight)",
|
||||
Running{}, Trigger{}, Running{}, []Effect(nil)),
|
||||
Entry("running+finished -> idle",
|
||||
Running{}, Finished{}, Idle{}, []Effect(nil)),
|
||||
Entry("idle+shutdown -> terminated",
|
||||
Idle{}, Shutdown{}, Terminated{}, []Effect(nil)),
|
||||
Entry("running+shutdown -> terminated",
|
||||
Running{}, Shutdown{}, Terminated{}, []Effect(nil)),
|
||||
Entry("terminated+trigger -> terminated, REJECTED",
|
||||
Terminated{}, Trigger{}, Terminated{}, []Effect(nil)),
|
||||
Entry("terminated+finished -> terminated, no-op (stale)",
|
||||
Terminated{}, Finished{}, Terminated{}, []Effect(nil)),
|
||||
Entry("terminated+shutdown -> terminated, idempotent",
|
||||
Terminated{}, Shutdown{}, Terminated{}, []Effect(nil)),
|
||||
)
|
||||
|
||||
It("is total over the defined (state, event) pairs", func() {
|
||||
for _, s := range []State{Idle{}, Running{}, Terminated{}} {
|
||||
for _, e := range []Event{Trigger{}, Finished{}, Shutdown{}} {
|
||||
_, _, err := Next(s, e)
|
||||
Expect(err).NotTo(HaveOccurred(), "Next(%s, %s)", s, e)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("errors on an unknown event type", func() {
|
||||
_, _, err := Next(Idle{}, unknownEvent{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors on an unknown state type", func() {
|
||||
_, _, err := Next(unknownState{}, Trigger{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("compactcoord.Coordinator", func() {
|
||||
// A StartCompaction is only ever produced while Idle (verified by checking the
|
||||
// effect count grows exactly when the model transitions Idle->Running), so at
|
||||
// most one compaction is ever in flight.
|
||||
It("starts at most one compaction at a time over random sequences", func() {
|
||||
seeds := []uint64{1, 2, 3, 42, 1337, 0xC0FFEE}
|
||||
for _, seed := range seeds {
|
||||
r := rand.New(rand.NewPCG(seed, 0xA5A5A5A5))
|
||||
sink := &recordingSink{}
|
||||
c := New(sink)
|
||||
running := false
|
||||
starts := 0
|
||||
|
||||
for range 5000 {
|
||||
if r.IntN(2) == 0 {
|
||||
before := sink.count()
|
||||
Expect(c.Apply(Trigger{})).To(Succeed())
|
||||
if sink.count() > before {
|
||||
// A StartCompaction was produced: must have been Idle.
|
||||
Expect(running).To(BeFalse(), "seed=%d: started while already running", seed)
|
||||
running = true
|
||||
starts++
|
||||
}
|
||||
} else {
|
||||
Expect(c.Apply(Finished{})).To(Succeed())
|
||||
running = false
|
||||
}
|
||||
if running {
|
||||
Expect(c.State()).To(Equal(State(Running{})), "seed=%d", seed)
|
||||
} else {
|
||||
Expect(c.State()).To(Equal(State(Idle{})), "seed=%d", seed)
|
||||
}
|
||||
}
|
||||
Expect(starts).To(BeNumerically(">", 0), "seed=%d: walk should have started at least one", seed)
|
||||
}
|
||||
})
|
||||
|
||||
// Faithful concurrent test: StartCompaction spawns "work" that bumps an active
|
||||
// counter, runs, and reports Finished back to the coordinator (exactly how the
|
||||
// real sink behaves). Single-flight must hold even under many concurrent
|
||||
// Triggers: the active counter never exceeds 1. Run under -race.
|
||||
It("never runs two compactions concurrently", func() {
|
||||
var active, maxActive int32
|
||||
var c *Coordinator
|
||||
var work sync.WaitGroup
|
||||
sink := &spawnSink{onStart: func() {
|
||||
work.Add(1)
|
||||
go func() {
|
||||
defer work.Done()
|
||||
n := atomic.AddInt32(&active, 1)
|
||||
for {
|
||||
m := atomic.LoadInt32(&maxActive)
|
||||
if n <= m || atomic.CompareAndSwapInt32(&maxActive, m, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
atomic.AddInt32(&active, -1)
|
||||
_ = c.Apply(Finished{})
|
||||
}()
|
||||
}}
|
||||
c = New(sink)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for g := 0; g < 8; g++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 1000 {
|
||||
_ = c.Apply(Trigger{})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
work.Wait() // let any in-flight compaction report Finished
|
||||
|
||||
Expect(atomic.LoadInt32(&maxActive)).To(BeNumerically("<=", 1))
|
||||
Expect(c.State()).To(Equal(State(Idle{})))
|
||||
})
|
||||
|
||||
It("terminates on shutdown and rejects later triggers", func() {
|
||||
sink := &recordingSink{}
|
||||
c := New(sink)
|
||||
Expect(c.Apply(Trigger{})).To(Succeed()) // Idle -> Running (StartCompaction)
|
||||
Expect(c.Apply(Shutdown{})).To(Succeed())
|
||||
Expect(c.State()).To(Equal(State(Terminated{})))
|
||||
|
||||
before := sink.count()
|
||||
Expect(c.Apply(Trigger{})).To(Succeed()) // rejected
|
||||
Expect(sink.count()).To(Equal(before), "no StartCompaction after shutdown")
|
||||
Expect(c.Apply(Finished{})).To(Succeed()) // stale, absorbed
|
||||
Expect(c.State()).To(Equal(State(Terminated{})))
|
||||
})
|
||||
})
|
||||
|
||||
// spawnSink invokes onStart for each StartCompaction (called under the coord lock;
|
||||
// onStart must be non-blocking — it spawns the work goroutine).
|
||||
type spawnSink struct{ onStart func() }
|
||||
|
||||
func (s *spawnSink) Perform(e Effect) {
|
||||
if _, ok := e.(StartCompaction); ok {
|
||||
s.onStart()
|
||||
}
|
||||
}
|
||||
|
||||
var _ = DescribeTable("compactcoord stringers",
|
||||
func(got, want string) { Expect(got).To(Equal(want)) },
|
||||
Entry(nil, Idle{}.String(), "Idle"),
|
||||
Entry(nil, Running{}.String(), "Running"),
|
||||
Entry(nil, Terminated{}.String(), "Terminated"),
|
||||
Entry(nil, Trigger{}.String(), "Trigger"),
|
||||
Entry(nil, Finished{}.String(), "Finished"),
|
||||
Entry(nil, Shutdown{}.String(), "Shutdown"),
|
||||
Entry(nil, StartCompaction{}.String(), "StartCompaction"),
|
||||
)
|
||||
164
core/http/endpoints/openai/conncoord/conncoord.go
Normal file
164
core/http/endpoints/openai/conncoord/conncoord.go
Normal file
@@ -0,0 +1,164 @@
|
||||
// Package conncoord is the explicit state machine for the realtime API's
|
||||
// connection lifecycle (machine "M1" in docs/design/realtime-state-machines.md).
|
||||
//
|
||||
// In the legacy code this machine is implicit and fragile. The session handler
|
||||
// keeps a `vadServerStarted` bool plus a `done` channel that is REASSIGNED to a
|
||||
// fresh channel every time turn detection is toggled on (session.update) and
|
||||
// closed both at toggle-off and at teardown (Part 2, failure mode 6). It is
|
||||
// correct today only because one goroutine owns it; "one variable name meaning
|
||||
// different channels over time, closed from two sites guarded by a bool" is a
|
||||
// structural hazard, not an explicit lifecycle. Teardown likewise depends on the
|
||||
// bool to avoid closing an already-closed channel.
|
||||
//
|
||||
// This package makes the lifecycle explicit:
|
||||
// - a sealed sum type for State (Live{VADRunning} | Torn) — illegal states
|
||||
// such as "running after teardown" are unrepresentable,
|
||||
// - a total, pure transition function Next(state, event) -> (state, effects),
|
||||
// - a single-writer Coordinator that serializes every transition.
|
||||
//
|
||||
// The guarantees the spec checks:
|
||||
// - the VAD goroutine's done channel is closed exactly once per start (StopVAD
|
||||
// is emitted only while running, so never a double close / close of nil),
|
||||
// - teardown runs exactly once (Close from Live; any later Close is a no-op),
|
||||
// - nothing is started after teardown (no resurrection / no send-after-close).
|
||||
//
|
||||
// Like turncoord (M2), the connection machine is driven by the single session
|
||||
// goroutine; the Coordinator's lock keeps State() race-free and guards against a
|
||||
// future second writer. The effects are performed by a sink that owns the actual
|
||||
// channels/goroutines (see realtime_conncoord.go).
|
||||
package conncoord
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/coordinator"
|
||||
)
|
||||
|
||||
// State is the sealed sum type of connection states. The only implementations
|
||||
// are the marker-method structs in this file. Exhaustively: Live | Torn.
|
||||
type State interface {
|
||||
isState()
|
||||
String() string
|
||||
}
|
||||
|
||||
// Live: the session is active. VADRunning records whether the turn-detection
|
||||
// (handleVAD) goroutine is currently running — the single source of truth that
|
||||
// replaces the legacy vadServerStarted bool, so the per-run done channel is
|
||||
// closed exactly once.
|
||||
type Live struct{ VADRunning bool }
|
||||
|
||||
// Torn: the session has been torn down. Terminal — no effect is ever produced
|
||||
// from here again.
|
||||
type Torn struct{}
|
||||
|
||||
func (Live) isState() {}
|
||||
func (Torn) isState() {}
|
||||
|
||||
func (s Live) String() string { return fmt.Sprintf("Live(vad=%t)", s.VADRunning) }
|
||||
func (Torn) String() string { return "Torn" }
|
||||
|
||||
// Event is the sealed sum type of inputs. Exhaustively: SetVAD | Close.
|
||||
type Event interface {
|
||||
isEvent()
|
||||
String() string
|
||||
}
|
||||
|
||||
// SetVAD requests the turn-detection goroutine be running (Active) or not. It is
|
||||
// raised whenever session.update changes whether turn detection is active. It is
|
||||
// idempotent: setting the state it is already in is a no-op.
|
||||
type SetVAD struct{ Active bool }
|
||||
|
||||
// Close requests teardown (the transport read loop ended, or the session is
|
||||
// closing). It is idempotent — only the first Close from Live tears down.
|
||||
type Close struct{}
|
||||
|
||||
func (SetVAD) isEvent() {}
|
||||
func (Close) isEvent() {}
|
||||
|
||||
func (e SetVAD) String() string { return fmt.Sprintf("SetVAD(%t)", e.Active) }
|
||||
func (Close) String() string { return "Close" }
|
||||
|
||||
// Effect is a side effect returned by Next as data for the caller to perform.
|
||||
// Exhaustively: StartVAD | StopVAD | Teardown.
|
||||
type Effect interface {
|
||||
isEffect()
|
||||
String() string
|
||||
}
|
||||
|
||||
// StartVAD: create a fresh done channel and spawn the handleVAD goroutine on it.
|
||||
type StartVAD struct{}
|
||||
|
||||
// StopVAD: close the running VAD goroutine's done channel (signal it to exit).
|
||||
type StopVAD struct{}
|
||||
|
||||
// Teardown: the once-only teardown — stop the remaining input goroutines (opus
|
||||
// decode, sound window), join them, cancel in-flight responses, and remove the
|
||||
// session from the registry. Emitted exactly once.
|
||||
type Teardown struct{}
|
||||
|
||||
func (StartVAD) isEffect() {}
|
||||
func (StopVAD) isEffect() {}
|
||||
func (Teardown) isEffect() {}
|
||||
|
||||
func (StartVAD) String() string { return "StartVAD" }
|
||||
func (StopVAD) String() string { return "StopVAD" }
|
||||
func (Teardown) String() string { return "Teardown" }
|
||||
|
||||
// Next is the total, pure transition function. For every (state, event) it
|
||||
// returns the next state and the ordered effects to perform. It returns a
|
||||
// non-nil error only for an unknown State/Event implementation. Every in-domain
|
||||
// pair is defined; there are no forbidden transitions, only no-ops.
|
||||
//
|
||||
// The crux: Close moves to Torn, which absorbs every later event with no
|
||||
// effects. So teardown's channel closes happen exactly once even if Close is
|
||||
// raised again (e.g. an error path and the normal return both reaching it), and
|
||||
// no StartVAD can resurrect a torn session.
|
||||
func Next(s State, e Event) (State, []Effect, error) {
|
||||
switch st := s.(type) {
|
||||
case Live:
|
||||
switch ev := e.(type) {
|
||||
case SetVAD:
|
||||
switch {
|
||||
case ev.Active && !st.VADRunning:
|
||||
return Live{VADRunning: true}, []Effect{StartVAD{}}, nil
|
||||
case !ev.Active && st.VADRunning:
|
||||
return Live{VADRunning: false}, []Effect{StopVAD{}}, nil
|
||||
default:
|
||||
// Already in the requested state: idempotent no-op.
|
||||
return Live{VADRunning: st.VADRunning}, nil, nil
|
||||
}
|
||||
case Close:
|
||||
if st.VADRunning {
|
||||
return Torn{}, []Effect{StopVAD{}, Teardown{}}, nil
|
||||
}
|
||||
return Torn{}, []Effect{Teardown{}}, nil
|
||||
}
|
||||
case Torn:
|
||||
switch e.(type) {
|
||||
case SetVAD:
|
||||
// No resurrection: a toggle after teardown is ignored.
|
||||
return Torn{}, nil, nil
|
||||
case Close:
|
||||
// Idempotent: teardown already ran.
|
||||
return Torn{}, nil, nil
|
||||
}
|
||||
}
|
||||
return s, nil, fmt.Errorf("conncoord: unhandled transition %s <- %s", s, e)
|
||||
}
|
||||
|
||||
// EffectSink performs the effects produced by a transition. See coordinator.Sink:
|
||||
// Perform runs under the coordinator lock. The Teardown effect does join
|
||||
// goroutines (which can block) — acceptable here because the connection
|
||||
// coordinator is single-writer and torn down exactly once at the end of the
|
||||
// session goroutine, so no other Apply is contending the lock.
|
||||
type EffectSink = coordinator.Sink[Effect]
|
||||
|
||||
// Coordinator serializes the connection-lifecycle transitions.
|
||||
// See coordinator.Coordinator.
|
||||
type Coordinator = coordinator.Coordinator[State, Event, Effect]
|
||||
|
||||
// New returns a Coordinator in Live{VADRunning:false} that performs effects via
|
||||
// sink.
|
||||
func New(sink EffectSink) *Coordinator {
|
||||
return coordinator.New[State, Event, Effect](Live{VADRunning: false}, Next, sink)
|
||||
}
|
||||
13
core/http/endpoints/openai/conncoord/conncoord_suite_test.go
Normal file
13
core/http/endpoints/openai/conncoord/conncoord_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package conncoord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestConncoord(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "conncoord (realtime M1) Suite")
|
||||
}
|
||||
212
core/http/endpoints/openai/conncoord/conncoord_test.go
Normal file
212
core/http/endpoints/openai/conncoord/conncoord_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package conncoord
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// recordingSink captures the ordered stream of effects so the invariants can be
|
||||
// checked independently of the transition function. Perform is called by
|
||||
// Coordinator.Apply under the coordinator lock; the mutex here only guards reads
|
||||
// from the spec goroutine.
|
||||
type recordingSink struct {
|
||||
mu sync.Mutex
|
||||
log []Effect
|
||||
}
|
||||
|
||||
func (s *recordingSink) Perform(e Effect) {
|
||||
s.mu.Lock()
|
||||
s.log = append(s.log, e)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *recordingSink) snapshot() []Effect {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]Effect, len(s.log))
|
||||
copy(out, s.log)
|
||||
return out
|
||||
}
|
||||
|
||||
// checkLog replays the effect log and asserts the lifecycle safety properties
|
||||
// from docs/design/realtime-state-machines.md, Part 4 (invariants #8, #10 and
|
||||
// failure mode 6):
|
||||
//
|
||||
// (1) the VAD done channel is closed exactly once per start -- StartVAD only
|
||||
// while stopped, StopVAD only while running (no double close / close-of-nil);
|
||||
// (2) teardown runs at most once;
|
||||
// (3) no resurrection -- no StartVAD after Teardown.
|
||||
func checkLog(log []Effect) {
|
||||
running := false
|
||||
torn := false
|
||||
teardowns := 0
|
||||
for i, eff := range log {
|
||||
switch eff.(type) {
|
||||
case StartVAD:
|
||||
Expect(torn).To(BeFalse(), "invariant (3): StartVAD after teardown (effect #%d)\nlog=%v", i, log)
|
||||
Expect(running).To(BeFalse(), "invariant (1): StartVAD while already running (effect #%d)\nlog=%v", i, log)
|
||||
running = true
|
||||
case StopVAD:
|
||||
Expect(running).To(BeTrue(), "invariant (1): StopVAD while not running (effect #%d)\nlog=%v", i, log)
|
||||
running = false
|
||||
case Teardown:
|
||||
Expect(torn).To(BeFalse(), "invariant (2): Teardown twice (effect #%d)\nlog=%v", i, log)
|
||||
torn = true
|
||||
teardowns++
|
||||
}
|
||||
}
|
||||
Expect(teardowns).To(BeNumerically("<=", 1), "invariant (2): teardown ran %d times\nlog=%v", teardowns, log)
|
||||
}
|
||||
|
||||
type unknownEvent struct{}
|
||||
|
||||
func (unknownEvent) isEvent() {}
|
||||
func (unknownEvent) String() string { return "unknownEvent" }
|
||||
|
||||
type unknownState struct{}
|
||||
|
||||
func (unknownState) isState() {}
|
||||
func (unknownState) String() string { return "unknownState" }
|
||||
|
||||
var _ = Describe("conncoord.Next", func() {
|
||||
DescribeTable("transitions",
|
||||
func(state State, event Event, wantState State, wantEff []Effect) {
|
||||
gotState, gotEff, err := Next(state, event)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(gotState).To(Equal(wantState))
|
||||
Expect(gotEff).To(Equal(wantEff))
|
||||
},
|
||||
Entry("stopped+setvad(on) -> running: start",
|
||||
Live{VADRunning: false}, SetVAD{Active: true},
|
||||
Live{VADRunning: true}, []Effect{StartVAD{}}),
|
||||
Entry("running+setvad(on) -> running, no-op",
|
||||
Live{VADRunning: true}, SetVAD{Active: true},
|
||||
Live{VADRunning: true}, []Effect(nil)),
|
||||
Entry("stopped+setvad(off) -> stopped, no-op",
|
||||
Live{VADRunning: false}, SetVAD{Active: false},
|
||||
Live{VADRunning: false}, []Effect(nil)),
|
||||
Entry("running+setvad(off) -> stopped: stop",
|
||||
Live{VADRunning: true}, SetVAD{Active: false},
|
||||
Live{VADRunning: false}, []Effect{StopVAD{}}),
|
||||
Entry("stopped+close -> torn: teardown",
|
||||
Live{VADRunning: false}, Close{},
|
||||
Torn{}, []Effect{Teardown{}}),
|
||||
Entry("running+close -> torn: stop + teardown",
|
||||
Live{VADRunning: true}, Close{},
|
||||
Torn{}, []Effect{StopVAD{}, Teardown{}}),
|
||||
Entry("torn+setvad(on) -> torn, no-op (no resurrection)",
|
||||
Torn{}, SetVAD{Active: true},
|
||||
Torn{}, []Effect(nil)),
|
||||
Entry("torn+close -> torn, no-op (idempotent)",
|
||||
Torn{}, Close{},
|
||||
Torn{}, []Effect(nil)),
|
||||
)
|
||||
|
||||
It("is total over the defined (state, event) pairs", func() {
|
||||
states := []State{Live{VADRunning: false}, Live{VADRunning: true}, Torn{}}
|
||||
events := []Event{SetVAD{Active: true}, SetVAD{Active: false}, Close{}}
|
||||
for _, s := range states {
|
||||
for _, e := range events {
|
||||
_, _, err := Next(s, e)
|
||||
Expect(err).NotTo(HaveOccurred(), "Next(%s, %s)", s, e)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("errors on an unknown event type", func() {
|
||||
_, _, err := Next(Live{}, unknownEvent{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors on an unknown state type", func() {
|
||||
_, _, err := Next(unknownState{}, Close{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("conncoord.Coordinator", func() {
|
||||
It("upholds the lifecycle invariants over random event sequences", func() {
|
||||
seeds := []uint64{1, 2, 3, 42, 1337, 0xC0FFEE}
|
||||
for _, seed := range seeds {
|
||||
r := rand.New(rand.NewPCG(seed, 0xA5A5A5A5))
|
||||
sink := &recordingSink{}
|
||||
c := New(sink)
|
||||
running := false
|
||||
torn := false
|
||||
|
||||
for range 5000 {
|
||||
switch r.IntN(3) {
|
||||
case 0:
|
||||
Expect(c.Apply(SetVAD{Active: true})).To(Succeed())
|
||||
if !torn {
|
||||
running = true
|
||||
}
|
||||
case 1:
|
||||
Expect(c.Apply(SetVAD{Active: false})).To(Succeed())
|
||||
if !torn {
|
||||
running = false
|
||||
}
|
||||
case 2:
|
||||
Expect(c.Apply(Close{})).To(Succeed())
|
||||
torn = true
|
||||
running = false
|
||||
}
|
||||
if torn {
|
||||
Expect(c.State()).To(Equal(State(Torn{})), "seed=%d", seed)
|
||||
} else {
|
||||
Expect(c.State()).To(Equal(State(Live{VADRunning: running})), "seed=%d", seed)
|
||||
}
|
||||
}
|
||||
checkLog(sink.snapshot())
|
||||
}
|
||||
})
|
||||
|
||||
It("tears down at most once under concurrent SetVAD/Close from two goroutines", func() {
|
||||
const perGoroutine = 2000
|
||||
sink := &recordingSink{}
|
||||
c := New(sink)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
drive := func(active bool) {
|
||||
defer wg.Done()
|
||||
for i := range perGoroutine {
|
||||
switch i % 3 {
|
||||
case 0:
|
||||
_ = c.Apply(SetVAD{Active: active})
|
||||
case 1:
|
||||
_ = c.Apply(SetVAD{Active: !active})
|
||||
case 2:
|
||||
if i > perGoroutine/2 {
|
||||
_ = c.Apply(Close{})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go drive(true)
|
||||
go drive(false)
|
||||
wg.Wait()
|
||||
_ = c.Apply(Close{})
|
||||
|
||||
checkLog(sink.snapshot())
|
||||
Expect(c.State()).To(Equal(State(Torn{})))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = DescribeTable("conncoord stringers",
|
||||
func(got, want string) { Expect(got).To(Equal(want)) },
|
||||
Entry(nil, Live{VADRunning: true}.String(), "Live(vad=true)"),
|
||||
Entry(nil, Live{VADRunning: false}.String(), "Live(vad=false)"),
|
||||
Entry(nil, Torn{}.String(), "Torn"),
|
||||
|
||||
Entry(nil, SetVAD{Active: true}.String(), "SetVAD(true)"),
|
||||
Entry(nil, Close{}.String(), "Close"),
|
||||
|
||||
Entry(nil, StartVAD{}.String(), "StartVAD"),
|
||||
Entry(nil, StopVAD{}.String(), "StopVAD"),
|
||||
Entry(nil, Teardown{}.String(), "Teardown"),
|
||||
)
|
||||
@@ -5,4 +5,7 @@ const (
|
||||
FinishReasonStop = "stop"
|
||||
FinishReasonToolCalls = "tool_calls"
|
||||
FinishReasonFunctionCall = "function_call"
|
||||
// FinishReasonLength is reported when generation stopped because it
|
||||
// reached the max_tokens budget rather than a natural stop (issue #9716).
|
||||
FinishReasonLength = "length"
|
||||
)
|
||||
|
||||
82
core/http/endpoints/openai/coordinator/coordinator.go
Normal file
82
core/http/endpoints/openai/coordinator/coordinator.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Package coordinator is the shared single-writer state-machine runtime for the
|
||||
// realtime API's explicit coordinators (machines M1–M5 in
|
||||
// docs/design/realtime-state-machines.md).
|
||||
//
|
||||
// Each machine package (respcoord, turncoord, conncoord, compactcoord, ttscoord)
|
||||
// defines its OWN sealed sum types for State/Event/Effect and a total, pure
|
||||
// transition function Next(state, event) -> (state, []effect, error). The
|
||||
// plumbing around that — a single-writer Coordinator that serializes every
|
||||
// transition behind one lock and performs the returned effects in order — is
|
||||
// identical across all five, so it lives here once instead of being copied.
|
||||
//
|
||||
// A machine package wires itself up with three lines:
|
||||
//
|
||||
// type EffectSink = coordinator.Sink[Effect]
|
||||
// type Coordinator = coordinator.Coordinator[State, Event, Effect]
|
||||
// func New(sink EffectSink) *Coordinator { return coordinator.New[State, Event, Effect](Idle{}, Next, sink) }
|
||||
//
|
||||
// The aliases keep each package's public API (Coordinator, New, EffectSink,
|
||||
// Apply, State) unchanged. The single-writer serialization — the load-bearing
|
||||
// concurrency guarantee the FizzBee specs check — is therefore implemented and
|
||||
// reasoned about in exactly one place.
|
||||
package coordinator
|
||||
|
||||
import "sync"
|
||||
|
||||
// TransitionFunc is a machine's total, pure transition: given the current state
|
||||
// and an event it returns the next state, the ordered effects to perform, and a
|
||||
// non-nil error ONLY for an unhandled (programmer-error) state/event pair. It
|
||||
// must not perform I/O or block; side effects are returned as data (F) for the
|
||||
// Coordinator to hand to the Sink.
|
||||
type TransitionFunc[S, E, F any] func(state S, event E) (S, []F, error)
|
||||
|
||||
// Sink performs the effects a transition produces. Implementations MUST be
|
||||
// non-blocking: Perform is called while the Coordinator holds its lock, so it
|
||||
// must not block (it should spawn a goroutine, call a cancel func, or do a
|
||||
// non-blocking channel send) and MUST NOT call back into the same Coordinator's
|
||||
// Apply.
|
||||
type Sink[F any] interface {
|
||||
Perform(F)
|
||||
}
|
||||
|
||||
// Coordinator is the single-writer wrapper around a pure transition function.
|
||||
// Every Apply is serialized by mu, so multiple goroutines can drive the machine
|
||||
// without racing, and a transition's effects are performed in order under the
|
||||
// lock (before any subsequent Apply can observe the new state).
|
||||
type Coordinator[S, E, F any] struct {
|
||||
mu sync.Mutex
|
||||
state S
|
||||
next TransitionFunc[S, E, F]
|
||||
sink Sink[F]
|
||||
}
|
||||
|
||||
// New returns a Coordinator in the given initial state that transitions via next
|
||||
// and performs effects via sink.
|
||||
func New[S, E, F any](initial S, next TransitionFunc[S, E, F], sink Sink[F]) *Coordinator[S, E, F] {
|
||||
return &Coordinator[S, E, F]{state: initial, next: next, sink: sink}
|
||||
}
|
||||
|
||||
// Apply runs one transition under the lock and performs its effects in order. If
|
||||
// the transition function returns an error (an unhandled state/event), the state
|
||||
// is left unchanged and the error is returned to the caller — never silently
|
||||
// swallowed.
|
||||
func (c *Coordinator[S, E, F]) Apply(e E) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
ns, effects, err := c.next(c.state, e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.state = ns
|
||||
for _, eff := range effects {
|
||||
c.sink.Perform(eff)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// State returns the current state (a value; safe to call concurrently).
|
||||
func (c *Coordinator[S, E, F]) State() S {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.state
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package coordinator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestCoordinator(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "coordinator (shared runtime) Suite")
|
||||
}
|
||||
124
core/http/endpoints/openai/coordinator/coordinator_test.go
Normal file
124
core/http/endpoints/openai/coordinator/coordinator_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package coordinator
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// A tiny toy machine exercises the generic runtime directly (the five real
|
||||
// machines exercise it via their aliases, but the gate measures this package's
|
||||
// own coverage). off <-toggle-> on; burst emits three ordered effects; boom is
|
||||
// the unhandled/error path.
|
||||
type tstate int
|
||||
|
||||
const (
|
||||
off tstate = iota
|
||||
on
|
||||
)
|
||||
|
||||
type tevent int
|
||||
|
||||
const (
|
||||
toggle tevent = iota
|
||||
burst
|
||||
boom
|
||||
)
|
||||
|
||||
type teffect string
|
||||
|
||||
func tnext(s tstate, e tevent) (tstate, []teffect, error) {
|
||||
switch e {
|
||||
case toggle:
|
||||
if s == off {
|
||||
return on, []teffect{"on"}, nil
|
||||
}
|
||||
return off, []teffect{"off"}, nil
|
||||
case burst:
|
||||
return s, []teffect{"a", "b", "c"}, nil
|
||||
case boom:
|
||||
return s, nil, errors.New("boom: unhandled")
|
||||
}
|
||||
return s, nil, fmt.Errorf("unknown event %d", int(e))
|
||||
}
|
||||
|
||||
type recordingSink struct {
|
||||
mu sync.Mutex
|
||||
log []teffect
|
||||
}
|
||||
|
||||
func (s *recordingSink) Perform(e teffect) {
|
||||
s.mu.Lock()
|
||||
s.log = append(s.log, e)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *recordingSink) snapshot() []teffect {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]teffect, len(s.log))
|
||||
copy(out, s.log)
|
||||
return out
|
||||
}
|
||||
|
||||
var _ = Describe("coordinator.Coordinator", func() {
|
||||
It("starts in the initial state", func() {
|
||||
c := New[tstate, tevent, teffect](off, tnext, &recordingSink{})
|
||||
Expect(c.State()).To(Equal(off))
|
||||
})
|
||||
|
||||
It("advances state and performs the transition's effects", func() {
|
||||
sink := &recordingSink{}
|
||||
c := New[tstate, tevent, teffect](off, tnext, sink)
|
||||
|
||||
Expect(c.Apply(toggle)).To(Succeed())
|
||||
Expect(c.State()).To(Equal(on))
|
||||
Expect(c.Apply(toggle)).To(Succeed())
|
||||
Expect(c.State()).To(Equal(off))
|
||||
|
||||
Expect(sink.snapshot()).To(Equal([]teffect{"on", "off"}))
|
||||
})
|
||||
|
||||
It("performs multiple effects in order", func() {
|
||||
sink := &recordingSink{}
|
||||
c := New[tstate, tevent, teffect](off, tnext, sink)
|
||||
Expect(c.Apply(burst)).To(Succeed())
|
||||
Expect(sink.snapshot()).To(Equal([]teffect{"a", "b", "c"}))
|
||||
})
|
||||
|
||||
It("returns the transition error and leaves state unchanged", func() {
|
||||
sink := &recordingSink{}
|
||||
c := New[tstate, tevent, teffect](on, tnext, sink)
|
||||
err := c.Apply(boom)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(c.State()).To(Equal(on), "state unchanged on error")
|
||||
Expect(sink.snapshot()).To(BeEmpty(), "no effects performed on error")
|
||||
})
|
||||
|
||||
It("serializes concurrent Apply from many goroutines (run with -race)", func() {
|
||||
const goroutines = 8
|
||||
const each = 1000
|
||||
sink := &recordingSink{}
|
||||
c := New[tstate, tevent, teffect](off, tnext, sink)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range each {
|
||||
_ = c.Apply(toggle)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// goroutines*each toggles from off; an even total returns to off. The
|
||||
// point is race-freedom + a consistent final state, not the value itself.
|
||||
Expect(c.State()).To(Equal(off))
|
||||
Expect(sink.snapshot()).To(HaveLen(goroutines * each))
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user