mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-26 19:58:58 -05:00
Compare commits
117 Commits
v3.11.0
...
faster-qwe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e169492543 | ||
|
|
51c26f1f39 | ||
|
|
65082b3a6f | ||
|
|
0483d47674 | ||
|
|
8ad40091a6 | ||
|
|
8bfe458fbc | ||
|
|
657ba8cdad | ||
|
|
fb86f6461d | ||
|
|
1027c487a6 | ||
|
|
bb226d1eaa | ||
|
|
b032cf489b | ||
|
|
3ac7301f31 | ||
|
|
c4783a0a05 | ||
|
|
c44f03b882 | ||
|
|
eeec92af78 | ||
|
|
842033b8b5 | ||
|
|
a2941228a7 | ||
|
|
791e6b84ee | ||
|
|
d845c39963 | ||
|
|
1331e23b67 | ||
|
|
36ff2a0138 | ||
|
|
db6ba4ef07 | ||
|
|
d19dcac863 | ||
|
|
fd42675bec | ||
|
|
3391538806 | ||
|
|
c4f879c4ea | ||
|
|
b7e0de54fe | ||
|
|
f0868acdf3 | ||
|
|
9a5b5ee8a9 | ||
|
|
ed0bfb8732 | ||
|
|
be84b1d258 | ||
|
|
cbedcc9091 | ||
|
|
e45d63c86e | ||
|
|
f40c8dd0ce | ||
|
|
559ab99890 | ||
|
|
91f2dd5820 | ||
|
|
8250815763 | ||
|
|
b1b67b973e | ||
|
|
fcecc12e57 | ||
|
|
51902df7ba | ||
|
|
05f3ae31de | ||
|
|
bb0924dff1 | ||
|
|
51eec4e6b8 | ||
|
|
462c82fad2 | ||
|
|
352b8aaa1b | ||
|
|
df792d6243 | ||
|
|
b1c434f0fc | ||
|
|
bb42b342de | ||
|
|
e555057f8b | ||
|
|
76fba02e56 | ||
|
|
dadc7158fb | ||
|
|
68c7077491 | ||
|
|
b471619ad9 | ||
|
|
a0476d5567 | ||
|
|
a2228f1418 | ||
|
|
7dd9a155a3 | ||
|
|
4fe830ff58 | ||
|
|
86b3bc9313 | ||
|
|
2fabdc08e6 | ||
|
|
ed832cf0e0 | ||
|
|
95db1da309 | ||
|
|
9e692967c3 | ||
|
|
ecba23d44e | ||
|
|
067a255435 | ||
|
|
637ecba382 | ||
|
|
46c64e59f5 | ||
|
|
f806838c37 | ||
|
|
074a982853 | ||
|
|
109f29cc24 | ||
|
|
587e4a21b3 | ||
|
|
3f1f58b2ab | ||
|
|
01eb70caff | ||
|
|
d784851337 | ||
|
|
1c4e5aa5c0 | ||
|
|
94df096fb9 | ||
|
|
820bd7dd01 | ||
|
|
42cb7bda19 | ||
|
|
2fb9940b8a | ||
|
|
2ff0ad4190 | ||
|
|
bd12103ed4 | ||
|
|
2e17edd72a | ||
|
|
24aab68b3f | ||
|
|
5bdbb10593 | ||
|
|
2fd026e958 | ||
|
|
08718b656e | ||
|
|
7121b189f7 | ||
|
|
f6c80a6987 | ||
|
|
4a4d65f8e8 | ||
|
|
2858e71606 | ||
|
|
088205339c | ||
|
|
8616397d59 | ||
|
|
1698f92bd0 | ||
|
|
02c95a57ae | ||
|
|
2ab6be1d0c | ||
|
|
9d78ec1bd8 | ||
|
|
b10b85de52 | ||
|
|
1479bee894 | ||
|
|
cff972094c | ||
|
|
79a25f7ae9 | ||
|
|
7270a98ce5 | ||
|
|
0ee92317ec | ||
|
|
743d2d1947 | ||
|
|
df04843f34 | ||
|
|
780877d1d0 | ||
|
|
08eeed61f4 | ||
|
|
5207ff84dc | ||
|
|
4ade2e61ab | ||
|
|
818be98314 | ||
|
|
056c438452 | ||
|
|
0c040beb59 | ||
|
|
bf5a1dd840 | ||
|
|
f44200bec8 | ||
|
|
3b1b08efd6 | ||
|
|
3d8791067f | ||
|
|
da8207b73b | ||
|
|
aa9ca401fa | ||
|
|
e43c0c3ffc |
@@ -10,7 +10,8 @@ services:
|
||||
- 8080:8080
|
||||
volumes:
|
||||
- localai_workspace:/workspace
|
||||
- ../models:/host-models
|
||||
- models:/host-models
|
||||
- backends:/host-backends
|
||||
- ./customization:/devcontainer-customization
|
||||
command: /bin/sh -c "while sleep 1000; do :; done"
|
||||
cap_add:
|
||||
@@ -39,6 +40,9 @@ services:
|
||||
- GF_SECURITY_ADMIN_PASSWORD=grafana
|
||||
volumes:
|
||||
- ./grafana:/etc/grafana/provisioning/datasources
|
||||
|
||||
volumes:
|
||||
prom_data:
|
||||
localai_workspace:
|
||||
localai_workspace:
|
||||
models:
|
||||
backends:
|
||||
|
||||
3
.env
3
.env
@@ -26,6 +26,9 @@
|
||||
## Disables COMPEL (Diffusers)
|
||||
# COMPEL=0
|
||||
|
||||
## Disables SD_EMBED (Diffusers)
|
||||
# SD_EMBED=0
|
||||
|
||||
## Enable/Disable single backend (useful if only one GPU is available)
|
||||
# LOCALAI_SINGLE_ACTIVE_BACKEND=true
|
||||
|
||||
|
||||
2
.github/gallery-agent/agent.go
vendored
2
.github/gallery-agent/agent.go
vendored
@@ -146,7 +146,7 @@ func getRealReadme(ctx context.Context, repository string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
content := newFragment.LastMessage().Content
|
||||
content := result.LastMessage().Content
|
||||
return cleanTextContent(content), nil
|
||||
}
|
||||
|
||||
|
||||
85
.github/workflows/backend.yml
vendored
85
.github/workflows/backend.yml
vendored
@@ -210,6 +210,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-faster-qwen3-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "faster-qwen3-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -575,6 +588,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-faster-qwen3-tts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "faster-qwen3-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -705,6 +731,19 @@ jobs:
|
||||
backend: "qwen-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "faster-qwen3-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -718,6 +757,19 @@ jobs:
|
||||
backend: "pocket-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-chatterbox'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "chatterbox"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1293,6 +1345,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-faster-qwen3-tts'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "faster-qwen3-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1674,6 +1739,20 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# voxtral
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-voxtral'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "voxtral"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
#silero-vad
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -1878,7 +1957,7 @@ jobs:
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-voxcpm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
@@ -1945,6 +2024,10 @@ jobs:
|
||||
tag-suffix: "-metal-darwin-arm64-whisper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "voxtral"
|
||||
tag-suffix: "-metal-darwin-arm64-voxtral"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "vibevoice"
|
||||
tag-suffix: "-metal-darwin-arm64-vibevoice"
|
||||
build-type: "mps"
|
||||
|
||||
8
.github/workflows/bump_deps.yaml
vendored
8
.github/workflows/bump_deps.yaml
vendored
@@ -18,10 +18,6 @@ jobs:
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/whisper/Makefile"
|
||||
- repository: "PABannier/bark.cpp"
|
||||
variable: "BARKCPP_VERSION"
|
||||
branch: "main"
|
||||
file: "Makefile"
|
||||
- repository: "leejet/stable-diffusion.cpp"
|
||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||
branch: "master"
|
||||
@@ -30,6 +26,10 @@ jobs:
|
||||
variable: "PIPER_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/piper/Makefile"
|
||||
- repository: "antirez/voxtral.c"
|
||||
variable: "VOXTRAL_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/voxtral/Makefile"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
2
.github/workflows/localaibot_automerge.yml
vendored
2
.github/workflows/localaibot_automerge.yml
vendored
@@ -10,7 +10,7 @@ permissions:
|
||||
actions: write # to dispatch publish workflow
|
||||
jobs:
|
||||
dependabot:
|
||||
if: github.repository == 'mudler/LocalAI' && github.actor == 'localai-bot' && !contains(github.event.pull_request.title, 'chore(model gallery):')
|
||||
if: github.repository == 'mudler/LocalAI' && github.actor == 'localai-bot' && contains(github.event.pull_request.title, 'chore:')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
||||
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
with:
|
||||
go-version: 1.23
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
uses: goreleaser/goreleaser-action@v7
|
||||
with:
|
||||
version: v2.11.0
|
||||
args: release --clean
|
||||
|
||||
2
.github/workflows/stalebot.yml
vendored
2
.github/workflows/stalebot.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
if: github.repository == 'mudler/LocalAI'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v9
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'
|
||||
|
||||
31
.github/workflows/test-extra.yml
vendored
31
.github/workflows/test-extra.yml
vendored
@@ -361,3 +361,34 @@ jobs:
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm test
|
||||
tests-voxtral:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake curl libopenblas-dev ffmpeg
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
# You can test your matrix by printing the current Go version
|
||||
- name: Display Go version
|
||||
run: go version
|
||||
- name: Proto Dependencies
|
||||
run: |
|
||||
# Install protoc
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||
PATH="$PATH:$HOME/go/bin" make protogen-go
|
||||
- name: Build voxtral
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/voxtral
|
||||
- name: Test voxtral
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/voxtral test
|
||||
|
||||
10
Makefile
10
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -317,6 +317,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/moonshine
|
||||
$(MAKE) -C backend/python/pocket-tts
|
||||
$(MAKE) -C backend/python/qwen-tts
|
||||
$(MAKE) -C backend/python/faster-qwen3-tts
|
||||
$(MAKE) -C backend/python/qwen-asr
|
||||
$(MAKE) -C backend/python/nemo
|
||||
$(MAKE) -C backend/python/voxcpm
|
||||
@@ -334,6 +335,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/moonshine test
|
||||
$(MAKE) -C backend/python/pocket-tts test
|
||||
$(MAKE) -C backend/python/qwen-tts test
|
||||
$(MAKE) -C backend/python/faster-qwen3-tts test
|
||||
$(MAKE) -C backend/python/qwen-asr test
|
||||
$(MAKE) -C backend/python/nemo test
|
||||
$(MAKE) -C backend/python/voxcpm test
|
||||
@@ -453,6 +455,7 @@ BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
|
||||
# Python backends with root context
|
||||
BACKEND_RERANKERS = rerankers|python|.|false|true
|
||||
@@ -472,6 +475,7 @@ BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true
|
||||
BACKEND_MOONSHINE = moonshine|python|.|false|true
|
||||
BACKEND_POCKET_TTS = pocket-tts|python|.|false|true
|
||||
BACKEND_QWEN_TTS = qwen-tts|python|.|false|true
|
||||
BACKEND_FASTER_QWEN3_TTS = faster-qwen3-tts|python|.|false|true
|
||||
BACKEND_QWEN_ASR = qwen-asr|python|.|false|true
|
||||
BACKEND_NEMO = nemo|python|.|false|true
|
||||
BACKEND_VOXCPM = voxcpm|python|.|false|true
|
||||
@@ -506,6 +510,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRANSFORMERS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OUTETTS)))
|
||||
@@ -523,6 +528,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MOONSHINE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_POCKET_TTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN_TTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_FASTER_QWEN3_TTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN_ASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_NEMO)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXCPM)))
|
||||
@@ -533,7 +539,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
11
README.md
11
README.md
@@ -93,16 +93,7 @@ Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tool
|
||||
|
||||
## 💻 Quickstart
|
||||
|
||||
> ⚠️ **Note:** The `install.sh` script is currently experiencing issues due to the heavy changes currently undergoing in LocalAI and may produce broken or misconfigured installations. Please use Docker installation (see below) or manual binary installation until [issue #8032](https://github.com/mudler/LocalAI/issues/8032) is resolved.
|
||||
|
||||
Run the installer script:
|
||||
|
||||
```bash
|
||||
# Basic installation
|
||||
curl https://localai.io/install.sh | sh
|
||||
```
|
||||
|
||||
For more installation options, see [Installer Options](https://localai.io/installation/).
|
||||
|
||||
### macOS Download:
|
||||
|
||||
@@ -237,7 +228,7 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A
|
||||
- 🧩 [Backend Gallery](https://localai.io/backends/): Install/remove backends on the fly, powered by OCI images — fully customizable and API-driven.
|
||||
- 📖 [Text generation with GPTs](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [:book: and more](https://localai.io/model-compatibility/index.html#model-compatibility-table))
|
||||
- 🗣 [Text to Audio](https://localai.io/features/text-to-audio/)
|
||||
- 🔈 [Audio to Text](https://localai.io/features/audio-to-text/) (Audio transcription with `whisper.cpp`)
|
||||
- 🔈 [Audio to Text](https://localai.io/features/audio-to-text/)
|
||||
- 🎨 [Image generation](https://localai.io/features/image-generation)
|
||||
- 🔥 [OpenAI-alike tools API](https://localai.io/features/openai-functions/)
|
||||
- ⚡ [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
|
||||
|
||||
@@ -20,7 +20,7 @@ RUN apt-get update && \
|
||||
build-essential \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget \
|
||||
make cmake wget libopenblas-dev \
|
||||
curl unzip \
|
||||
libssl-dev && \
|
||||
apt-get clean && \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=8872ad2125336d209a9911a82101f80095a9831d
|
||||
LLAMA_VERSION?=723c71064da0908c19683f8c344715fbf6d986fd
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -417,6 +417,12 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// n_ctx_checkpoints: max context checkpoints per slot (default: 8)
|
||||
params.n_ctx_checkpoints = 8;
|
||||
|
||||
// llama memory fit fails if we don't provide a buffer for tensor overrides
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
for (int i = 0; i < request->options_size(); i++) {
|
||||
std::string opt = request->options(i);
|
||||
@@ -1255,6 +1261,42 @@ public:
|
||||
body_json["add_generation_prompt"] = data["add_generation_prompt"];
|
||||
}
|
||||
|
||||
// Pass sampling parameters to body_json so oaicompat_chat_params_parse respects them
|
||||
// and doesn't overwrite them with defaults in the returned parsed_data
|
||||
if (data.contains("n_predict")) {
|
||||
body_json["max_tokens"] = data["n_predict"];
|
||||
}
|
||||
if (data.contains("ignore_eos")) {
|
||||
body_json["ignore_eos"] = data["ignore_eos"];
|
||||
}
|
||||
if (data.contains("stop")) {
|
||||
body_json["stop"] = data["stop"];
|
||||
}
|
||||
if (data.contains("temperature")) {
|
||||
body_json["temperature"] = data["temperature"];
|
||||
}
|
||||
if (data.contains("top_p")) {
|
||||
body_json["top_p"] = data["top_p"];
|
||||
}
|
||||
if (data.contains("frequency_penalty")) {
|
||||
body_json["frequency_penalty"] = data["frequency_penalty"];
|
||||
}
|
||||
if (data.contains("presence_penalty")) {
|
||||
body_json["presence_penalty"] = data["presence_penalty"];
|
||||
}
|
||||
if (data.contains("seed")) {
|
||||
body_json["seed"] = data["seed"];
|
||||
}
|
||||
if (data.contains("logit_bias")) {
|
||||
body_json["logit_bias"] = data["logit_bias"];
|
||||
}
|
||||
if (data.contains("top_k")) {
|
||||
body_json["top_k"] = data["top_k"];
|
||||
}
|
||||
if (data.contains("min_p")) {
|
||||
body_json["min_p"] = data["min_p"];
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -1986,6 +2028,42 @@ public:
|
||||
body_json["add_generation_prompt"] = data["add_generation_prompt"];
|
||||
}
|
||||
|
||||
// Pass sampling parameters to body_json so oaicompat_chat_params_parse respects them
|
||||
// and doesn't overwrite them with defaults in the returned parsed_data
|
||||
if (data.contains("n_predict")) {
|
||||
body_json["max_tokens"] = data["n_predict"];
|
||||
}
|
||||
if (data.contains("ignore_eos")) {
|
||||
body_json["ignore_eos"] = data["ignore_eos"];
|
||||
}
|
||||
if (data.contains("stop")) {
|
||||
body_json["stop"] = data["stop"];
|
||||
}
|
||||
if (data.contains("temperature")) {
|
||||
body_json["temperature"] = data["temperature"];
|
||||
}
|
||||
if (data.contains("top_p")) {
|
||||
body_json["top_p"] = data["top_p"];
|
||||
}
|
||||
if (data.contains("frequency_penalty")) {
|
||||
body_json["frequency_penalty"] = data["frequency_penalty"];
|
||||
}
|
||||
if (data.contains("presence_penalty")) {
|
||||
body_json["presence_penalty"] = data["presence_penalty"];
|
||||
}
|
||||
if (data.contains("seed")) {
|
||||
body_json["seed"] = data["seed"];
|
||||
}
|
||||
if (data.contains("logit_bias")) {
|
||||
body_json["logit_bias"] = data["logit_bias"];
|
||||
}
|
||||
if (data.contains("top_k")) {
|
||||
body_json["top_k"] = data["top_k"];
|
||||
}
|
||||
if (data.contains("min_p")) {
|
||||
body_json["min_p"] = data["min_p"];
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
|
||||
2
backend/go/stablediffusion-ggml/.gitignore
vendored
2
backend/go/stablediffusion-ggml/.gitignore
vendored
@@ -2,5 +2,5 @@ package/
|
||||
sources/
|
||||
.cache/
|
||||
build/
|
||||
libgosd.so
|
||||
*.so
|
||||
stablediffusion-ggml
|
||||
|
||||
@@ -66,15 +66,18 @@ sources/stablediffusion-ggml.cpp:
|
||||
git checkout $(STABLEDIFFUSION_GGML_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
libgosd.so: sources/stablediffusion-ggml.cpp CMakeLists.txt gosd.cpp gosd.h
|
||||
mkdir -p build && \
|
||||
cd build && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build/libgosd.so ./
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
stablediffusion-ggml: main.go gosd.go libgosd.so
|
||||
# Only build CPU variants on Linux
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgosd-avx.so libgosd-avx2.so libgosd-avx512.so libgosd-fallback.so
|
||||
else
|
||||
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||
VARIANT_TARGETS = libgosd-fallback.so
|
||||
endif
|
||||
|
||||
stablediffusion-ggml: main.go gosd.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o stablediffusion-ggml ./
|
||||
|
||||
package: stablediffusion-ggml
|
||||
@@ -82,5 +85,46 @@ package: stablediffusion-ggml
|
||||
|
||||
build: package
|
||||
|
||||
clean:
|
||||
rm -rf libgosd.so build stablediffusion-ggml package sources
|
||||
clean: purge
|
||||
rm -rf libgosd*.so stablediffusion-ggml package sources
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgosd-avx.so: sources/stablediffusion-ggml.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I stablediffusion-ggml build info:avx${RESET})
|
||||
SO_TARGET=libgosd-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosd-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosd-avx2.so: sources/stablediffusion-ggml.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I stablediffusion-ggml build info:avx2${RESET})
|
||||
SO_TARGET=libgosd-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosd-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosd-avx512.so: sources/stablediffusion-ggml.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I stablediffusion-ggml build info:avx512${RESET})
|
||||
SO_TARGET=libgosd-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosd-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
libgosd-fallback.so: sources/stablediffusion-ggml.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I stablediffusion-ggml build info:fallback${RESET})
|
||||
SO_TARGET=libgosd-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosd-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosd-custom: CMakeLists.txt gosd.cpp gosd.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgosd.so ./$(SO_TARGET)
|
||||
|
||||
all: stablediffusion-ggml package
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
@@ -17,7 +18,13 @@ type LibFuncs struct {
|
||||
}
|
||||
|
||||
func main() {
|
||||
gosd, err := purego.Dlopen("./libgosd.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("SD_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgosd-fallback.so"
|
||||
}
|
||||
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ REPO_ROOT="${CURDIR}/../../.."
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/libgosd.so $CURDIR/package/
|
||||
cp -avf $CURDIR/libgosd-*.so $CURDIR/package/
|
||||
cp -avf $CURDIR/stablediffusion-ggml $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
|
||||
@@ -1,14 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgosd-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgosd-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgosd-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgosd-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgosd-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgosd-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgosd-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export SD_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/stablediffusion-ggml "$@"
|
||||
fi
|
||||
|
||||
exec $CURDIR/stablediffusion-ggml "$@"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/stablediffusion-ggml "$@"
|
||||
|
||||
9
backend/go/voxtral/.gitignore
vendored
Normal file
9
backend/go/voxtral/.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
build-*/
|
||||
package/
|
||||
voxtral
|
||||
*.so
|
||||
*.dylib
|
||||
compile_commands.json
|
||||
84
backend/go/voxtral/CMakeLists.txt
Normal file
84
backend/go/voxtral/CMakeLists.txt
Normal file
@@ -0,0 +1,84 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
|
||||
if(USE_METAL)
|
||||
project(govoxtral LANGUAGES C OBJC)
|
||||
else()
|
||||
project(govoxtral LANGUAGES C)
|
||||
endif()
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# Workaround: CMake + GCC linker depfile generation fails for MODULE libraries
|
||||
set(CMAKE_C_LINKER_DEPFILE_SUPPORTED FALSE)
|
||||
|
||||
# Build voxtral.c as a library
|
||||
set(VOXTRAL_SOURCES
|
||||
sources/voxtral.c/voxtral.c
|
||||
sources/voxtral.c/voxtral_kernels.c
|
||||
sources/voxtral.c/voxtral_audio.c
|
||||
sources/voxtral.c/voxtral_encoder.c
|
||||
sources/voxtral.c/voxtral_decoder.c
|
||||
sources/voxtral.c/voxtral_tokenizer.c
|
||||
sources/voxtral.c/voxtral_safetensors.c
|
||||
)
|
||||
|
||||
# Metal GPU acceleration (macOS arm64 only)
|
||||
if(USE_METAL)
|
||||
# Generate embedded shader header from .metal source via xxd
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/sources/voxtral.c/voxtral_shaders_source.h
|
||||
COMMAND xxd -i voxtral_shaders.metal > voxtral_shaders_source.h
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/sources/voxtral.c
|
||||
DEPENDS sources/voxtral.c/voxtral_shaders.metal
|
||||
COMMENT "Generating embedded Metal shaders header"
|
||||
)
|
||||
list(APPEND VOXTRAL_SOURCES sources/voxtral.c/voxtral_metal.m)
|
||||
set_source_files_properties(sources/voxtral.c/voxtral_metal.m PROPERTIES
|
||||
COMPILE_FLAGS "-fobjc-arc"
|
||||
)
|
||||
endif()
|
||||
|
||||
add_library(govoxtral MODULE csrc/govoxtral.c ${VOXTRAL_SOURCES})
|
||||
|
||||
target_include_directories(govoxtral PRIVATE sources/voxtral.c csrc)
|
||||
|
||||
target_compile_options(govoxtral PRIVATE -O3 -ffast-math)
|
||||
|
||||
if(USE_METAL)
|
||||
target_compile_definitions(govoxtral PRIVATE USE_BLAS USE_METAL ACCELERATE_NEW_LAPACK)
|
||||
target_link_libraries(govoxtral PRIVATE
|
||||
"-framework Accelerate"
|
||||
"-framework Metal"
|
||||
"-framework MetalPerformanceShaders"
|
||||
"-framework MetalPerformanceShadersGraph"
|
||||
"-framework Foundation"
|
||||
"-framework AudioToolbox"
|
||||
"-framework CoreFoundation"
|
||||
m
|
||||
)
|
||||
# Ensure the generated shader header is built before compiling
|
||||
target_sources(govoxtral PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/voxtral.c/voxtral_shaders_source.h
|
||||
)
|
||||
elseif(USE_OPENBLAS)
|
||||
# Try to find OpenBLAS; use it if available, otherwise fall back to pure C
|
||||
find_package(BLAS)
|
||||
if(BLAS_FOUND)
|
||||
target_compile_definitions(govoxtral PRIVATE USE_BLAS USE_OPENBLAS)
|
||||
target_link_libraries(govoxtral PRIVATE ${BLAS_LIBRARIES} m)
|
||||
target_include_directories(govoxtral PRIVATE /usr/include/openblas)
|
||||
else()
|
||||
message(WARNING "OpenBLAS requested but not found, building without BLAS")
|
||||
target_link_libraries(govoxtral PRIVATE m)
|
||||
endif()
|
||||
elseif(APPLE)
|
||||
# macOS without Metal: use Accelerate framework
|
||||
target_compile_definitions(govoxtral PRIVATE USE_BLAS ACCELERATE_NEW_LAPACK)
|
||||
target_link_libraries(govoxtral PRIVATE "-framework Accelerate" m)
|
||||
else()
|
||||
target_link_libraries(govoxtral PRIVATE m)
|
||||
endif()
|
||||
|
||||
set_property(TARGET govoxtral PROPERTY C_STANDARD 11)
|
||||
set_target_properties(govoxtral PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
107
backend/go/voxtral/Makefile
Normal file
107
backend/go/voxtral/Makefile
Normal file
@@ -0,0 +1,107 @@
|
||||
.NOTPARALLEL:
|
||||
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=true
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
# voxtral.c version
|
||||
VOXTRAL_REPO?=https://github.com/antirez/voxtral.c
|
||||
VOXTRAL_VERSION?=134d366c24d20c64b614a3dcc8bda2a6922d077d
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
# Shared library extension
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
SO_EXT=dylib
|
||||
else
|
||||
SO_EXT=so
|
||||
endif
|
||||
|
||||
SO_TARGET?=libgovoxtral.$(SO_EXT)
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
ifneq ($(UNAME_S),Darwin)
|
||||
CMAKE_ARGS+=-DCMAKE_C_FLAGS="-march=x86-64"
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DUSE_OPENBLAS=OFF
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DUSE_OPENBLAS=OFF
|
||||
else ifeq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DUSE_OPENBLAS=OFF -DUSE_METAL=ON
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
# Default on macOS: use Accelerate (no OpenBLAS needed)
|
||||
CMAKE_ARGS+=-DUSE_OPENBLAS=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DUSE_OPENBLAS=ON
|
||||
endif
|
||||
|
||||
# Single library target
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
VARIANT_TARGETS = libgovoxtral.dylib
|
||||
else
|
||||
VARIANT_TARGETS = libgovoxtral.so
|
||||
endif
|
||||
|
||||
sources/voxtral.c:
|
||||
mkdir -p sources/voxtral.c
|
||||
cd sources/voxtral.c && \
|
||||
git init && \
|
||||
git remote add origin $(VOXTRAL_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(VOXTRAL_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
voxtral: main.go govoxtral.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o voxtral ./
|
||||
|
||||
package: voxtral
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgovoxtral.so libgovoxtral.dylib package sources/voxtral.c voxtral
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
# Build single library
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
libgovoxtral.dylib: sources/voxtral.c
|
||||
$(MAKE) purge
|
||||
$(info Building voxtral: darwin)
|
||||
SO_TARGET=libgovoxtral.dylib NATIVE=true $(MAKE) libgovoxtral-custom
|
||||
rm -rfv build*
|
||||
else
|
||||
libgovoxtral.so: sources/voxtral.c
|
||||
$(MAKE) purge
|
||||
$(info Building voxtral)
|
||||
SO_TARGET=libgovoxtral.so $(MAKE) libgovoxtral-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
libgovoxtral-custom: CMakeLists.txt csrc/govoxtral.c csrc/govoxtral.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
(mv build-$(SO_TARGET)/libgovoxtral.so ./$(SO_TARGET) 2>/dev/null || \
|
||||
mv build-$(SO_TARGET)/libgovoxtral.dylib ./$(SO_TARGET) 2>/dev/null)
|
||||
|
||||
test: voxtral
|
||||
@echo "Running voxtral tests..."
|
||||
bash test.sh
|
||||
@echo "voxtral tests completed."
|
||||
|
||||
all: voxtral package
|
||||
62
backend/go/voxtral/csrc/govoxtral.c
Normal file
62
backend/go/voxtral/csrc/govoxtral.c
Normal file
@@ -0,0 +1,62 @@
|
||||
#include "govoxtral.h"
|
||||
#include "voxtral.h"
|
||||
#include "voxtral_audio.h"
|
||||
#ifdef USE_METAL
|
||||
#include "voxtral_metal.h"
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
|
||||
static vox_ctx_t *ctx = NULL;
|
||||
static char *last_result = NULL;
|
||||
static int metal_initialized = 0;
|
||||
|
||||
int load_model(const char *model_dir) {
|
||||
if (ctx != NULL) {
|
||||
vox_free(ctx);
|
||||
ctx = NULL;
|
||||
}
|
||||
|
||||
#ifdef USE_METAL
|
||||
if (!metal_initialized) {
|
||||
vox_metal_init();
|
||||
metal_initialized = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
ctx = vox_load(model_dir);
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr, "error: failed to load voxtral model from %s\n", model_dir);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char *transcribe(const char *wav_path) {
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr, "error: model not loaded\n");
|
||||
return "";
|
||||
}
|
||||
|
||||
if (last_result != NULL) {
|
||||
free(last_result);
|
||||
last_result = NULL;
|
||||
}
|
||||
|
||||
last_result = vox_transcribe(ctx, wav_path);
|
||||
if (last_result == NULL) {
|
||||
fprintf(stderr, "error: transcription failed for %s\n", wav_path);
|
||||
return "";
|
||||
}
|
||||
|
||||
return last_result;
|
||||
}
|
||||
|
||||
void free_result(void) {
|
||||
if (last_result != NULL) {
|
||||
free(last_result);
|
||||
last_result = NULL;
|
||||
}
|
||||
}
|
||||
8
backend/go/voxtral/csrc/govoxtral.h
Normal file
8
backend/go/voxtral/csrc/govoxtral.h
Normal file
@@ -0,0 +1,8 @@
|
||||
#ifndef GOVOXTRAL_H
|
||||
#define GOVOXTRAL_H
|
||||
|
||||
extern int load_model(const char *model_dir);
|
||||
extern const char *transcribe(const char *wav_path);
|
||||
extern void free_result(void);
|
||||
|
||||
#endif /* GOVOXTRAL_H */
|
||||
60
backend/go/voxtral/govoxtral.go
Normal file
60
backend/go/voxtral/govoxtral.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelDir string) int
|
||||
CppTranscribe func(wavPath string) string
|
||||
CppFreeResult func()
|
||||
)
|
||||
|
||||
type Voxtral struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
func (v *Voxtral) Load(opts *pb.ModelOptions) error {
|
||||
if ret := CppLoadModel(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("failed to load Voxtral model from %s", opts.ModelFile)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
dir, err := os.MkdirTemp("", "voxtral")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
convertedPath := dir + "/converted.wav"
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
result := strings.Clone(CppTranscribe(convertedPath))
|
||||
CppFreeResult()
|
||||
|
||||
text := strings.TrimSpace(result)
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
if text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{
|
||||
Id: 0,
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: text,
|
||||
}, nil
|
||||
}
|
||||
53
backend/go/voxtral/main.go
Normal file
53
backend/go/voxtral/main.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("VOXTRAL_LIBRARY")
|
||||
if libName == "" {
|
||||
if runtime.GOOS == "darwin" {
|
||||
libName = "./libgovoxtral.dylib"
|
||||
} else {
|
||||
libName = "./libgovoxtral.so"
|
||||
}
|
||||
}
|
||||
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppTranscribe, "transcribe"},
|
||||
{&CppFreeResult, "free_result"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &Voxtral{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
68
backend/go/voxtral/package.sh
Normal file
68
backend/go/voxtral/package.sh
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/voxtral $CURDIR/package/
|
||||
cp -fv $CURDIR/libgovoxtral-*.so $CURDIR/package/ 2>/dev/null || true
|
||||
cp -fv $CURDIR/libgovoxtral-*.dylib $CURDIR/package/ 2>/dev/null || true
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
# OpenBLAS if available
|
||||
if [ -f /usr/lib/x86_64-linux-gnu/libopenblas.so.0 ]; then
|
||||
cp -arfLv /usr/lib/x86_64-linux-gnu/libopenblas.so.0 $CURDIR/package/lib/
|
||||
fi
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
# OpenBLAS if available
|
||||
if [ -f /usr/lib/aarch64-linux-gnu/libopenblas.so.0 ]; then
|
||||
cp -arfLv /usr/lib/aarch64-linux-gnu/libopenblas.so.0 $CURDIR/package/lib/
|
||||
fi
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin — system frameworks linked dynamically, no bundled libs needed"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
49
backend/go/voxtral/run.sh
Normal file
49
backend/go/voxtral/run.sh
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
# macOS: single dylib variant (Metal or Accelerate)
|
||||
LIBRARY="$CURDIR/libgovoxtral-fallback.dylib"
|
||||
export DYLD_LIBRARY_PATH=$CURDIR/lib:$DYLD_LIBRARY_PATH
|
||||
else
|
||||
LIBRARY="$CURDIR/libgovoxtral-fallback.so"
|
||||
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgovoxtral-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgovoxtral-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgovoxtral-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgovoxtral-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
fi
|
||||
|
||||
export VOXTRAL_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it (Linux only)
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/voxtral "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/voxtral "$@"
|
||||
48
backend/go/voxtral/test.sh
Normal file
48
backend/go/voxtral/test.sh
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
echo "Running voxtral backend tests..."
|
||||
|
||||
# The test requires:
|
||||
# - VOXTRAL_MODEL_DIR: path to directory containing consolidated.safetensors + tekken.json
|
||||
# - VOXTRAL_BINARY: path to the voxtral binary (defaults to ./voxtral)
|
||||
#
|
||||
# Tests that require the model will be skipped if VOXTRAL_MODEL_DIR is not set.
|
||||
|
||||
cd "$CURDIR"
|
||||
export VOXTRAL_MODEL_DIR="${VOXTRAL_MODEL_DIR:-./voxtral-model}"
|
||||
|
||||
if [ ! -d "$VOXTRAL_MODEL_DIR" ]; then
|
||||
echo "Creating voxtral-model directory for tests..."
|
||||
mkdir -p "$VOXTRAL_MODEL_DIR"
|
||||
MODEL_ID="mistralai/Voxtral-Mini-4B-Realtime-2602"
|
||||
echo "Model: ${MODEL_ID}"
|
||||
echo ""
|
||||
|
||||
# Files to download
|
||||
FILES=(
|
||||
"consolidated.safetensors"
|
||||
"params.json"
|
||||
"tekken.json"
|
||||
)
|
||||
|
||||
BASE_URL="https://huggingface.co/${MODEL_ID}/resolve/main"
|
||||
|
||||
for file in "${FILES[@]}"; do
|
||||
dest="${VOXTRAL_MODEL_DIR}/${file}"
|
||||
if [ -f "${dest}" ]; then
|
||||
echo " [skip] ${file} (already exists)"
|
||||
else
|
||||
echo " [download] ${file}..."
|
||||
curl -L -o "${dest}" "${BASE_URL}/${file}" --progress-bar
|
||||
echo " [done] ${file}"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
# Run Go tests
|
||||
go test -v -timeout 300s ./...
|
||||
|
||||
echo "All voxtral tests passed."
|
||||
201
backend/go/voxtral/voxtral_test.go
Normal file
201
backend/go/voxtral/voxtral_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
const (
|
||||
testAddr = "localhost:50051"
|
||||
sampleAudio = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav"
|
||||
startupWait = 5 * time.Second
|
||||
)
|
||||
|
||||
func skipIfNoModel(t *testing.T) string {
|
||||
t.Helper()
|
||||
modelDir := os.Getenv("VOXTRAL_MODEL_DIR")
|
||||
if modelDir == "" {
|
||||
t.Skip("VOXTRAL_MODEL_DIR not set, skipping test (set to voxtral model directory)")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(modelDir, "consolidated.safetensors")); os.IsNotExist(err) {
|
||||
t.Skipf("Model file not found in %s, skipping", modelDir)
|
||||
}
|
||||
return modelDir
|
||||
}
|
||||
|
||||
func startServer(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
binary := os.Getenv("VOXTRAL_BINARY")
|
||||
if binary == "" {
|
||||
binary = "./voxtral"
|
||||
}
|
||||
if _, err := os.Stat(binary); os.IsNotExist(err) {
|
||||
t.Skipf("Backend binary not found at %s, skipping", binary)
|
||||
}
|
||||
cmd := exec.Command(binary, "--addr", testAddr)
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
time.Sleep(startupWait)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func stopServer(cmd *exec.Cmd) {
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
cmd.Process.Kill()
|
||||
cmd.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func dialGRPC(t *testing.T) *grpc.ClientConn {
|
||||
t.Helper()
|
||||
conn, err := grpc.Dial(testAddr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024),
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial gRPC: %v", err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
func downloadFile(url, dest string) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("HTTP GET failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
}
|
||||
f, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = io.Copy(f, resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
func TestServerHealth(t *testing.T) {
|
||||
cmd := startServer(t)
|
||||
defer stopServer(cmd)
|
||||
|
||||
conn := dialGRPC(t)
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
resp, err := client.Health(context.Background(), &pb.HealthMessage{})
|
||||
if err != nil {
|
||||
t.Fatalf("Health check failed: %v", err)
|
||||
}
|
||||
if string(resp.Message) != "OK" {
|
||||
t.Fatalf("Expected OK, got %s", string(resp.Message))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadModel(t *testing.T) {
|
||||
modelDir := skipIfNoModel(t)
|
||||
cmd := startServer(t)
|
||||
defer stopServer(cmd)
|
||||
|
||||
conn := dialGRPC(t)
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: modelDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel failed: %v", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
t.Fatalf("LoadModel returned failure: %s", resp.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAudioTranscription(t *testing.T) {
|
||||
modelDir := skipIfNoModel(t)
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "voxtral-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Download sample audio — JFK "ask not what your country can do for you" clip
|
||||
audioFile := filepath.Join(tmpDir, "sample.wav")
|
||||
t.Log("Downloading sample audio...")
|
||||
if err := downloadFile(sampleAudio, audioFile); err != nil {
|
||||
t.Fatalf("Failed to download sample audio: %v", err)
|
||||
}
|
||||
|
||||
cmd := startServer(t)
|
||||
defer stopServer(cmd)
|
||||
|
||||
conn := dialGRPC(t)
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
// Load model
|
||||
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: modelDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel failed: %v", err)
|
||||
}
|
||||
if !loadResp.Success {
|
||||
t.Fatalf("LoadModel returned failure: %s", loadResp.Message)
|
||||
}
|
||||
|
||||
// Transcribe
|
||||
transcriptResp, err := client.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AudioTranscription failed: %v", err)
|
||||
}
|
||||
if transcriptResp == nil {
|
||||
t.Fatal("AudioTranscription returned nil")
|
||||
}
|
||||
|
||||
t.Logf("Transcribed text: %s", transcriptResp.Text)
|
||||
t.Logf("Number of segments: %d", len(transcriptResp.Segments))
|
||||
|
||||
if transcriptResp.Text == "" {
|
||||
t.Fatal("Transcription returned empty text")
|
||||
}
|
||||
|
||||
allText := strings.ToLower(transcriptResp.Text)
|
||||
for _, seg := range transcriptResp.Segments {
|
||||
allText += " " + strings.ToLower(seg.Text)
|
||||
}
|
||||
t.Logf("All text: %s", allText)
|
||||
|
||||
if !strings.Contains(allText, "big") {
|
||||
t.Errorf("Expected 'big' in transcription, got: %s", allText)
|
||||
}
|
||||
|
||||
// The sample audio should contain recognizable speech
|
||||
if len(allText) < 10 {
|
||||
t.Errorf("Transcription too short: %q", allText)
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=941bdabbe4561bc6de68981aea01bc5ab05781c5
|
||||
WHISPER_CPP_VERSION?=21411d81ea736ed5d9cdea4df360d3c4b60a4adb
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
@@ -88,19 +88,19 @@ ifeq ($(UNAME_S),Linux)
|
||||
libgowhisper-avx.so: sources/whisper.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I whisper build info:avx${RESET})
|
||||
SO_TARGET=libgowhisper-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) libgowhisper-custom
|
||||
SO_TARGET=libgowhisper-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgowhisper-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgowhisper-avx2.so: sources/whisper.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I whisper build info:avx2${RESET})
|
||||
SO_TARGET=libgowhisper-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) libgowhisper-custom
|
||||
SO_TARGET=libgowhisper-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgowhisper-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgowhisper-avx512.so: sources/whisper.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I whisper build info:avx512${RESET})
|
||||
SO_TARGET=libgowhisper-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) libgowhisper-custom
|
||||
SO_TARGET=libgowhisper-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgowhisper-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
@@ -108,7 +108,7 @@ endif
|
||||
libgowhisper-fallback.so: sources/whisper.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I whisper build info:fallback${RESET})
|
||||
SO_TARGET=libgowhisper-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) libgowhisper-custom
|
||||
SO_TARGET=libgowhisper-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgowhisper-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgowhisper-custom: CMakeLists.txt gowhisper.cpp gowhisper.h
|
||||
|
||||
@@ -56,6 +56,21 @@
|
||||
nvidia-cuda-12: "cuda12-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper"
|
||||
- &voxtral
|
||||
name: "voxtral"
|
||||
alias: "voxtral"
|
||||
license: mit
|
||||
description: |
|
||||
Voxtral Realtime 4B Pure C speech-to-text inference engine
|
||||
urls:
|
||||
- https://github.com/mudler/voxtral.c
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- Metal
|
||||
capabilities:
|
||||
default: "cpu-voxtral"
|
||||
metal-darwin-arm64: "metal-voxtral"
|
||||
- &stablediffusionggml
|
||||
name: "stablediffusion-ggml"
|
||||
alias: "stablediffusion-ggml"
|
||||
@@ -513,6 +528,28 @@
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-qwen-tts"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-tts"
|
||||
icon: https://cdn-avatars.huggingface.co/v1/production/uploads/620760a26e3b7210c2ff1943/-s1gyJfvbE1RgO5iBeNOi.png
|
||||
- &faster-qwen3-tts
|
||||
urls:
|
||||
- https://github.com/andimarafioti/faster-qwen3-tts
|
||||
- https://pypi.org/project/faster-qwen3-tts/
|
||||
description: |
|
||||
Real-time Qwen3-TTS inference using CUDA graph capture. Voice clone only; requires NVIDIA GPU with CUDA.
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
- voice-clone
|
||||
license: apache-2.0
|
||||
name: "faster-qwen3-tts"
|
||||
alias: "faster-qwen3-tts"
|
||||
capabilities:
|
||||
nvidia: "cuda12-faster-qwen3-tts"
|
||||
default: "cuda12-faster-qwen3-tts"
|
||||
nvidia-cuda-13: "cuda13-faster-qwen3-tts"
|
||||
nvidia-cuda-12: "cuda12-faster-qwen3-tts"
|
||||
nvidia-l4t: "nvidia-l4t-faster-qwen3-tts"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-faster-qwen3-tts"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-faster-qwen3-tts"
|
||||
icon: https://cdn-avatars.huggingface.co/v1/production/uploads/620760a26e3b7210c2ff1943/-s1gyJfvbE1RgO5iBeNOi.png
|
||||
- &qwen-asr
|
||||
urls:
|
||||
- https://github.com/QwenLM/Qwen3-ASR
|
||||
@@ -2015,7 +2052,7 @@
|
||||
nvidia-cuda-13: "cuda13-chatterbox-development"
|
||||
nvidia-cuda-12: "cuda12-chatterbox-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-chatterbox"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-chatterbox"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-chatterbox-development"
|
||||
- !!merge <<: *chatterbox
|
||||
name: "cpu-chatterbox"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-chatterbox"
|
||||
@@ -2264,6 +2301,57 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-qwen-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-qwen-tts
|
||||
## faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "faster-qwen3-tts-development"
|
||||
capabilities:
|
||||
nvidia: "cuda12-faster-qwen3-tts-development"
|
||||
default: "cuda12-faster-qwen3-tts-development"
|
||||
nvidia-cuda-13: "cuda13-faster-qwen3-tts-development"
|
||||
nvidia-cuda-12: "cuda12-faster-qwen3-tts-development"
|
||||
nvidia-l4t: "nvidia-l4t-faster-qwen3-tts-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-faster-qwen3-tts-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-faster-qwen3-tts-development"
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "cuda12-faster-qwen3-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-faster-qwen3-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "cuda12-faster-qwen3-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-qwen3-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "cuda13-faster-qwen3-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-faster-qwen3-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "cuda13-faster-qwen3-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-faster-qwen3-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "nvidia-l4t-faster-qwen3-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-faster-qwen3-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "nvidia-l4t-faster-qwen3-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-faster-qwen3-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "cuda13-nvidia-l4t-arm64-faster-qwen3-tts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts
|
||||
- !!merge <<: *faster-qwen3-tts
|
||||
name: "cuda13-nvidia-l4t-arm64-faster-qwen3-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts
|
||||
## qwen-asr
|
||||
- !!merge <<: *qwen-asr
|
||||
name: "qwen-asr-development"
|
||||
@@ -2594,3 +2682,24 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-pocket-tts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-pocket-tts
|
||||
## voxtral
|
||||
- !!merge <<: *voxtral
|
||||
name: "cpu-voxtral"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-voxtral"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-voxtral
|
||||
- !!merge <<: *voxtral
|
||||
name: "cpu-voxtral-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-voxtral"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-voxtral
|
||||
- !!merge <<: *voxtral
|
||||
name: "metal-voxtral"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-voxtral"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-voxtral
|
||||
- !!merge <<: *voxtral
|
||||
name: "metal-voxtral-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-voxtral
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.76.0
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
grpcio-tools
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.76.0
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
@@ -115,6 +115,7 @@ Available pipelines: AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline, ...
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `COMPEL` | `0` | Enable Compel for prompt weighting |
|
||||
| `SD_EMBED` | `0` | Enable sd_embed for prompt weighting |
|
||||
| `XPU` | `0` | Enable Intel XPU support |
|
||||
| `CLIPSKIP` | `1` | Enable CLIP skip support |
|
||||
| `SAFETENSORS` | `1` | Use safetensors format |
|
||||
|
||||
@@ -40,6 +40,21 @@ from compel import Compel, ReturnedEmbeddingsType
|
||||
from optimum.quanto import freeze, qfloat8, quantize
|
||||
from transformers import T5EncoderModel
|
||||
from safetensors.torch import load_file
|
||||
# Try to import sd_embed - it might not always be available
|
||||
try:
|
||||
from sd_embed.embedding_funcs import (
|
||||
get_weighted_text_embeddings_sd15,
|
||||
get_weighted_text_embeddings_sdxl,
|
||||
get_weighted_text_embeddings_sd3,
|
||||
get_weighted_text_embeddings_flux1,
|
||||
)
|
||||
SD_EMBED_AVAILABLE = True
|
||||
except ImportError:
|
||||
get_weighted_text_embeddings_sd15 = None
|
||||
get_weighted_text_embeddings_sdxl = None
|
||||
get_weighted_text_embeddings_sd3 = None
|
||||
get_weighted_text_embeddings_flux1 = None
|
||||
SD_EMBED_AVAILABLE = False
|
||||
|
||||
# Import LTX-2 specific utilities
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video as ltx2_encode_video
|
||||
@@ -47,6 +62,10 @@ from diffusers import LTX2VideoTransformer3DModel, GGUFQuantizationConfig
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
COMPEL = os.environ.get("COMPEL", "0") == "1"
|
||||
SD_EMBED = os.environ.get("SD_EMBED", "0") == "1"
|
||||
# Warn if SD_EMBED is enabled but the module is not available
|
||||
if SD_EMBED and not SD_EMBED_AVAILABLE:
|
||||
print("WARNING: SD_EMBED is enabled but sd_embed module is not available. Falling back to standard prompt processing.", file=sys.stderr)
|
||||
XPU = os.environ.get("XPU", "0") == "1"
|
||||
CLIPSKIP = os.environ.get("CLIPSKIP", "1") == "1"
|
||||
SAFETENSORS = os.environ.get("SAFETENSORS", "1") == "1"
|
||||
@@ -177,7 +196,7 @@ def get_scheduler(name: str, config: dict = {}):
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant):
|
||||
def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant, device_map=None):
|
||||
"""
|
||||
Load a diffusers pipeline dynamically using the dynamic loader.
|
||||
|
||||
@@ -191,6 +210,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
fromSingleFile: Whether to use from_single_file() vs from_pretrained()
|
||||
torchType: The torch dtype to use
|
||||
variant: Model variant (e.g., "fp16")
|
||||
device_map: Device mapping strategy (e.g., "auto" for multi-GPU)
|
||||
|
||||
Returns:
|
||||
The loaded pipeline instance
|
||||
@@ -212,14 +232,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
dtype = torch.bfloat16
|
||||
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
|
||||
|
||||
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
|
||||
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype, device_map=device_map)
|
||||
quantize(transformer, weights=qfloat8)
|
||||
freeze(transformer)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype, device_map=device_map)
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
||||
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype, device_map=device_map)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
|
||||
@@ -232,13 +252,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
torch_dtype=torch.float32,
|
||||
device_map=device_map
|
||||
)
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="WanPipeline",
|
||||
model_id=request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map
|
||||
)
|
||||
self.txt2vid = True
|
||||
return pipe
|
||||
@@ -248,13 +270,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
torch_dtype=torch.float32,
|
||||
device_map=device_map
|
||||
)
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="WanImageToVideoPipeline",
|
||||
model_id=request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map
|
||||
)
|
||||
self.img2vid = True
|
||||
return pipe
|
||||
@@ -265,7 +289,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
class_name="SanaPipeline",
|
||||
model_id=request.Model,
|
||||
variant="bf16",
|
||||
torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device_map
|
||||
)
|
||||
pipe.vae.to(torch.bfloat16)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
@@ -277,7 +302,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="DiffusionPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map
|
||||
)
|
||||
return pipe
|
||||
|
||||
@@ -288,7 +314,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
class_name="StableVideoDiffusionPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
variant=variant,
|
||||
device_map=device_map
|
||||
)
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -312,6 +339,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFile,
|
||||
config=request.Model, # Use request.Model as the config/model_id
|
||||
subfolder="transformer",
|
||||
device_map=device_map,
|
||||
**transformer_kwargs,
|
||||
)
|
||||
|
||||
@@ -321,6 +349,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_id=request.Model,
|
||||
transformer=transformer,
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map,
|
||||
)
|
||||
else:
|
||||
# Single file but not GGUF - use standard single file loading
|
||||
@@ -329,6 +358,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_id=modelFile,
|
||||
from_single_file=True,
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map,
|
||||
)
|
||||
else:
|
||||
# Standard loading from pretrained
|
||||
@@ -336,7 +366,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
class_name="LTX2ImageToVideoPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
variant=variant,
|
||||
device_map=device_map
|
||||
)
|
||||
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
@@ -361,6 +392,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFile,
|
||||
config=request.Model, # Use request.Model as the config/model_id
|
||||
subfolder="transformer",
|
||||
device_map=device_map,
|
||||
**transformer_kwargs,
|
||||
)
|
||||
|
||||
@@ -370,6 +402,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_id=request.Model,
|
||||
transformer=transformer,
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map,
|
||||
)
|
||||
else:
|
||||
# Single file but not GGUF - use standard single file loading
|
||||
@@ -378,6 +411,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_id=modelFile,
|
||||
from_single_file=True,
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map,
|
||||
)
|
||||
else:
|
||||
# Standard loading from pretrained
|
||||
@@ -385,7 +419,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
class_name="LTX2Pipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
variant=variant,
|
||||
device_map=device_map
|
||||
)
|
||||
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
@@ -408,6 +443,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if not fromSingleFile:
|
||||
load_kwargs["use_safetensors"] = SAFETENSORS
|
||||
|
||||
# Add device_map for multi-GPU support (when TensorParallelSize > 1)
|
||||
if device_map:
|
||||
load_kwargs["device_map"] = device_map
|
||||
|
||||
# Determine pipeline class name - default to AutoPipelineForText2Image
|
||||
effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image"
|
||||
|
||||
@@ -510,6 +549,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
print(f"LoadModel: PipelineType from request: {request.PipelineType}", file=sys.stderr)
|
||||
|
||||
# Determine device_map for multi-GPU support based on TensorParallelSize
|
||||
# When TensorParallelSize > 1, use device_map='auto' to distribute model across GPUs
|
||||
device_map = None
|
||||
if hasattr(request, 'TensorParallelSize') and request.TensorParallelSize > 1:
|
||||
device_map = "auto"
|
||||
print(f"LoadModel: Multi-GPU mode enabled with TensorParallelSize={request.TensorParallelSize}, using device_map='auto'", file=sys.stderr)
|
||||
|
||||
# Load pipeline using dynamic loader
|
||||
# Special cases that require custom initialization are handled first
|
||||
self.pipe = self._load_pipeline(
|
||||
@@ -517,7 +563,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFile=modelFile,
|
||||
fromSingleFile=fromSingleFile,
|
||||
torchType=torchType,
|
||||
variant=variant
|
||||
variant=variant,
|
||||
device_map=device_map
|
||||
)
|
||||
|
||||
print(f"LoadModel: After loading - ltx2_pipeline: {self.ltx2_pipeline}, img2vid: {self.img2vid}, txt2vid: {self.txt2vid}, PipelineType: {self.PipelineType}", file=sys.stderr)
|
||||
@@ -542,7 +589,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
if request.ControlNet:
|
||||
self.controlnet = ControlNetModel.from_pretrained(
|
||||
request.ControlNet, torch_dtype=torchType, variant=variant
|
||||
request.ControlNet, torch_dtype=torchType, variant=variant, device_map=device_map
|
||||
)
|
||||
self.pipe.controlnet = self.controlnet
|
||||
else:
|
||||
@@ -581,7 +628,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights)
|
||||
|
||||
if device != "cpu":
|
||||
# Only move pipeline to device if NOT using device_map
|
||||
# device_map handles device placement automatically
|
||||
if device_map is None and device != "cpu":
|
||||
self.pipe.to(device)
|
||||
if self.controlnet:
|
||||
self.controlnet.to(device)
|
||||
@@ -737,6 +786,51 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
kwargs["prompt_embeds"] = conditioning
|
||||
kwargs["pooled_prompt_embeds"] = pooled
|
||||
# pass the kwargs dictionary to the self.pipe method
|
||||
image = self.pipe(
|
||||
guidance_scale=self.cfg_scale,
|
||||
**kwargs
|
||||
).images[0]
|
||||
elif SD_EMBED and SD_EMBED_AVAILABLE:
|
||||
if self.PipelineType == "StableDiffusionPipeline":
|
||||
(
|
||||
kwargs["prompt_embeds"],
|
||||
kwargs["negative_prompt_embeds"],
|
||||
) = get_weighted_text_embeddings_sd15(
|
||||
pipe = self.pipe,
|
||||
prompt = prompt,
|
||||
neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None,
|
||||
)
|
||||
if self.PipelineType == "StableDiffusionXLPipeline":
|
||||
(
|
||||
kwargs["prompt_embeds"],
|
||||
kwargs["negative_prompt_embeds"],
|
||||
kwargs["pooled_prompt_embeds"],
|
||||
kwargs["negative_pooled_prompt_embeds"],
|
||||
) = get_weighted_text_embeddings_sdxl(
|
||||
pipe = self.pipe,
|
||||
prompt = prompt,
|
||||
neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None
|
||||
)
|
||||
if self.PipelineType == "StableDiffusion3Pipeline":
|
||||
(
|
||||
kwargs["prompt_embeds"],
|
||||
kwargs["negative_prompt_embeds"],
|
||||
kwargs["pooled_prompt_embeds"],
|
||||
kwargs["negative_pooled_prompt_embeds"],
|
||||
) = get_weighted_text_embeddings_sd3(
|
||||
pipe = self.pipe,
|
||||
prompt = prompt,
|
||||
neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None
|
||||
)
|
||||
if self.PipelineType == "FluxTransformer2DModel":
|
||||
(
|
||||
kwargs["prompt_embeds"],
|
||||
kwargs["pooled_prompt_embeds"],
|
||||
) = get_weighted_text_embeddings_flux1(
|
||||
pipe = self.pipe,
|
||||
prompt = prompt,
|
||||
)
|
||||
|
||||
image = self.pipe(
|
||||
guidance_scale=self.cfg_scale,
|
||||
**kwargs
|
||||
|
||||
@@ -5,6 +5,7 @@ transformers
|
||||
torchvision==0.22.1
|
||||
accelerate
|
||||
compel
|
||||
git+https://github.com/xhinker/sd_embed
|
||||
peft
|
||||
sentencepiece
|
||||
torch==2.7.1
|
||||
|
||||
@@ -5,6 +5,7 @@ transformers
|
||||
torchvision
|
||||
accelerate
|
||||
compel
|
||||
git+https://github.com/xhinker/sd_embed
|
||||
peft
|
||||
sentencepiece
|
||||
torch
|
||||
|
||||
@@ -5,6 +5,7 @@ transformers
|
||||
torchvision
|
||||
accelerate
|
||||
compel
|
||||
git+https://github.com/xhinker/sd_embed
|
||||
peft
|
||||
sentencepiece
|
||||
torch
|
||||
|
||||
@@ -8,6 +8,7 @@ opencv-python
|
||||
transformers
|
||||
accelerate
|
||||
compel
|
||||
git+https://github.com/xhinker/sd_embed
|
||||
peft
|
||||
sentencepiece
|
||||
optimum-quanto
|
||||
|
||||
23
backend/python/faster-qwen3-tts/Makefile
Normal file
23
backend/python/faster-qwen3-tts/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: faster-qwen3-tts
|
||||
faster-qwen3-tts:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: faster-qwen3-tts
|
||||
@echo "Running faster-qwen3-tts..."
|
||||
bash run.sh
|
||||
@echo "faster-qwen3-tts run."
|
||||
|
||||
.PHONY: test
|
||||
test: faster-qwen3-tts
|
||||
@echo "Testing faster-qwen3-tts..."
|
||||
bash test.sh
|
||||
@echo "faster-qwen3-tts tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
193
backend/python/faster-qwen3-tts/backend.py
Normal file
193
backend/python/faster-qwen3-tts/backend.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
gRPC server of LocalAI for Faster Qwen3-TTS (CUDA graph capture, voice clone only).
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import traceback
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import torch
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
def is_float(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_int(s):
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
if not torch.cuda.is_available():
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="faster-qwen3-tts requires NVIDIA GPU with CUDA"
|
||||
)
|
||||
|
||||
self.options = {}
|
||||
for opt in request.Options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
model_path = request.Model or "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
|
||||
self.audio_path = request.AudioPath if hasattr(request, 'AudioPath') and request.AudioPath else None
|
||||
self.model_file = request.ModelFile if hasattr(request, 'ModelFile') and request.ModelFile else None
|
||||
self.model_path = request.ModelPath if hasattr(request, 'ModelPath') and request.ModelPath else None
|
||||
|
||||
from faster_qwen3_tts import FasterQwen3TTS
|
||||
print(f"Loading model from: {model_path}", file=sys.stderr)
|
||||
try:
|
||||
self.model = FasterQwen3TTS.from_pretrained(model_path)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Loading model: {type(e).__name__}: {e}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
print(f"Model loaded successfully: {model_path}", file=sys.stderr)
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def _get_ref_audio_path(self, request):
|
||||
if not self.audio_path:
|
||||
return None
|
||||
if os.path.isabs(self.audio_path):
|
||||
return self.audio_path
|
||||
if self.model_file:
|
||||
model_file_base = os.path.dirname(self.model_file)
|
||||
ref_path = os.path.join(model_file_base, self.audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
if self.model_path:
|
||||
ref_path = os.path.join(self.model_path, self.audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
return self.audio_path
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
if not request.dst:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="dst (output path) is required"
|
||||
)
|
||||
text = request.text.strip()
|
||||
if not text:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="Text is empty"
|
||||
)
|
||||
|
||||
language = request.language if hasattr(request, 'language') and request.language else None
|
||||
if not language or language == "":
|
||||
language = "English"
|
||||
|
||||
ref_audio = self._get_ref_audio_path(request)
|
||||
if not ref_audio:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="AudioPath is required for voice clone (set in LoadModel)"
|
||||
)
|
||||
ref_text = self.options.get("ref_text")
|
||||
if not ref_text and hasattr(request, 'ref_text') and request.ref_text:
|
||||
ref_text = request.ref_text
|
||||
if not ref_text:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="ref_text is required for voice clone (set via LoadModel Options, e.g. ref_text:Your reference transcript)"
|
||||
)
|
||||
|
||||
chunk_size = self.options.get("chunk_size")
|
||||
generation_kwargs = {}
|
||||
if chunk_size is not None:
|
||||
generation_kwargs["chunk_size"] = int(chunk_size)
|
||||
|
||||
audio_list, sr = self.model.generate_voice_clone(
|
||||
text=text,
|
||||
language=language,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
**generation_kwargs
|
||||
)
|
||||
|
||||
if audio_list is None or (isinstance(audio_list, list) and len(audio_list) == 0):
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="No audio output generated"
|
||||
)
|
||||
audio_data = audio_list[0] if isinstance(audio_list, list) else audio_list
|
||||
sf.write(request.dst, audio_data, sr)
|
||||
print(f"Saved output to {request.dst}", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in TTS: {err}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
]
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.")
|
||||
args = parser.parse_args()
|
||||
serve(args.addr)
|
||||
13
backend/python/faster-qwen3-tts/install.sh
Normal file
13
backend/python/faster-qwen3-tts/install.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation"
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torch
|
||||
torchaudio
|
||||
faster-qwen3-tts
|
||||
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
torchaudio
|
||||
faster-qwen3-tts
|
||||
4
backend/python/faster-qwen3-tts/requirements-l4t12.txt
Normal file
4
backend/python/faster-qwen3-tts/requirements-l4t12.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
torchaudio
|
||||
faster-qwen3-tts
|
||||
4
backend/python/faster-qwen3-tts/requirements-l4t13.txt
Normal file
4
backend/python/faster-qwen3-tts/requirements-l4t13.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
torchaudio
|
||||
faster-qwen3-tts
|
||||
8
backend/python/faster-qwen3-tts/requirements.txt
Normal file
8
backend/python/faster-qwen3-tts/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
soundfile
|
||||
setuptools
|
||||
six
|
||||
sox
|
||||
9
backend/python/faster-qwen3-tts/run.sh
Normal file
9
backend/python/faster-qwen3-tts/run.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
104
backend/python/faster-qwen3-tts/test.py
Normal file
104
backend/python/faster-qwen3-tts/test.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Tests for the faster-qwen3-tts gRPC backend.
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50052"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
)
|
||||
time.sleep(15)
|
||||
|
||||
def tearDown(self):
|
||||
self.service.terminate()
|
||||
try:
|
||||
self.service.communicate(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.service.kill()
|
||||
self.service.communicate()
|
||||
|
||||
def test_health(self):
|
||||
with grpc.insecure_channel("localhost:50052") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
reply = stub.Health(backend_pb2.HealthMessage(), timeout=5.0)
|
||||
self.assertEqual(reply.message, b"OK")
|
||||
|
||||
def test_load_model_requires_cuda(self):
|
||||
with grpc.insecure_channel("localhost:50052") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(
|
||||
Model="Qwen/Qwen3-TTS-12Hz-0.6B-Base",
|
||||
CUDA=True,
|
||||
),
|
||||
timeout=10.0,
|
||||
)
|
||||
self.assertFalse(response.success)
|
||||
|
||||
@unittest.skipUnless(
|
||||
__import__("torch").cuda.is_available(),
|
||||
"faster-qwen3-tts TTS requires CUDA",
|
||||
)
|
||||
def test_tts(self):
|
||||
import soundfile as sf
|
||||
try:
|
||||
with grpc.insecure_channel("localhost:50052") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
ref_audio = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
||||
ref_audio.close()
|
||||
try:
|
||||
sr = 22050
|
||||
duration = 1.0
|
||||
samples = int(sr * duration)
|
||||
sf.write(ref_audio.name, [0.0] * samples, sr)
|
||||
|
||||
response = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(
|
||||
Model="Qwen/Qwen3-TTS-12Hz-0.6B-Base",
|
||||
AudioPath=ref_audio.name,
|
||||
Options=["ref_text:Hello world"],
|
||||
),
|
||||
timeout=600.0,
|
||||
)
|
||||
self.assertTrue(response.success, response.message)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as out:
|
||||
output_path = out.name
|
||||
try:
|
||||
tts_response = stub.TTS(
|
||||
backend_pb2.TTSRequest(
|
||||
text="Test output.",
|
||||
dst=output_path,
|
||||
language="English",
|
||||
),
|
||||
timeout=120.0,
|
||||
)
|
||||
self.assertTrue(tts_response.success, tts_response.message)
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
self.assertGreater(os.path.getsize(output_path), 0)
|
||||
finally:
|
||||
if os.path.exists(output_path):
|
||||
os.unlink(output_path)
|
||||
finally:
|
||||
if os.path.exists(ref_audio.name):
|
||||
os.unlink(ref_audio.name)
|
||||
except Exception as err:
|
||||
self.fail(f"TTS test failed: {err}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/faster-qwen3-tts/test.sh
Normal file
11
backend/python/faster-qwen3-tts/test.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -10,7 +10,11 @@ import sys
|
||||
import os
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import moonshine_onnx
|
||||
from moonshine_voice import (
|
||||
Transcriber,
|
||||
get_model_for_language,
|
||||
load_wav_file,
|
||||
)
|
||||
|
||||
import grpc
|
||||
|
||||
@@ -25,16 +29,49 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
def __init__(self):
|
||||
self.transcriber = None
|
||||
self.model_name = None
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
print("Preparing models, please wait", file=sys.stderr)
|
||||
# Store the model name for use in transcription
|
||||
# Model name format: e.g., "moonshine/tiny"
|
||||
self.model_name = request.Model
|
||||
print(f"Model name set to: {self.model_name}", file=sys.stderr)
|
||||
|
||||
# Default values
|
||||
language = "en"
|
||||
model_arch = None
|
||||
|
||||
# Parse options from request
|
||||
options = request.Options
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
self.options[key] = value
|
||||
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
# Extract language and model_arch from options
|
||||
if "language" in self.options:
|
||||
language = self.options["language"]
|
||||
if "model_arch" in self.options:
|
||||
model_arch = self.options["model_arch"]
|
||||
|
||||
# Get the model path and architecture
|
||||
model_path, model_arch = get_model_for_language(language, model_arch)
|
||||
print(f"Loading model: {model_path} with architecture: {model_arch} for language: {language}", file=sys.stderr)
|
||||
|
||||
# Initialize the transcriber
|
||||
self.transcriber = Transcriber(model_path=model_path, model_arch=model_arch)
|
||||
print("Model loaded successfully", file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
@@ -43,33 +80,44 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
resultSegments = []
|
||||
text = ""
|
||||
try:
|
||||
# moonshine_onnx.transcribe returns a list of strings
|
||||
transcriptions = moonshine_onnx.transcribe(request.dst, self.model_name)
|
||||
if self.transcriber is None:
|
||||
raise Exception("Model not loaded. Call LoadModel first.")
|
||||
|
||||
# Load the audio file
|
||||
audio_data, sample_rate = load_wav_file(request.dst)
|
||||
print(f"Loaded audio file: {request.dst} with sample rate: {sample_rate}", file=sys.stderr)
|
||||
|
||||
# Transcribe without streaming
|
||||
transcript = self.transcriber.transcribe_without_streaming(
|
||||
audio_data, sample_rate=sample_rate, flags=0
|
||||
)
|
||||
|
||||
# Process transcript lines
|
||||
full_text_parts = []
|
||||
for idx, line in enumerate(transcript.lines):
|
||||
line_text = line.text.strip()
|
||||
full_text_parts.append(line_text)
|
||||
|
||||
# Create segment with timing information
|
||||
start_ms = int(line.start_time * 1000)
|
||||
end_ms = int((line.start_time + line.duration) * 1000)
|
||||
|
||||
resultSegments.append(backend_pb2.TranscriptSegment(
|
||||
id=idx,
|
||||
start=start_ms,
|
||||
end=end_ms,
|
||||
text=line_text
|
||||
))
|
||||
|
||||
print(f"Segment {idx}: [{line.start_time:.2f}s - {line.start_time + line.duration:.2f}s] {line_text}", file=sys.stderr)
|
||||
|
||||
# Combine all transcriptions into a single text
|
||||
if isinstance(transcriptions, list):
|
||||
text = " ".join(transcriptions)
|
||||
# Create segments for each transcription in the list
|
||||
for id, trans in enumerate(transcriptions):
|
||||
# Since moonshine doesn't provide timing info, we'll create a single segment
|
||||
# with id and text, using approximate timing
|
||||
resultSegments.append(backend_pb2.TranscriptSegment(
|
||||
id=id,
|
||||
start=0,
|
||||
end=0,
|
||||
text=trans
|
||||
))
|
||||
else:
|
||||
# Handle case where it's not a list (shouldn't happen, but be safe)
|
||||
text = str(transcriptions)
|
||||
resultSegments.append(backend_pb2.TranscriptSegment(
|
||||
id=0,
|
||||
start=0,
|
||||
end=0,
|
||||
text=text
|
||||
))
|
||||
text = " ".join(full_text_parts)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
|
||||
return backend_pb2.TranscriptResult(segments=resultSegments, text=text)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
grpcio-tools
|
||||
useful-moonshine-onnx@git+https://git@github.com/moonshine-ai/moonshine.git#subdirectory=moonshine-onnx
|
||||
moonshine-voice
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
grpcio-tools
|
||||
useful-moonshine-onnx@git+https://git@github.com/moonshine-ai/moonshine.git#subdirectory=moonshine-onnx
|
||||
moonshine-voice
|
||||
@@ -112,7 +112,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.assertGreaterEqual(len(transcript_response.segments), 0)
|
||||
|
||||
# Verify the transcription contains the expected text
|
||||
expected_text = "This is the micro machine man presenting the most midget miniature"
|
||||
expected_text = "This is the micro machine man"
|
||||
self.assertIn(
|
||||
expected_text.lower(),
|
||||
transcript_response.text.lower(),
|
||||
|
||||
@@ -32,7 +32,14 @@ if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
fi
|
||||
|
||||
|
||||
git clone https://github.com/neuphonic/neutts-air neutts-air
|
||||
git clone --depth 100 https://github.com/neuphonic/neutts-air neutts-air
|
||||
|
||||
cd neutts-air
|
||||
|
||||
git checkout 1737487debe5b40a0bb97875edce8c66b391722b
|
||||
|
||||
cd ..
|
||||
|
||||
|
||||
cp -rfv neutts-air/neuttsair ./
|
||||
|
||||
|
||||
@@ -3,3 +3,6 @@ protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
setuptools
|
||||
h11
|
||||
gradio
|
||||
uvicorn
|
||||
@@ -4,4 +4,6 @@ certifi
|
||||
packaging==24.1
|
||||
soundfile
|
||||
setuptools
|
||||
six
|
||||
six
|
||||
scipy
|
||||
librosa
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.76.0
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
@@ -4,5 +4,5 @@ numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.2
|
||||
sentence-transformers==5.2.3
|
||||
protobuf==6.33.5
|
||||
@@ -4,5 +4,5 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.2
|
||||
sentence-transformers==5.2.3
|
||||
protobuf==6.33.5
|
||||
@@ -4,5 +4,5 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.2
|
||||
sentence-transformers==5.2.3
|
||||
protobuf==6.33.5
|
||||
@@ -5,5 +5,5 @@ transformers
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.2
|
||||
sentence-transformers==5.2.3
|
||||
protobuf==6.33.5
|
||||
@@ -5,5 +5,5 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.2
|
||||
sentence-transformers==5.2.3
|
||||
protobuf==6.33.5
|
||||
@@ -4,5 +4,5 @@ numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.2
|
||||
sentence-transformers==5.2.3
|
||||
protobuf==6.33.5
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.76.0
|
||||
grpcio==1.78.1
|
||||
protobuf==6.33.5
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.76.0
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
@@ -9,7 +9,12 @@ else
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install "setuptools<70.0.0"
|
||||
else
|
||||
uv pip install "setuptools<70.0.0"
|
||||
fi
|
||||
# Apply patch to fix PyTorch compatibility issue in voxcpm
|
||||
# This fixes the "Dimension out of range" error in scaled_dot_product_attention
|
||||
# by changing .contiguous() to .unsqueeze(0) in the attention module
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
setuptools
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
certifi
|
||||
|
||||
@@ -319,6 +319,29 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.MemoryReclaimerThreshold = *settings.MemoryReclaimerThreshold
|
||||
}
|
||||
}
|
||||
if settings.ForceEvictionWhenBusy != nil {
|
||||
// Only apply if current value is default (false), suggesting it wasn't set from env var
|
||||
if !options.ForceEvictionWhenBusy {
|
||||
options.ForceEvictionWhenBusy = *settings.ForceEvictionWhenBusy
|
||||
}
|
||||
}
|
||||
if settings.LRUEvictionMaxRetries != nil {
|
||||
// Only apply if current value is default (30), suggesting it wasn't set from env var
|
||||
if options.LRUEvictionMaxRetries == 0 {
|
||||
options.LRUEvictionMaxRetries = *settings.LRUEvictionMaxRetries
|
||||
}
|
||||
}
|
||||
if settings.LRUEvictionRetryInterval != nil {
|
||||
// Only apply if current value is default (1s), suggesting it wasn't set from env var
|
||||
if options.LRUEvictionRetryInterval == 0 {
|
||||
dur, err := time.ParseDuration(*settings.LRUEvictionRetryInterval)
|
||||
if err == nil {
|
||||
options.LRUEvictionRetryInterval = dur
|
||||
} else {
|
||||
xlog.Warn("invalid LRU eviction retry interval in runtime_settings.json", "error", err, "interval", *settings.LRUEvictionRetryInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings.AgentJobRetentionDays != nil {
|
||||
// Only apply if current value is default (0), suggesting it wasn't set from env var
|
||||
if options.AgentJobRetentionDays == 0 {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -37,11 +35,15 @@ func (a *Application) startWatchdog() error {
|
||||
model.WithMemoryReclaimer(appConfig.MemoryReclaimerEnabled, appConfig.MemoryReclaimerThreshold),
|
||||
model.WithForceEvictionWhenBusy(appConfig.ForceEvictionWhenBusy),
|
||||
)
|
||||
a.modelLoader.SetWatchDog(wd)
|
||||
|
||||
// Create new stop channel
|
||||
// Create new stop channel BEFORE setting up any goroutines
|
||||
// This prevents race conditions where the old shutdown handler might
|
||||
// receive the closed channel and try to shut down the new watchdog
|
||||
a.watchdogStop = make(chan bool, 1)
|
||||
|
||||
// Set the watchdog on the model loader
|
||||
a.modelLoader.SetWatchDog(wd)
|
||||
|
||||
// Start watchdog goroutine if any periodic checks are enabled
|
||||
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
||||
// But memory reclaimer needs the Run() loop for periodic checking
|
||||
@@ -49,15 +51,19 @@ func (a *Application) startWatchdog() error {
|
||||
go wd.Run()
|
||||
}
|
||||
|
||||
// Setup shutdown handler
|
||||
// Setup shutdown handler - this goroutine will wait on a.watchdogStop
|
||||
// which is now a fresh channel, so it won't receive any stale signals
|
||||
// Note: We capture wd in a local variable to ensure this handler operates
|
||||
// on the correct watchdog instance (not a later one that gets assigned to wd)
|
||||
wdForShutdown := wd
|
||||
go func() {
|
||||
select {
|
||||
case <-a.watchdogStop:
|
||||
xlog.Debug("Watchdog stop signal received")
|
||||
wd.Shutdown()
|
||||
wdForShutdown.Shutdown()
|
||||
case <-appConfig.Context.Done():
|
||||
xlog.Debug("Context canceled, shutting down watchdog")
|
||||
wd.Shutdown()
|
||||
wdForShutdown.Shutdown()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -82,20 +88,41 @@ func (a *Application) RestartWatchdog() error {
|
||||
a.watchdogMutex.Lock()
|
||||
defer a.watchdogMutex.Unlock()
|
||||
|
||||
// Shutdown existing watchdog if running
|
||||
// Get the old watchdog before we shut it down
|
||||
oldWD := a.modelLoader.GetWatchDog()
|
||||
|
||||
// Get the state from the old watchdog before shutting it down
|
||||
// This preserves information about loaded models
|
||||
var oldState model.WatchDogState
|
||||
if oldWD != nil {
|
||||
oldState = oldWD.GetState()
|
||||
}
|
||||
|
||||
// Signal all handlers to stop by closing the stop channel
|
||||
// This will cause any goroutine waiting on <-a.watchdogStop to unblock
|
||||
if a.watchdogStop != nil {
|
||||
close(a.watchdogStop)
|
||||
a.watchdogStop = nil
|
||||
}
|
||||
|
||||
// Shutdown existing watchdog if running
|
||||
currentWD := a.modelLoader.GetWatchDog()
|
||||
if currentWD != nil {
|
||||
currentWD.Shutdown()
|
||||
// Wait a bit for shutdown to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Shutdown existing watchdog - this triggers the stop signal
|
||||
if oldWD != nil {
|
||||
oldWD.Shutdown()
|
||||
// Wait for the old watchdog's Run() goroutine to fully shut down
|
||||
oldWD.WaitDone()
|
||||
}
|
||||
|
||||
// Start watchdog with new settings
|
||||
return a.startWatchdog()
|
||||
if err := a.startWatchdog(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore the model state from the old watchdog to the new one
|
||||
// This ensures the new watchdog knows about already-loaded models
|
||||
newWD := a.modelLoader.GetWatchDog()
|
||||
if newWD != nil && len(oldState.AddressModelMap) > 0 {
|
||||
newWD.RestoreState(oldState)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@ package backend
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -53,7 +55,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConf
|
||||
}
|
||||
}
|
||||
|
||||
return func() ([]float32, error) {
|
||||
wrappedFn := func() ([]float32, error) {
|
||||
embeds, err := fn()
|
||||
if err != nil {
|
||||
return embeds, err
|
||||
@@ -67,5 +69,48 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConf
|
||||
}
|
||||
}
|
||||
return embeds, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"input_text": trace.TruncateString(s, 1000),
|
||||
"input_tokens_count": len(tokens),
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
originalFn := wrappedFn
|
||||
wrappedFn = func() ([]float32, error) {
|
||||
result, err := originalFn()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
traceData["embedding_dimensions"] = len(result)
|
||||
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
summary := trace.TruncateString(s, 200)
|
||||
if summary == "" {
|
||||
summary = fmt.Sprintf("tokens[%d]", len(tokens))
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: duration,
|
||||
Type: trace.BackendTraceEmbedding,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: summary,
|
||||
Error: errStr,
|
||||
Data: traceData,
|
||||
})
|
||||
|
||||
return result, err
|
||||
}
|
||||
}
|
||||
|
||||
return wrappedFn, nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -36,6 +39,46 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
|
||||
return err
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"positive_prompt": positive_prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"step": step,
|
||||
"seed": seed,
|
||||
"source_image": src,
|
||||
"destination": dst,
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
originalFn := fn
|
||||
fn = func() error {
|
||||
err := originalFn()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: duration,
|
||||
Type: trace.BackendTraceImageGeneration,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(positive_prompt, 200),
|
||||
Error: errStr,
|
||||
Data: traceData,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,11 +7,13 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
@@ -220,6 +222,84 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
if o.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(o.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"prompt": s,
|
||||
"use_tokenizer_template": c.TemplateConfig.UseTokenizerTemplate,
|
||||
"chat_template": c.TemplateConfig.Chat,
|
||||
"function_template": c.TemplateConfig.Functions,
|
||||
"grammar": c.Grammar,
|
||||
"stop_words": c.StopWords,
|
||||
"streaming": tokenCallback != nil,
|
||||
"images_count": len(images),
|
||||
"videos_count": len(videos),
|
||||
"audios_count": len(audios),
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
if msgJSON, err := json.Marshal(messages); err == nil {
|
||||
traceData["messages"] = string(msgJSON)
|
||||
}
|
||||
}
|
||||
if tools != "" {
|
||||
traceData["tools"] = tools
|
||||
}
|
||||
if toolChoice != "" {
|
||||
traceData["tool_choice"] = toolChoice
|
||||
}
|
||||
if reasoningJSON, err := json.Marshal(c.ReasoningConfig); err == nil {
|
||||
traceData["reasoning_config"] = string(reasoningJSON)
|
||||
}
|
||||
traceData["functions_config"] = map[string]any{
|
||||
"grammar_disabled": c.FunctionsConfig.GrammarConfig.NoGrammar,
|
||||
"parallel_calls": c.FunctionsConfig.GrammarConfig.ParallelCalls,
|
||||
"mixed_mode": c.FunctionsConfig.GrammarConfig.MixedMode,
|
||||
"xml_format_preset": c.FunctionsConfig.XMLFormatPreset,
|
||||
}
|
||||
if c.Temperature != nil {
|
||||
traceData["temperature"] = *c.Temperature
|
||||
}
|
||||
if c.TopP != nil {
|
||||
traceData["top_p"] = *c.TopP
|
||||
}
|
||||
if c.Maxtokens != nil {
|
||||
traceData["max_tokens"] = *c.Maxtokens
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
originalFn := fn
|
||||
fn = func() (LLMResponse, error) {
|
||||
resp, err := originalFn()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
traceData["response"] = resp.Response
|
||||
traceData["token_usage"] = map[string]any{
|
||||
"prompt": resp.Usage.Prompt,
|
||||
"completion": resp.Usage.Completion,
|
||||
}
|
||||
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: duration,
|
||||
Type: trace.BackendTraceLLM,
|
||||
ModelName: c.Name,
|
||||
Backend: c.Backend,
|
||||
Summary: trace.GenerateLLMSummary(messages, s),
|
||||
Error: errStr,
|
||||
Data: traceData,
|
||||
})
|
||||
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ package backend
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
@@ -20,7 +22,35 @@ func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *
|
||||
return nil, fmt.Errorf("could not load rerank model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := rerankModel.Rerank(context.Background(), request)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceRerank,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(request.Query, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
"query": request.Query,
|
||||
"documents_count": len(request.Documents),
|
||||
"top_n": request.TopN,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -92,7 +94,51 @@ func SoundGeneration(
|
||||
req.Instrumental = instrumental
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := soundGenModel.SoundGeneration(context.Background(), req)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
} else if res != nil && !res.Success {
|
||||
errStr = fmt.Sprintf("sound generation error: %s", res.Message)
|
||||
}
|
||||
|
||||
summary := trace.TruncateString(text, 200)
|
||||
if summary == "" && caption != "" {
|
||||
summary = trace.TruncateString(caption, 200)
|
||||
}
|
||||
|
||||
traceData := map[string]any{
|
||||
"text": text,
|
||||
"caption": caption,
|
||||
"lyrics": lyrics,
|
||||
}
|
||||
if duration != nil {
|
||||
traceData["duration"] = *duration
|
||||
}
|
||||
if temperature != nil {
|
||||
traceData["temperature"] = *temperature
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceSoundGeneration,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: summary,
|
||||
Error: errStr,
|
||||
Data: traceData,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -21,8 +24,41 @@ func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.Model
|
||||
predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath)
|
||||
predictOptions.Prompt = s
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
// tokenize the string
|
||||
resp, err := inferenceModel.TokenizeString(appConfig.Context, predictOptions)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
tokenCount := 0
|
||||
if resp.Tokens != nil {
|
||||
tokenCount = len(resp.Tokens)
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceTokenize,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(s, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
"input_text": trace.TruncateString(s, 1000),
|
||||
"token_count": tokenCount,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return schema.TokenizeResponse{}, err
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -28,6 +29,12 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
return nil, fmt.Errorf("could not load transcription model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
Dst: audio,
|
||||
Language: language,
|
||||
@@ -37,6 +44,24 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
Prompt: prompt,
|
||||
})
|
||||
if err != nil {
|
||||
if appConfig.EnableTracing {
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceTranscription,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(audio, 200),
|
||||
Error: err.Error(),
|
||||
Data: map[string]any{
|
||||
"audio_file": audio,
|
||||
"language": language,
|
||||
"translate": translate,
|
||||
"diarize": diarize,
|
||||
"prompt": prompt,
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
tr := &schema.TranscriptionResult{
|
||||
@@ -57,5 +82,26 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
Speaker: s.Speaker,
|
||||
})
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceTranscription,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(audio+" -> "+tr.Text, 200),
|
||||
Data: map[string]any{
|
||||
"audio_file": audio,
|
||||
"language": language,
|
||||
"translate": translate,
|
||||
"diarize": diarize,
|
||||
"prompt": prompt,
|
||||
"result_text": tr.Text,
|
||||
"segments_count": len(tr.Segments),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return tr, err
|
||||
}
|
||||
|
||||
@@ -8,8 +8,10 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
laudio "github.com/mudler/LocalAI/pkg/audio"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -60,6 +62,12 @@ func ModelTTS(
|
||||
modelPath = modelConfig.Model // skip this step if it fails?????
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
@@ -67,6 +75,31 @@ func ModelTTS(
|
||||
Dst: filePath,
|
||||
Language: &language,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
} else if !res.Success {
|
||||
errStr = fmt.Sprintf("TTS error: %s", res.Message)
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceTTS,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(text, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@@ -115,6 +148,12 @@ func ModelTTSStream(
|
||||
modelPath = modelConfig.Model // skip this step if it fails?????
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
var sampleRate uint32 = 16000 // default
|
||||
headerSent := false
|
||||
var callbackErr error
|
||||
@@ -171,6 +210,34 @@ func ModelTTSStream(
|
||||
}
|
||||
})
|
||||
|
||||
resultErr := err
|
||||
if callbackErr != nil {
|
||||
resultErr = callbackErr
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if resultErr != nil {
|
||||
errStr = resultErr.Error()
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceTTS,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(text, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"language": language,
|
||||
"streaming": true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if callbackErr != nil {
|
||||
return callbackErr
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -37,5 +40,46 @@ func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, en
|
||||
return err
|
||||
}
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
|
||||
traceData := map[string]any{
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negativePrompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"num_frames": numFrames,
|
||||
"fps": fps,
|
||||
"seed": seed,
|
||||
"cfg_scale": cfgScale,
|
||||
"step": step,
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
originalFn := fn
|
||||
fn = func() error {
|
||||
err := originalFn()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: duration,
|
||||
Type: trace.BackendTraceVideoGeneration,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(prompt, 200),
|
||||
Error: errStr,
|
||||
Data: traceData,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
@@ -71,6 +71,7 @@ type RunCMD struct {
|
||||
WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"`
|
||||
EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"`
|
||||
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
|
||||
WatchdogInterval string `env:"LOCALAI_WATCHDOG_INTERVAL,WATCHDOG_INTERVAL" default:"500ms" help:"Interval between watchdog checks (e.g., 500ms, 5s, 1m) (default: 500ms)" group:"backends"`
|
||||
EnableMemoryReclaimer bool `env:"LOCALAI_MEMORY_RECLAIMER,MEMORY_RECLAIMER,LOCALAI_GPU_RECLAIMER,GPU_RECLAIMER" default:"false" help:"Enable memory threshold monitoring to auto-evict backends when memory usage exceeds threshold (uses GPU VRAM if available, otherwise RAM)" group:"backends"`
|
||||
MemoryReclaimerThreshold float64 `env:"LOCALAI_MEMORY_RECLAIMER_THRESHOLD,MEMORY_RECLAIMER_THRESHOLD,LOCALAI_GPU_RECLAIMER_THRESHOLD,GPU_RECLAIMER_THRESHOLD" default:"0.95" help:"Memory usage threshold (0.0-1.0) that triggers backend eviction (default 0.95 = 95%%)" group:"backends"`
|
||||
ForceEvictionWhenBusy bool `env:"LOCALAI_FORCE_EVICTION_WHEN_BUSY,FORCE_EVICTION_WHEN_BUSY" default:"false" help:"Force eviction even when models have active API calls (default: false for safety)" group:"backends"`
|
||||
@@ -83,7 +84,7 @@ type RunCMD struct {
|
||||
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
|
||||
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
|
||||
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`
|
||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||
OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"`
|
||||
|
||||
Version bool
|
||||
}
|
||||
@@ -215,6 +216,13 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
opts = append(opts, config.SetWatchDogBusyTimeout(dur))
|
||||
}
|
||||
if r.WatchdogInterval != "" {
|
||||
dur, err := time.ParseDuration(r.WatchdogInterval)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts = append(opts, config.SetWatchDogInterval(dur))
|
||||
}
|
||||
}
|
||||
|
||||
// Handle memory reclaimer (uses GPU VRAM if available, otherwise RAM)
|
||||
|
||||
@@ -31,8 +31,8 @@ type TranscriptCMD struct {
|
||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
|
||||
ResponseFormat schema.TranscriptionResponseFormatType `short:"f" default:"" help:"Response format for Whisper models, can be one of (txt, lrc, srt, vtt, json, json_verbose)"`
|
||||
PrettyPrint bool `help:"Used with response_format json or json_verbose for pretty printing"`
|
||||
ResponseFormat schema.TranscriptionResponseFormatType `short:"f" default:"" help:"Response format for Whisper models, can be one of (txt, lrc, srt, vtt, json, verbose_json)"`
|
||||
PrettyPrint bool `help:"Used with response_format json or verbose_json for pretty printing"`
|
||||
}
|
||||
|
||||
func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
@@ -98,10 +98,11 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
Context: context.Background(),
|
||||
UploadLimitMB: 15,
|
||||
Debug: true,
|
||||
AgentJobRetentionDays: 30, // Default: 30 days
|
||||
LRUEvictionMaxRetries: 30, // Default: 30 retries
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
TracingMaxItems: 1024,
|
||||
AgentJobRetentionDays: 30, // Default: 30 days
|
||||
LRUEvictionMaxRetries: 30, // Default: 30 retries
|
||||
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
|
||||
WatchDogInterval: 500 * time.Millisecond, // Default: 500ms
|
||||
TracingMaxItems: 1024,
|
||||
PathWithoutAuth: []string{
|
||||
"/static/",
|
||||
"/generated-audio/",
|
||||
@@ -208,6 +209,12 @@ func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func SetWatchDogInterval(t time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.WatchDogInterval = t
|
||||
}
|
||||
}
|
||||
|
||||
// EnableMemoryReclaimer enables memory threshold monitoring.
|
||||
// When enabled, the watchdog will evict backends if memory usage exceeds the threshold.
|
||||
// Works with GPU VRAM if available, otherwise uses system RAM.
|
||||
@@ -642,7 +649,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
AutoloadBackendGalleries: &autoloadBackendGalleries,
|
||||
ApiKeys: &apiKeys,
|
||||
AgentJobRetentionDays: &agentJobRetentionDays,
|
||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||
OpenResponsesStoreTTL: &openResponsesStoreTTL,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +99,10 @@ type AgentConfig struct {
|
||||
EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"`
|
||||
EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"`
|
||||
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"`
|
||||
DisableSinkState bool `yaml:"disable_sink_state,omitempty" json:"disable_sink_state,omitempty"`
|
||||
LoopDetection int `yaml:"loop_detection,omitempty" json:"loop_detection,omitempty"`
|
||||
MaxAdjustmentAttempts int `yaml:"max_adjustment_attempts,omitempty" json:"max_adjustment_attempts,omitempty"`
|
||||
ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"`
|
||||
}
|
||||
|
||||
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
|
||||
@@ -704,7 +708,7 @@ func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
|
||||
|
||||
// Apply agent configuration options
|
||||
if c.Agent.EnableReasoning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithForceReasoning())
|
||||
}
|
||||
|
||||
if c.Agent.EnablePlanning {
|
||||
@@ -727,5 +731,21 @@ func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
|
||||
}
|
||||
|
||||
if c.Agent.DisableSinkState {
|
||||
cogitoOpts = append(cogitoOpts, cogito.DisableSinkState)
|
||||
}
|
||||
|
||||
if c.Agent.LoopDetection != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithLoopDetection(c.Agent.LoopDetection))
|
||||
}
|
||||
|
||||
if c.Agent.MaxAdjustmentAttempts != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithMaxAdjustmentAttempts(c.Agent.MaxAdjustmentAttempts))
|
||||
}
|
||||
|
||||
if c.Agent.ForceReasoningTool {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithForceReasoningTool())
|
||||
}
|
||||
|
||||
return cogitoOpts
|
||||
}
|
||||
|
||||
@@ -76,42 +76,35 @@ func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig
|
||||
func readMultipleModelConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*ModelConfig, error) {
|
||||
c := &[]*ModelConfig{}
|
||||
// readModelConfigsFromFile reads a config file that may contain either a single
|
||||
// ModelConfig or an array of ModelConfigs. It tries to unmarshal as an array first,
|
||||
// then falls back to a single config if that fails.
|
||||
func readModelConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*ModelConfig, error) {
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot read config file %q: %w", file, err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("readMultipleModelConfigsFromFile cannot unmarshal config file %q: %w", file, err)
|
||||
return nil, fmt.Errorf("readModelConfigsFromFile cannot read config file %q: %w", file, err)
|
||||
}
|
||||
|
||||
for _, cc := range *c {
|
||||
cc.modelConfigFile = file
|
||||
cc.SetDefaults(opts...)
|
||||
// Try to unmarshal as array first
|
||||
var configs []*ModelConfig
|
||||
if err := yaml.Unmarshal(f, &configs); err == nil && len(configs) > 0 {
|
||||
for _, cc := range configs {
|
||||
cc.modelConfigFile = file
|
||||
cc.SetDefaults(opts...)
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
func readModelConfigFromFile(file string, opts ...ConfigLoaderOption) (*ModelConfig, error) {
|
||||
lo := &LoadOptions{}
|
||||
lo.Apply(opts...)
|
||||
|
||||
// Fall back to single config
|
||||
c := &ModelConfig{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("readModelConfigFromFile cannot read config file %q: %w", file, err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("readModelConfigFromFile cannot unmarshal config file %q: %w", file, err)
|
||||
return nil, fmt.Errorf("readModelConfigsFromFile cannot unmarshal config file %q: %w", file, err)
|
||||
}
|
||||
|
||||
c.SetDefaults(opts...)
|
||||
|
||||
c.modelConfigFile = file
|
||||
return c, nil
|
||||
c.SetDefaults(opts...)
|
||||
|
||||
return []*ModelConfig{c}, nil
|
||||
}
|
||||
|
||||
// Load a config file for a model
|
||||
@@ -163,7 +156,7 @@ func (bcl *ModelConfigLoader) LoadModelConfigFileByNameDefaultOptions(modelName
|
||||
func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
c, err := readMultipleModelConfigsFromFile(file, opts...)
|
||||
c, err := readModelConfigsFromFile(file, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot load config file: %w", err)
|
||||
}
|
||||
@@ -181,11 +174,18 @@ func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, op
|
||||
func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
c, err := readModelConfigFromFile(file, opts...)
|
||||
configs, err := readModelConfigsFromFile(file, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err)
|
||||
}
|
||||
if len(configs) == 0 {
|
||||
return fmt.Errorf("ReadModelConfig: no configs found in file %q", file)
|
||||
}
|
||||
if len(configs) > 1 {
|
||||
xlog.Warn("ReadModelConig: read more than one config from file, only using first", "file", file, "configs", len(configs))
|
||||
}
|
||||
|
||||
c := configs[0]
|
||||
if valid, err := c.Validate(); valid {
|
||||
bcl.configs[c.Name] = *c
|
||||
} else {
|
||||
@@ -375,15 +375,23 @@ func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...Conf
|
||||
strings.HasPrefix(file.Name(), ".") {
|
||||
continue
|
||||
}
|
||||
c, err := readModelConfigFromFile(filepath.Join(path, file.Name()), opts...)
|
||||
|
||||
filePath := filepath.Join(path, file.Name())
|
||||
|
||||
// Read config(s) - handles both single and array formats
|
||||
configs, err := readModelConfigsFromFile(filePath, opts...)
|
||||
if err != nil {
|
||||
xlog.Error("LoadModelConfigsFromPath cannot read config file", "error", err, "File Name", file.Name())
|
||||
continue
|
||||
}
|
||||
if valid, validationErr := c.Validate(); valid {
|
||||
bcl.configs[c.Name] = *c
|
||||
} else {
|
||||
xlog.Error("config is not valid", "error", validationErr, "Name", c.Name)
|
||||
|
||||
// Validate and store each config
|
||||
for _, c := range configs {
|
||||
if valid, validationErr := c.Validate(); valid {
|
||||
bcl.configs[c.Name] = *c
|
||||
} else {
|
||||
xlog.Error("config is not valid", "error", validationErr, "Name", c.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,8 @@ known_usecases:
|
||||
- COMPLETION
|
||||
`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := readModelConfigFromFile(tmp.Name())
|
||||
configs, err := readModelConfigsFromFile(tmp.Name())
|
||||
config := configs[0]
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
valid, err := config.Validate()
|
||||
@@ -43,7 +44,8 @@ backend: "foo-bar"
|
||||
parameters:
|
||||
model: "foo-bar"`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := readModelConfigFromFile(tmp.Name())
|
||||
configs, err := readModelConfigsFromFile(tmp.Name())
|
||||
config := configs[0]
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
@@ -62,7 +64,8 @@ parameters:
|
||||
defer os.Remove(tmp.Name())
|
||||
_, err = io.Copy(tmp, resp.Body)
|
||||
Expect(err).To(BeNil())
|
||||
config, err = readModelConfigFromFile(tmp.Name())
|
||||
configs, err = readModelConfigsFromFile(tmp.Name())
|
||||
config = configs[0]
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
@@ -188,7 +191,8 @@ mcp:
|
||||
}
|
||||
}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := readModelConfigFromFile(tmp.Name())
|
||||
configs, err := readModelConfigsFromFile(tmp.Name())
|
||||
config := configs[0]
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
valid, err := config.Validate()
|
||||
@@ -218,7 +222,8 @@ mcp:
|
||||
}
|
||||
}`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := readModelConfigFromFile(tmp.Name())
|
||||
configs, err := readModelConfigsFromFile(tmp.Name())
|
||||
config := configs[0]
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
valid, err := config.Validate()
|
||||
|
||||
@@ -16,7 +16,7 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
Context("Test Read configuration functions", func() {
|
||||
configFile = os.Getenv("CONFIG_FILE")
|
||||
It("Test readConfigFile", func() {
|
||||
config, err := readMultipleModelConfigsFromFile(configFile)
|
||||
config, err := readModelConfigsFromFile(configFile)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/xsync"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/gallery"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var _ = Describe("Gallery", func() {
|
||||
@@ -462,4 +463,60 @@ var _ = Describe("Gallery", func() {
|
||||
Expect(result).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("YAML merge with nested maps", func() {
|
||||
It("should handle YAML anchors and merges with nested overrides (regression test for nanbeige4.1)", func() {
|
||||
// This tests the fix for the panic that occurred with yaml.v2:
|
||||
// yaml.v2 produces map[interface{}]interface{} for nested maps
|
||||
// which caused mergo.Merge to panic with "value of type interface {} is not assignable to type string"
|
||||
// The exact YAML structure from gallery/index.yaml nanbeige4.1 entries
|
||||
yamlContent := `---
|
||||
- &nanbeige4
|
||||
name: "nanbeige4.1-3b-q8"
|
||||
overrides:
|
||||
parameters:
|
||||
model: nanbeige4.1-3b-q8_0.gguf
|
||||
- !!merge <<: *nanbeige4
|
||||
name: "nanbeige4.1-3b-q4"
|
||||
overrides:
|
||||
parameters:
|
||||
model: nanbeige4.1-3b-q4_k_m.gguf
|
||||
`
|
||||
var models []GalleryModel
|
||||
err := yaml.Unmarshal([]byte(yamlContent), &models)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(models).To(HaveLen(2))
|
||||
|
||||
// Verify first model
|
||||
Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8"))
|
||||
Expect(models[0].Overrides).NotTo(BeNil())
|
||||
Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{}))
|
||||
params := models[0].Overrides["parameters"].(map[string]interface{})
|
||||
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf"))
|
||||
|
||||
// Verify second model (merged)
|
||||
Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4"))
|
||||
Expect(models[1].Overrides).NotTo(BeNil())
|
||||
Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{}))
|
||||
params = models[1].Overrides["parameters"].(map[string]interface{})
|
||||
Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
||||
|
||||
// Simulate the mergo.Merge call that was failing in models.go:251
|
||||
// This should not panic with yaml.v3
|
||||
configMap := make(map[string]interface{})
|
||||
configMap["name"] = "test"
|
||||
configMap["backend"] = "llama-cpp"
|
||||
configMap["parameters"] = map[string]interface{}{
|
||||
"model": "original.gguf",
|
||||
}
|
||||
|
||||
err = mergo.Merge(&configMap, models[1].Overrides, mergo.WithOverride)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(configMap["parameters"]).NotTo(BeNil())
|
||||
|
||||
// Verify the merge worked correctly
|
||||
mergedParams := configMap["parameters"].(map[string]interface{})
|
||||
Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -215,7 +215,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
|
||||
return nil, fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err)
|
||||
}
|
||||
// Create and write file content
|
||||
err = os.WriteFile(filePath, []byte(template.Content), 0600)
|
||||
err = os.WriteFile(filePath, []byte(template.Content), 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write prompt template %q: %v", template.Name, err)
|
||||
}
|
||||
@@ -268,7 +268,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
|
||||
return nil, fmt.Errorf("failed to validate updated config YAML: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(configFilePath, updatedConfigYAML, 0600)
|
||||
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write updated config file: %v", err)
|
||||
}
|
||||
@@ -285,7 +285,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
|
||||
|
||||
xlog.Debug("Written gallery file", "file", modelFile)
|
||||
|
||||
return &modelConfig, os.WriteFile(modelFile, data, 0600)
|
||||
return &modelConfig, os.WriteFile(modelFile, data, 0644)
|
||||
}
|
||||
|
||||
func galleryFileName(name string) string {
|
||||
|
||||
@@ -29,6 +29,8 @@ import (
|
||||
//go:embed static/*
|
||||
var embedDirStatic embed.FS
|
||||
|
||||
var quietPaths = []string{"/api/operations", "/api/resources", "/healthz", "/readyz"}
|
||||
|
||||
// @title LocalAI API
|
||||
// @version 2.0.0
|
||||
// @description The LocalAI Rest API.
|
||||
@@ -109,10 +111,17 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
res := c.Response()
|
||||
err := next(c)
|
||||
|
||||
// Fix for #7989: Reduce log verbosity of Web UI polling
|
||||
// If the path is /api/operations and the request was successful (200),
|
||||
// we log it at DEBUG level (hidden by default) instead of INFO.
|
||||
if req.URL.Path == "/api/operations" && res.Status == 200 {
|
||||
// Fix for #7989: Reduce log verbosity of Web UI polling, resources API, and health checks
|
||||
// These paths are logged at DEBUG level (hidden by default) instead of INFO.
|
||||
isQuietPath := false
|
||||
for _, path := range quietPaths {
|
||||
if req.URL.Path == path {
|
||||
isQuietPath = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isQuietPath && res.Status == 200 {
|
||||
xlog.Debug("HTTP request", "method", req.Method, "path", req.URL.Path, "status", res.Status)
|
||||
} else {
|
||||
xlog.Info("HTTP request", "method", req.Method, "path", req.URL.Path, "status", res.Status)
|
||||
|
||||
@@ -336,6 +336,7 @@ var _ = Describe("API test", func() {
|
||||
Name: "bert",
|
||||
URL: bertEmbeddingsURL,
|
||||
},
|
||||
Overrides: map[string]interface{}{"backend": "llama-cpp"},
|
||||
},
|
||||
{
|
||||
Metadata: gallery.Metadata{
|
||||
@@ -953,7 +954,8 @@ parameters:
|
||||
It("returns the models list", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models.Models)).To(Equal(7)) // If "config.yaml" should be included, this should be 8?
|
||||
// A model called "bert" can be present in the model directory depending on the order of the tests
|
||||
Expect(len(models.Models)).To(BeNumerically(">=", 8))
|
||||
})
|
||||
It("can generate completions via ggml", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -55,20 +56,22 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
|
||||
// Render the edit page with the current configuration
|
||||
templateData := struct {
|
||||
Title string
|
||||
ModelName string
|
||||
Config *config.ModelConfig
|
||||
ConfigJSON string
|
||||
ConfigYAML string
|
||||
BaseURL string
|
||||
Version string
|
||||
Title string
|
||||
ModelName string
|
||||
Config *config.ModelConfig
|
||||
ConfigJSON string
|
||||
ConfigYAML string
|
||||
BaseURL string
|
||||
Version string
|
||||
DisableRuntimeSettings bool
|
||||
}{
|
||||
Title: "LocalAI - Edit Model " + modelName,
|
||||
ModelName: modelName,
|
||||
Config: &modelConfig,
|
||||
ConfigYAML: string(configData),
|
||||
BaseURL: httpUtils.BaseURL(c),
|
||||
Version: internal.PrintableVersion(),
|
||||
Title: "LocalAI - Edit Model " + modelName,
|
||||
ModelName: modelName,
|
||||
Config: &modelConfig,
|
||||
ConfigYAML: string(configData),
|
||||
BaseURL: httpUtils.BaseURL(c),
|
||||
Version: internal.PrintableVersion(),
|
||||
DisableRuntimeSettings: appConfig.DisableRuntimeSettings,
|
||||
}
|
||||
|
||||
return c.Render(http.StatusOK, "views/model-editor", templateData)
|
||||
@@ -76,7 +79,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
}
|
||||
|
||||
// EditModelEndpoint handles updating existing model configurations
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if modelName == "" {
|
||||
@@ -172,6 +175,14 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Shutdown the running model to apply new configuration (e.g., context_size)
|
||||
// The model will be reloaded on the next inference request
|
||||
if err := ml.ShutdownModel(modelName); err != nil {
|
||||
// Log the error but don't fail the request - the config was saved successfully
|
||||
// The model can still be manually reloaded or restarted
|
||||
fmt.Printf("Warning: Failed to shutdown model '%s': %v\n", modelName, err)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
@@ -184,7 +195,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
// Return success response
|
||||
response := ModelResponse{
|
||||
Success: true,
|
||||
Message: fmt.Sprintf("Model '%s' updated successfully", modelName),
|
||||
Message: fmt.Sprintf("Model '%s' updated successfully. Model has been reloaded with new configuration.", modelName),
|
||||
Filename: configPath,
|
||||
Config: req,
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// Build fragment from messages
|
||||
fragment := cogito.NewEmptyFragment()
|
||||
for _, message := range input.Messages {
|
||||
fragment = fragment.AddMessage(message.Role, message.StringContent)
|
||||
fragment = fragment.AddMessage(cogito.MessageRole(message.Role), message.StringContent)
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(appConfig.APIAddress)
|
||||
@@ -162,11 +162,6 @@ func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = defaultLLM.Ask(ctxWithCancellation, f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
@@ -252,17 +247,6 @@ func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return
|
||||
}
|
||||
|
||||
// Get final response
|
||||
f, err = defaultLLM.Ask(ctxWithCancellation, f)
|
||||
if err != nil {
|
||||
events <- MCPErrorEvent{
|
||||
Type: "error",
|
||||
Message: fmt.Sprintf("Failed to get response: %v", err),
|
||||
}
|
||||
ended <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Stream final assistant response
|
||||
content := f.LastMessage().Content
|
||||
events <- MCPAssistantEvent{
|
||||
|
||||
@@ -79,6 +79,14 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
return err
|
||||
}
|
||||
|
||||
// Resample to requested sample rate if specified
|
||||
if input.SampleRate > 0 {
|
||||
filePath, err = utils.AudioResample(filePath, input.SampleRate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Convert generated file to target format
|
||||
filePath, err = utils.AudioConvert(filePath, input.Format)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,10 +23,15 @@ import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func downloadFile(url string) (string, error) {
|
||||
if err := utils.ValidateExternalURL(url); err != nil {
|
||||
return "", fmt.Errorf("URL validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,18 +27,36 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
"github.com/mudler/LocalAI/pkg/sound"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
const (
|
||||
localSampleRate = 16000
|
||||
remoteSampleRate = 24000
|
||||
// XXX: Presently it seems all ASR/VAD backends use 16Khz. If a backend uses 24Khz then it will likely still work, but have reduced performance
|
||||
localSampleRate = 16000
|
||||
defaultRemoteSampleRate = 24000
|
||||
// Maximum audio buffer size in bytes (100MB) to prevent memory exhaustion
|
||||
maxAudioBufferSize = 100 * 1024 * 1024
|
||||
// Maximum WebSocket message size in bytes (10MB) to prevent DoS attacks
|
||||
maxWebSocketMessageSize = 10 * 1024 * 1024
|
||||
)
|
||||
|
||||
// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
|
||||
// If the model support instead audio-to-audio, we will use the specific gRPC calls instead
|
||||
|
||||
// LockedWebsocket wraps a websocket connection with a mutex for safe concurrent writes
|
||||
type LockedWebsocket struct {
|
||||
*websocket.Conn
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (l *LockedWebsocket) WriteMessage(messageType int, data []byte) error {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
// Session represents a single WebSocket connection and its state
|
||||
type Session struct {
|
||||
ID string
|
||||
@@ -58,7 +76,9 @@ type Session struct {
|
||||
DefaultConversationID string
|
||||
ModelInterface Model
|
||||
// The pipeline model config or the config for an any-to-any model
|
||||
ModelConfig *config.ModelConfig
|
||||
ModelConfig *config.ModelConfig
|
||||
InputSampleRate int
|
||||
MaxOutputTokens types.IntOrInf
|
||||
}
|
||||
|
||||
func (s *Session) FromClient(session *types.SessionUnion) {
|
||||
@@ -80,12 +100,13 @@ func (s *Session) ToServer() types.SessionUnion {
|
||||
} else {
|
||||
return types.SessionUnion{
|
||||
Realtime: &types.RealtimeSession{
|
||||
ID: s.ID,
|
||||
Object: "realtime.session",
|
||||
Model: s.Model,
|
||||
Instructions: s.Instructions,
|
||||
Tools: s.Tools,
|
||||
ToolChoice: s.ToolChoice,
|
||||
ID: s.ID,
|
||||
Object: "realtime.session",
|
||||
Model: s.Model,
|
||||
Instructions: s.Instructions,
|
||||
Tools: s.Tools,
|
||||
ToolChoice: s.ToolChoice,
|
||||
MaxOutputTokens: s.MaxOutputTokens,
|
||||
Audio: &types.RealtimeSessionAudio{
|
||||
Input: &types.SessionAudioInput{
|
||||
TurnDetection: s.TurnDetection,
|
||||
@@ -153,6 +174,9 @@ func Realtime(application *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Set maximum message size to prevent DoS attacks
|
||||
ws.SetReadLimit(maxWebSocketMessageSize)
|
||||
|
||||
// Extract query parameters from Echo context before passing to websocket handler
|
||||
model := c.QueryParam("model")
|
||||
|
||||
@@ -162,7 +186,8 @@ func Realtime(application *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
func registerRealtime(application *application.Application, model string) func(c *websocket.Conn) {
|
||||
return func(c *websocket.Conn) {
|
||||
return func(conn *websocket.Conn) {
|
||||
c := &LockedWebsocket{Conn: conn}
|
||||
|
||||
evaluator := application.TemplatesEvaluator()
|
||||
xlog.Debug("Realtime WebSocket connection established", "address", c.RemoteAddr().String(), "model", model)
|
||||
@@ -183,14 +208,13 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
}
|
||||
|
||||
sttModel := cfg.Pipeline.Transcription
|
||||
ttsModel := cfg.Pipeline.TTS
|
||||
|
||||
sessionID := generateSessionID()
|
||||
session := &Session{
|
||||
ID: sessionID,
|
||||
TranscriptionOnly: false,
|
||||
Model: model,
|
||||
Voice: ttsModel,
|
||||
Voice: cfg.TTSConfig.Voice,
|
||||
ModelConfig: cfg,
|
||||
TurnDetection: &types.TurnDetectionUnion{
|
||||
ServerVad: &types.ServerVad{
|
||||
@@ -203,7 +227,8 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
InputAudioTranscription: &types.AudioTranscription{
|
||||
Model: sttModel,
|
||||
},
|
||||
Conversations: make(map[string]*Conversation),
|
||||
Conversations: make(map[string]*Conversation),
|
||||
InputSampleRate: defaultRemoteSampleRate,
|
||||
}
|
||||
|
||||
// Create a default conversation
|
||||
@@ -355,8 +380,17 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to InputAudioBuffer
|
||||
// Check buffer size limits before appending
|
||||
session.AudioBufferLock.Lock()
|
||||
newSize := len(session.InputAudioBuffer) + len(decodedAudio)
|
||||
if newSize > maxAudioBufferSize {
|
||||
session.AudioBufferLock.Unlock()
|
||||
xlog.Error("audio buffer size limit exceeded", "current_size", len(session.InputAudioBuffer), "incoming_size", len(decodedAudio), "limit", maxAudioBufferSize)
|
||||
sendError(c, "buffer_size_exceeded", fmt.Sprintf("Audio buffer size limit exceeded (max %d bytes)", maxAudioBufferSize), "", "")
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to InputAudioBuffer
|
||||
session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...)
|
||||
session.AudioBufferLock.Unlock()
|
||||
|
||||
@@ -383,7 +417,36 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
|
||||
case types.ConversationItemCreateEvent:
|
||||
xlog.Debug("recv", "message", string(msg))
|
||||
sendNotImplemented(c, "conversation.item.create")
|
||||
// Add the item to the conversation
|
||||
item := e.Item
|
||||
// Ensure IDs are present
|
||||
if item.User != nil && item.User.ID == "" {
|
||||
item.User.ID = generateItemID()
|
||||
}
|
||||
if item.Assistant != nil && item.Assistant.ID == "" {
|
||||
item.Assistant.ID = generateItemID()
|
||||
}
|
||||
if item.System != nil && item.System.ID == "" {
|
||||
item.System.ID = generateItemID()
|
||||
}
|
||||
if item.FunctionCall != nil && item.FunctionCall.ID == "" {
|
||||
item.FunctionCall.ID = generateItemID()
|
||||
}
|
||||
if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" {
|
||||
item.FunctionCallOutput.ID = generateItemID()
|
||||
}
|
||||
|
||||
conversation.Lock.Lock()
|
||||
conversation.Items = append(conversation.Items, &item)
|
||||
conversation.Lock.Unlock()
|
||||
|
||||
sendEvent(c, types.ConversationItemAddedEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: e.EventID,
|
||||
},
|
||||
PreviousItemID: e.PreviousItemID,
|
||||
Item: item,
|
||||
})
|
||||
|
||||
case types.ConversationItemDeleteEvent:
|
||||
sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO")
|
||||
@@ -429,7 +492,34 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
|
||||
case types.ResponseCreateEvent:
|
||||
xlog.Debug("recv", "message", string(msg))
|
||||
sendNotImplemented(c, "response.create")
|
||||
|
||||
// Handle optional items to add to context
|
||||
if len(e.Response.Input) > 0 {
|
||||
conversation.Lock.Lock()
|
||||
for _, item := range e.Response.Input {
|
||||
// Ensure IDs are present
|
||||
if item.User != nil && item.User.ID == "" {
|
||||
item.User.ID = generateItemID()
|
||||
}
|
||||
if item.Assistant != nil && item.Assistant.ID == "" {
|
||||
item.Assistant.ID = generateItemID()
|
||||
}
|
||||
if item.System != nil && item.System.ID == "" {
|
||||
item.System.ID = generateItemID()
|
||||
}
|
||||
if item.FunctionCall != nil && item.FunctionCall.ID == "" {
|
||||
item.FunctionCall.ID = generateItemID()
|
||||
}
|
||||
if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" {
|
||||
item.FunctionCallOutput.ID = generateItemID()
|
||||
}
|
||||
|
||||
conversation.Items = append(conversation.Items, &item)
|
||||
}
|
||||
conversation.Lock.Unlock()
|
||||
}
|
||||
|
||||
go triggerResponse(session, conversation, c, &e.Response)
|
||||
|
||||
case types.ResponseCancelEvent:
|
||||
xlog.Debug("recv", "message", string(msg))
|
||||
@@ -456,7 +546,7 @@ func registerRealtime(application *application.Application, model string) func(c
|
||||
}
|
||||
|
||||
// Helper function to send events to the client
|
||||
func sendEvent(c *websocket.Conn, event types.ServerEvent) {
|
||||
func sendEvent(c *LockedWebsocket, event types.ServerEvent) {
|
||||
eventBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
xlog.Error("failed to marshal event", "error", err)
|
||||
@@ -468,7 +558,7 @@ func sendEvent(c *websocket.Conn, event types.ServerEvent) {
|
||||
}
|
||||
|
||||
// Helper function to send errors to the client
|
||||
func sendError(c *websocket.Conn, code, message, param, eventID string) {
|
||||
func sendError(c *LockedWebsocket, code, message, param, eventID string) {
|
||||
errorEvent := types.ErrorEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: eventID,
|
||||
@@ -485,7 +575,7 @@ func sendError(c *websocket.Conn, code, message, param, eventID string) {
|
||||
sendEvent(c, errorEvent)
|
||||
}
|
||||
|
||||
func sendNotImplemented(c *websocket.Conn, message string) {
|
||||
func sendNotImplemented(c *LockedWebsocket, message string) {
|
||||
sendError(c, "not_implemented", message, "", "event_TODO")
|
||||
}
|
||||
|
||||
@@ -530,6 +620,12 @@ func updateTransSession(session *Session, update *types.SessionUnion, cl *config
|
||||
session.TurnDetection = update.Transcription.Audio.Input.TurnDetection
|
||||
}
|
||||
|
||||
if update.Transcription.Audio.Input.Format != nil && update.Transcription.Audio.Input.Format.PCM != nil {
|
||||
if update.Transcription.Audio.Input.Format.PCM.Rate > 0 {
|
||||
session.InputSampleRate = update.Transcription.Audio.Input.Format.PCM.Rate
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -557,13 +653,13 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
||||
session.InputAudioTranscription = &types.AudioTranscription{}
|
||||
}
|
||||
session.InputAudioTranscription.Model = cfg.Pipeline.Transcription
|
||||
session.Voice = cfg.Pipeline.TTS
|
||||
session.Voice = cfg.TTSConfig.Voice
|
||||
session.Model = rt.Model
|
||||
session.ModelConfig = cfg
|
||||
}
|
||||
|
||||
if rt.Audio != nil && rt.Audio.Output != nil && rt.Audio.Output.Voice != "" {
|
||||
xlog.Warn("Ignoring voice setting; not implemented", "voice", rt.Audio.Output.Voice)
|
||||
session.Voice = string(rt.Audio.Output.Voice)
|
||||
}
|
||||
|
||||
if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Transcription != nil {
|
||||
@@ -583,6 +679,12 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
||||
session.TurnDetection = rt.Audio.Input.TurnDetection
|
||||
}
|
||||
|
||||
if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Format != nil && rt.Audio.Input.Format.PCM != nil {
|
||||
if rt.Audio.Input.Format.PCM.Rate > 0 {
|
||||
session.InputSampleRate = rt.Audio.Input.Format.PCM.Rate
|
||||
}
|
||||
}
|
||||
|
||||
if rt.Instructions != "" {
|
||||
session.Instructions = rt.Instructions
|
||||
}
|
||||
@@ -594,12 +696,16 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
|
||||
session.ToolChoice = rt.ToolChoice
|
||||
}
|
||||
|
||||
if rt.MaxOutputTokens != 0 {
|
||||
session.MaxOutputTokens = rt.MaxOutputTokens
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleVAD is a goroutine that listens for audio data from the client,
|
||||
// runs VAD on the audio data, and commits utterances to the conversation
|
||||
func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done chan struct{}) {
|
||||
vadContext, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-done
|
||||
@@ -628,12 +734,12 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha
|
||||
session.AudioBufferLock.Unlock()
|
||||
|
||||
aints := sound.BytesToInt16sLE(allAudio)
|
||||
if len(aints) == 0 || len(aints) < int(silenceThreshold)*remoteSampleRate {
|
||||
if len(aints) == 0 || len(aints) < int(silenceThreshold)*session.InputSampleRate {
|
||||
continue
|
||||
}
|
||||
|
||||
// Resample from 24kHz to 16kHz
|
||||
aints = sound.ResampleInt16(aints, remoteSampleRate, localSampleRate)
|
||||
// Resample from InputSampleRate to 16kHz
|
||||
aints = sound.ResampleInt16(aints, session.InputSampleRate, localSampleRate)
|
||||
|
||||
segments, err := runVAD(vadContext, session, aints)
|
||||
if err != nil {
|
||||
@@ -649,18 +755,18 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha
|
||||
audioLength := float64(len(aints)) / localSampleRate
|
||||
|
||||
// TODO: When resetting the buffer we should retain a small postfix
|
||||
// TODO: The OpenAI documentation seems to suggest that only the client decides when to clear the buffer
|
||||
if len(segments) == 0 && audioLength > silenceThreshold {
|
||||
session.AudioBufferLock.Lock()
|
||||
session.InputAudioBuffer = nil
|
||||
session.AudioBufferLock.Unlock()
|
||||
xlog.Debug("Detected silence for a while, clearing audio buffer")
|
||||
|
||||
sendEvent(c, types.InputAudioBufferClearedEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: "event_TODO",
|
||||
},
|
||||
})
|
||||
// NOTE: OpenAI doesn't send this message unless the client requests it
|
||||
// xlog.Debug("Detected silence for a while, clearing audio buffer")
|
||||
// sendEvent(c, types.InputAudioBufferClearedEvent{
|
||||
// ServerEventBase: types.ServerEventBase{
|
||||
// EventID: "event_TODO",
|
||||
// },
|
||||
// })
|
||||
|
||||
continue
|
||||
} else if len(segments) == 0 {
|
||||
@@ -713,7 +819,7 @@ func handleVAD(session *Session, conv *Conversation, c *websocket.Conn, done cha
|
||||
}
|
||||
}
|
||||
|
||||
func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *LockedWebsocket) {
|
||||
if len(utt) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -746,6 +852,10 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co
|
||||
tr, err := session.ModelInterface.Transcribe(ctx, f.Name(), session.InputAudioTranscription.Language, false, false, session.InputAudioTranscription.Prompt)
|
||||
if err != nil {
|
||||
sendError(c, "transcription_failed", err.Error(), "", "event_TODO")
|
||||
return
|
||||
} else if tr == nil {
|
||||
sendError(c, "transcription_failed", "trancribe result is nil", "", "event_TODO")
|
||||
return
|
||||
}
|
||||
|
||||
transcript = tr.Text
|
||||
@@ -791,11 +901,10 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS
|
||||
}
|
||||
|
||||
// Function to generate a response based on the conversation
|
||||
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *websocket.Conn, mt int) {
|
||||
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) {
|
||||
xlog.Debug("Generating realtime response...")
|
||||
|
||||
config := session.ModelInterface.PredictConfig()
|
||||
|
||||
// Create user message item
|
||||
item := types.MessageItemUnion{
|
||||
User: &types.MessageItemUser{
|
||||
ID: generateItemID(),
|
||||
@@ -817,33 +926,100 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
Item: item,
|
||||
})
|
||||
|
||||
triggerResponse(session, conv, c, nil)
|
||||
}
|
||||
|
||||
func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, overrides *types.ResponseCreateParams) {
|
||||
config := session.ModelInterface.PredictConfig()
|
||||
|
||||
// Default values
|
||||
tools := session.Tools
|
||||
toolChoice := session.ToolChoice
|
||||
instructions := session.Instructions
|
||||
maxOutputTokens := session.MaxOutputTokens
|
||||
// Overrides
|
||||
if overrides != nil {
|
||||
if overrides.Tools != nil {
|
||||
tools = overrides.Tools
|
||||
}
|
||||
if overrides.ToolChoice != nil {
|
||||
toolChoice = overrides.ToolChoice
|
||||
}
|
||||
if overrides.Instructions != "" {
|
||||
instructions = overrides.Instructions
|
||||
}
|
||||
if overrides.MaxOutputTokens != 0 {
|
||||
maxOutputTokens = overrides.MaxOutputTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Apply MaxOutputTokens to model config if specified
|
||||
// Save original value to restore after prediction
|
||||
var originalMaxTokens *int
|
||||
if config != nil {
|
||||
originalMaxTokens = config.Maxtokens
|
||||
if maxOutputTokens != 0 && !maxOutputTokens.IsInf() {
|
||||
tokenValue := int(maxOutputTokens)
|
||||
config.Maxtokens = &tokenValue
|
||||
xlog.Debug("Applied max_output_tokens to config", "value", tokenValue)
|
||||
}
|
||||
}
|
||||
// Defer restoration of original value
|
||||
defer func() {
|
||||
if config != nil {
|
||||
config.Maxtokens = originalMaxTokens
|
||||
}
|
||||
}()
|
||||
|
||||
var conversationHistory schema.Messages
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: string(types.MessageRoleSystem),
|
||||
StringContent: session.Instructions,
|
||||
Content: session.Instructions,
|
||||
StringContent: instructions,
|
||||
Content: instructions,
|
||||
})
|
||||
|
||||
imgIndex := 0
|
||||
conv.Lock.Lock()
|
||||
for _, item := range conv.Items {
|
||||
if item.User != nil {
|
||||
msg := schema.Message{
|
||||
Role: string(types.MessageRoleUser),
|
||||
}
|
||||
textContent := ""
|
||||
nrOfImgsInMessage := 0
|
||||
for _, content := range item.User.Content {
|
||||
switch content.Type {
|
||||
case types.MessageContentTypeInputText:
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: string(types.MessageRoleUser),
|
||||
StringContent: content.Text,
|
||||
Content: content.Text,
|
||||
})
|
||||
textContent += content.Text
|
||||
case types.MessageContentTypeInputAudio:
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: string(types.MessageRoleUser),
|
||||
StringContent: content.Transcript,
|
||||
Content: content.Transcript,
|
||||
StringAudios: []string{content.Audio},
|
||||
})
|
||||
textContent += content.Transcript
|
||||
case types.MessageContentTypeInputImage:
|
||||
img, err := utils.GetContentURIAsBase64(content.ImageURL)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to process image", "error", err)
|
||||
continue
|
||||
}
|
||||
msg.StringImages = append(msg.StringImages, img)
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
if nrOfImgsInMessage > 0 {
|
||||
templated, err := templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
}, textContent)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to apply multimodal template", "error", err)
|
||||
templated = textContent
|
||||
}
|
||||
msg.StringContent = templated
|
||||
msg.Content = templated
|
||||
} else {
|
||||
msg.StringContent = textContent
|
||||
msg.Content = textContent
|
||||
}
|
||||
conversationHistory = append(conversationHistory, msg)
|
||||
} else if item.Assistant != nil {
|
||||
for _, content := range item.Assistant.Content {
|
||||
switch content.Type {
|
||||
@@ -870,10 +1046,36 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
Content: content.Text,
|
||||
})
|
||||
}
|
||||
} else if item.FunctionCall != nil {
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: string(types.MessageRoleAssistant),
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
ID: item.FunctionCall.CallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: item.FunctionCall.Name,
|
||||
Arguments: item.FunctionCall.Arguments,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
} else if item.FunctionCallOutput != nil {
|
||||
conversationHistory = append(conversationHistory, schema.Message{
|
||||
Role: "tool",
|
||||
Name: item.FunctionCallOutput.CallID,
|
||||
Content: item.FunctionCallOutput.Output,
|
||||
StringContent: item.FunctionCallOutput.Output,
|
||||
})
|
||||
}
|
||||
}
|
||||
conv.Lock.Unlock()
|
||||
|
||||
var images []string
|
||||
for _, m := range conversationHistory {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
responseID := generateUniqueID()
|
||||
sendEvent(c, types.ResponseCreatedEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
@@ -884,26 +1086,47 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
},
|
||||
})
|
||||
|
||||
predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, nil, nil, nil, nil, session.Tools, session.ToolChoice, nil, nil, nil)
|
||||
predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil)
|
||||
if err != nil {
|
||||
sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID)
|
||||
sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here
|
||||
return
|
||||
}
|
||||
|
||||
pred, err := predFunc()
|
||||
if err != nil {
|
||||
sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID)
|
||||
sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "")
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Debug("Function config for parsing", "function_name_key", config.FunctionsConfig.FunctionNameKey, "function_arguments_key", config.FunctionsConfig.FunctionArgumentsKey)
|
||||
xlog.Debug("LLM raw response", "text", pred.Response, "response_length", len(pred.Response), "usage", pred.Usage)
|
||||
|
||||
// Safely dereference pointer fields for logging
|
||||
maxTokens := "nil"
|
||||
if config.Maxtokens != nil {
|
||||
maxTokens = fmt.Sprintf("%d", *config.Maxtokens)
|
||||
}
|
||||
contextSize := "nil"
|
||||
if config.ContextSize != nil {
|
||||
contextSize = fmt.Sprintf("%d", *config.ContextSize)
|
||||
}
|
||||
xlog.Debug("Model parameters", "max_tokens", maxTokens, "context_size", contextSize, "stopwords", config.StopWords)
|
||||
|
||||
rawResponse := pred.Response
|
||||
if config.TemplateConfig.ReplyPrefix != "" {
|
||||
rawResponse = config.TemplateConfig.ReplyPrefix + rawResponse
|
||||
}
|
||||
|
||||
reasoningText, responseWithoutReasoning := reasoning.ExtractReasoningWithConfig(rawResponse, "", config.ReasoningConfig)
|
||||
// Detect thinking start token from template for reasoning extraction
|
||||
var template string
|
||||
if config.TemplateConfig.UseTokenizerTemplate {
|
||||
template = config.GetModelTemplate()
|
||||
} else {
|
||||
template = config.TemplateConfig.Chat
|
||||
}
|
||||
thinkingStartToken := reasoning.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
|
||||
reasoningText, responseWithoutReasoning := reasoning.ExtractReasoningWithConfig(rawResponse, thinkingStartToken, config.ReasoningConfig)
|
||||
xlog.Debug("LLM Response", "reasoning", reasoningText, "response_without_reasoning", responseWithoutReasoning)
|
||||
|
||||
textContent := functions.ParseTextContent(responseWithoutReasoning, config.FunctionsConfig)
|
||||
@@ -1006,7 +1229,16 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
sendError(c, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
audioString := base64.StdEncoding.EncodeToString(audioBytes)
|
||||
|
||||
// Strip WAV header (44 bytes) to get raw PCM data
|
||||
// The OpenAI Realtime API expects raw PCM, not WAV files
|
||||
const wavHeaderSize = 44
|
||||
pcmData := audioBytes
|
||||
if len(audioBytes) > wavHeaderSize {
|
||||
pcmData = audioBytes[wavHeaderSize:]
|
||||
}
|
||||
|
||||
audioString := base64.StdEncoding.EncodeToString(pcmData)
|
||||
|
||||
sendEvent(c, types.ResponseOutputAudioTranscriptDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
@@ -1131,7 +1363,6 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
|
||||
Status: types.ResponseStatusCompleted,
|
||||
},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// Helper functions to generate unique IDs
|
||||
|
||||
@@ -194,7 +194,40 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
|
||||
var toolsJSON string
|
||||
if len(tools) > 0 {
|
||||
b, _ := json.Marshal(tools)
|
||||
// Convert tools to OpenAI Chat Completions format (nested)
|
||||
// as expected by most backends (including llama.cpp)
|
||||
var chatTools []functions.Tool
|
||||
for _, t := range tools {
|
||||
if t.Function != nil {
|
||||
var params map[string]interface{}
|
||||
switch p := t.Function.Parameters.(type) {
|
||||
case map[string]interface{}:
|
||||
params = p
|
||||
case string:
|
||||
if err := json.Unmarshal([]byte(p), ¶ms); err != nil {
|
||||
xlog.Warn("Failed to parse parameters JSON string", "error", err, "function", t.Function.Name)
|
||||
}
|
||||
case nil:
|
||||
params = map[string]interface{}{}
|
||||
default:
|
||||
// Try to marshal/unmarshal to get map
|
||||
b, err := json.Marshal(p)
|
||||
if err == nil {
|
||||
_ = json.Unmarshal(b, ¶ms)
|
||||
}
|
||||
}
|
||||
|
||||
chatTools = append(chatTools, functions.Tool{
|
||||
Type: "function",
|
||||
Function: functions.Function{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
b, _ := json.Marshal(chatTools)
|
||||
toolsJSON = string(b)
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user