mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-23 16:20:01 -04:00
Compare commits
30 Commits
bump/turbo
...
feat/buun-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9787bee48b | ||
|
|
42754d33b9 | ||
|
|
7f2b7e4ace | ||
|
|
6233feb190 | ||
|
|
d6bf3a4969 | ||
|
|
b27d38a53d | ||
|
|
45756b19dc | ||
|
|
cd6079b2f3 | ||
|
|
3db60b57e6 | ||
|
|
13734ae9fa | ||
|
|
c0920f3273 | ||
|
|
7c1934b183 | ||
|
|
5e062b4d1f | ||
|
|
4906cbad04 | ||
|
|
c755cd5ab5 | ||
|
|
0fb04f7ac3 | ||
|
|
d9d7b5c29b | ||
|
|
f877942d97 | ||
|
|
f5eb13d3c2 | ||
|
|
c1f923b2bc | ||
|
|
ed648b3b4e | ||
|
|
3ce5248126 | ||
|
|
04f1a0285d | ||
|
|
181ebb6df4 | ||
|
|
1c59165d63 | ||
|
|
eb00d9b178 | ||
|
|
2068b6f43c | ||
|
|
eb01c77214 | ||
|
|
bb4fda6f0e | ||
|
|
f0c92610a1 |
@@ -8,6 +8,7 @@ Create the backend directory under the appropriate location:
|
||||
- **Python backends**: `backend/python/<backend-name>/`
|
||||
- **Go backends**: `backend/go/<backend-name>/`
|
||||
- **C++ backends**: `backend/cpp/<backend-name>/`
|
||||
- **Rust backends**: `backend/rust/<backend-name>/`
|
||||
|
||||
For Python backends, you'll typically need:
|
||||
- `backend.py` - Main gRPC server implementation
|
||||
@@ -18,9 +19,22 @@ For Python backends, you'll typically need:
|
||||
- `run.sh` - Runtime script
|
||||
- `test.py` / `test.sh` - Test files
|
||||
|
||||
For Rust backends, you'll typically need (see `backend/rust/kokoros/` as a reference):
|
||||
- `Cargo.toml` - Crate manifest; depend on the upstream project as a submodule under `sources/`
|
||||
- `build.rs` - Invokes `tonic_build` to generate gRPC stubs from `backend/backend.proto` (use the `BACKEND_PROTO_PATH` env var so the Makefile can inject the canonical copy)
|
||||
- `src/` - The gRPC server implementation (implement `Backend` via `tonic`)
|
||||
- `Makefile` - Copies `backend.proto` into the crate, runs `cargo build --release`, then `package.sh`
|
||||
- `package.sh` - Uses `ldd` to bundle the binary's dynamic deps and `ld.so` into `package/lib/`
|
||||
- `run.sh` - Sets `LD_LIBRARY_PATH`/`SSL_CERT_DIR` and execs the binary via the bundled `lib/ld.so`
|
||||
- `sources/<UpstreamProject>/` - Git submodule with the upstream Rust crate
|
||||
|
||||
## 2. Add Build Configurations to `.github/workflows/backend.yml`
|
||||
|
||||
Add build matrix entries for each platform/GPU type you want to support. Look at similar backends (e.g., `chatterbox`, `faster-whisper`) for reference.
|
||||
Add build matrix entries for each platform/GPU type you want to support. Look at similar backends for reference — `chatterbox`/`faster-whisper` for Python, `piper`/`silero-vad` for Go, `kokoros` for Rust.
|
||||
|
||||
**Without an entry here no image is ever built or pushed, and the gallery entry in `backend/index.yaml` will point at a tag that does not exist.** The `dockerfile:` field must point at `./backend/Dockerfile.<lang>` matching the language bucket from step 1 (e.g. `Dockerfile.python`, `Dockerfile.golang`, `Dockerfile.rust`). The `tag-suffix` must match the `uri:` in the corresponding `backend/index.yaml` image entry exactly.
|
||||
|
||||
If you add a new language bucket, `scripts/changed-backends.js` also needs a branch in `inferBackendPath` so PR change-detection routes file edits correctly.
|
||||
|
||||
**Placement in file:**
|
||||
- CPU builds: Add after other CPU builds (e.g., after `cpu-chatterbox`)
|
||||
@@ -56,24 +70,28 @@ Add `backends/<backend-name>` to the `.NOTPARALLEL` line (around line 2) to prev
|
||||
|
||||
**Step 4b: Add to `prepare-test-extra`**
|
||||
|
||||
Add the backend to the `prepare-test-extra` target (around line 312) to prepare it for testing:
|
||||
Add the backend to the `prepare-test-extra` target to prepare it for testing. Use the path matching your language bucket (`backend/python/`, `backend/go/`, `backend/rust/`, …):
|
||||
|
||||
```makefile
|
||||
prepare-test-extra: protogen-python
|
||||
...
|
||||
$(MAKE) -C backend/python/<backend-name>
|
||||
$(MAKE) -C backend/<lang>/<backend-name>
|
||||
```
|
||||
|
||||
For Rust backends the target is usually the crate build target itself (e.g. `$(MAKE) -C backend/rust/<backend-name> <backend-name>-grpc`) so the binary is in place before `test` runs.
|
||||
|
||||
**Step 4c: Add to `test-extra`**
|
||||
|
||||
Add the backend to the `test-extra` target (around line 319) to run its tests:
|
||||
Add the backend to the `test-extra` target to run its tests — applies to Go and Rust backends too, not only Python:
|
||||
|
||||
```makefile
|
||||
test-extra: prepare-test-extra
|
||||
...
|
||||
$(MAKE) -C backend/python/<backend-name> test
|
||||
$(MAKE) -C backend/<lang>/<backend-name> test
|
||||
```
|
||||
|
||||
Each backend's own `Makefile` should define a `test` target so this line works regardless of language. Integration tests that need large model downloads should be gated behind an env var (see `backend/rust/kokoros/`'s `KOKOROS_MODEL_PATH` pattern) so CI only runs unit tests.
|
||||
|
||||
**Step 4d: Add Backend Definition**
|
||||
|
||||
Add a backend definition variable in the backend definitions section (around line 428-457). The format depends on the backend type:
|
||||
@@ -93,6 +111,13 @@ BACKEND_<BACKEND_NAME> = <backend-name>|python|./backend|false|true
|
||||
BACKEND_<BACKEND_NAME> = <backend-name>|golang|.|false|true
|
||||
```
|
||||
|
||||
**For Rust backends**:
|
||||
```makefile
|
||||
BACKEND_<BACKEND_NAME> = <backend-name>|rust|.|false|true
|
||||
```
|
||||
|
||||
The language field (`python`/`golang`/`rust`/…) must match a `backend/Dockerfile.<lang>` file.
|
||||
|
||||
**Step 4e: Generate Docker Build Target**
|
||||
|
||||
Add an eval call to generate the docker-build target (around line 480-501):
|
||||
@@ -153,6 +178,29 @@ ls /tmp/check # expect the bundled .so files + symlinks
|
||||
|
||||
Then boot it inside a fresh `ubuntu:24.04` (which intentionally does *not* have the lib installed) to confirm it actually loads from the backend dir.
|
||||
|
||||
## Importer integration
|
||||
|
||||
When you add a new backend, you MUST also make it importable via the model import form (`/import-model`). The import form dropdown is sourced dynamically from `GET /backends/known` — it reads the importer registry at `core/gallery/importers/importers.go`, so the steps below are the ONLY way to make your backend show up.
|
||||
|
||||
Required steps:
|
||||
|
||||
1. **If your backend has unambiguous detection signals** (unique file extension, HF `pipeline_tag`, unique repo name pattern, unique artefact like `modules.json`):
|
||||
- Create an importer file at `core/gallery/importers/<backend>.go` following the Match/Import pattern in `llama-cpp.go`.
|
||||
- Register it in `importers.go:defaultImporters` in **specificity order** — more specific detectors must appear BEFORE more generic ones (e.g. `sentencetransformers` before `transformers`, `stablediffusion-ggml` before `llama-cpp`, `vllm-omni` before `vllm`). First match wins.
|
||||
2. **If your backend is a drop-in replacement** (same artefacts as another backend, e.g. `ik-llama-cpp` and `turboquant` both consume GGUF the same way `llama-cpp` does):
|
||||
- Do NOT create a new importer. Extend the existing importer's `Import()` to swap the emitted `backend:` field when `preferences.backend` matches. See `llama-cpp.go` for the pattern.
|
||||
3. **If your backend has no reliable auto-detect signal** (preference-only — e.g. `sglang`, `tinygrad`, `whisperx`):
|
||||
- Do NOT create an importer. Instead add the backend name to the curated pref-only slice in `core/http/endpoints/localai/backend.go` that feeds `/backends/known`. A single line addition.
|
||||
4. **Always** add a table-driven test in `core/gallery/importers/importers_test.go` (Ginkgo/Gomega):
|
||||
- Use a real public HuggingFace repo URI as the test fixture (existing tests already hit the live HF API — follow that pattern).
|
||||
- Cover detection (auto-match without preferences), preference-override (explicit `backend:` in preferences wins), and — if the backend's modality has a common `pipeline_tag` but ambiguous artefacts — an ambiguity test asserting `errors.Is(err, importers.ErrAmbiguousImport)`.
|
||||
|
||||
Rules of thumb:
|
||||
|
||||
- When in doubt, lean pref-only. A wrong auto-detect is worse than a forced preference.
|
||||
- Never silently emit a modality mismatch (e.g. emit `llama-cpp` for a TTS repo because `.gguf` is present). Return `ErrAmbiguousImport` instead.
|
||||
- Registration order is the single most common source of bugs. Check by running `go test ./core/gallery/importers/...` — the existing suite will fail if you've shadowed a pre-existing detector.
|
||||
|
||||
## 6. Example: Adding a Python Backend
|
||||
|
||||
For reference, when `moonshine` was added:
|
||||
|
||||
@@ -35,19 +35,33 @@ All contributions must comply with LocalAI's licensing requirements:
|
||||
|
||||
## Signed-off-by and Developer Certificate of Origin
|
||||
|
||||
**AI agents MUST NOT add `Signed-off-by` tags.** Only humans can legally
|
||||
certify the Developer Certificate of Origin (DCO). The human submitter
|
||||
is responsible for:
|
||||
Only humans can certify the Developer Certificate of Origin (DCO). AI
|
||||
agents MUST NOT invent or guess a human identity for `Signed-off-by` —
|
||||
doing so forges the DCO certification.
|
||||
|
||||
- Reviewing all AI-generated code
|
||||
However, when a human operator explicitly directs the AI to commit on
|
||||
their behalf, the AI is acting as a typing tool — no different from an
|
||||
editor macro or `git commit -s`. In that case the AI SHOULD add
|
||||
`Signed-off-by:` using the **configured `user.name` / `user.email`** of
|
||||
the current git repository (i.e. the operator's own identity). The
|
||||
resulting trailer is the operator's signature; they take responsibility
|
||||
for it by reviewing and pushing the commit. The AI MUST NOT use any
|
||||
other identity and MUST NOT add its own name to the sign-off.
|
||||
|
||||
When running `git commit`, prefer `git commit --signoff` (or `-s`) so
|
||||
the trailer is emitted by git itself from the configured identity,
|
||||
rather than hand-writing it in a heredoc — this guarantees the sign-off
|
||||
matches whatever identity the operator is currently using.
|
||||
|
||||
The human submitter remains responsible for:
|
||||
|
||||
- Reviewing all AI-generated code before it's pushed or merged
|
||||
- Ensuring compliance with licensing requirements
|
||||
- Adding their own `Signed-off-by` tag (when the project requires DCO)
|
||||
to certify the contribution
|
||||
- Taking full responsibility for the contribution
|
||||
|
||||
AI agents MUST NOT add `Co-Authored-By` trailers for themselves either.
|
||||
A human reviewer owns the contribution; the AI's involvement is recorded
|
||||
via `Assisted-by` (see below).
|
||||
AI agents MUST NOT add `Co-Authored-By` trailers for themselves. A human
|
||||
reviewer owns the contribution; the AI's involvement is recorded via
|
||||
`Assisted-by` (see below).
|
||||
|
||||
## Attribution
|
||||
|
||||
@@ -84,6 +98,12 @@ Assisted-by: Claude:claude-opus-4-7 golangci-lint
|
||||
Signed-off-by: Jane Developer <jane@example.com>
|
||||
```
|
||||
|
||||
The `Signed-off-by` line uses Jane's own identity because Jane is the
|
||||
submitter operating the AI. If Jane asks Claude to create the commit via
|
||||
`git commit -s`, git emits that exact trailer from Jane's configured
|
||||
identity — no separate human step is needed beyond Jane reviewing the
|
||||
diff before pushing.
|
||||
|
||||
## Scope and Responsibility
|
||||
|
||||
Using an AI assistant does not reduce the contributor's responsibility.
|
||||
|
||||
@@ -42,6 +42,12 @@ trim_trailing_whitespace = false
|
||||
|
||||
Use `github.com/mudler/xlog` for logging which has the same API as slog.
|
||||
|
||||
## Go tests
|
||||
|
||||
All Go tests — including backend tests — must use [Ginkgo](https://onsi.github.io/ginkgo/) (v2) with Gomega matchers, not the stdlib `testing` package with `t.Run` / `t.Errorf`. A test file should register a suite with `RegisterFailHandler(Fail)` in a `TestXxx(t *testing.T)` bootstrap and use `Describe`/`Context`/`It` blocks for the actual cases. Look at any existing `*_test.go` under `core/` or `pkg/` for a template.
|
||||
|
||||
Do not mix styles within a package. If you are extending tests in a package that already uses Ginkgo, keep using Ginkgo. If you find stdlib-style Go tests in the tree, treat them as tech debt to be migrated rather than as a pattern to follow.
|
||||
|
||||
## Documentation
|
||||
|
||||
The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant.
|
||||
|
||||
187
.github/workflows/backend.yml
vendored
187
.github/workflows/backend.yml
vendored
@@ -399,6 +399,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-buun-llama-cpp'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -724,6 +737,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-speaker-recognition'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "speaker-recognition"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -881,6 +907,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-buun-llama-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -907,6 +946,19 @@ jobs:
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-buun-llama-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
ubuntu-version: '2404'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1441,6 +1493,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-buun-llama-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1690,6 +1755,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-buun-llama-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1716,6 +1794,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-buun-llama-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2121,6 +2212,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-buun-llama-cpp'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2160,6 +2264,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-buun-llama-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2186,6 +2303,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-buun-llama-cpp'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "buun-llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.buun-llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# Stablediffusion-ggml
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -2653,6 +2783,20 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# speaker-recognition (voice/speaker biometrics)
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-speaker-recognition'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "speaker-recognition"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2850,6 +2994,49 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# sherpa-onnx CPU
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-sherpa-onnx'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sherpa-onnx"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# sherpa-onnx CUDA 12
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-sherpa-onnx'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sherpa-onnx"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# sherpa-onnx CUDA 13 — requires onnxruntime 1.24.x+ for the
|
||||
# gpu_cuda13 tarball; sherpa-onnx SHERPA_COMMIT pins to v1.12.39.
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-sherpa-onnx'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sherpa-onnx"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
backend-jobs-darwin:
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
strategy:
|
||||
|
||||
2
.github/workflows/backend_build.yml
vendored
2
.github/workflows/backend_build.yml
vendored
@@ -108,6 +108,8 @@ jobs:
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
- name: Release space from worker
|
||||
if: inputs.runs-on == 'ubuntu-latest'
|
||||
|
||||
119
.github/workflows/test-extra.yml
vendored
119
.github/workflows/test-extra.yml
vendored
@@ -32,6 +32,7 @@ jobs:
|
||||
llama-cpp: ${{ steps.detect.outputs.llama-cpp }}
|
||||
ik-llama-cpp: ${{ steps.detect.outputs.ik-llama-cpp }}
|
||||
turboquant: ${{ steps.detect.outputs.turboquant }}
|
||||
buun-llama-cpp: ${{ steps.detect.outputs['buun-llama-cpp'] }}
|
||||
vllm: ${{ steps.detect.outputs.vllm }}
|
||||
sglang: ${{ steps.detect.outputs.sglang }}
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
@@ -39,6 +40,8 @@ jobs:
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
kokoros: ${{ steps.detect.outputs.kokoros }}
|
||||
insightface: ${{ steps.detect.outputs.insightface }}
|
||||
speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }}
|
||||
sherpa-onnx: ${{ steps.detect.outputs.sherpa-onnx }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -505,6 +508,72 @@ jobs:
|
||||
- name: Build llama-cpp backend image and run audio transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-llama-cpp-transcription
|
||||
# Realtime e2e with sherpa-onnx driving VAD + STT + TTS against a mocked LLM.
|
||||
# Builds the sherpa-onnx Docker image, extracts the rootfs so the e2e suite
|
||||
# can discover the backend binary + shared libs, downloads the three model
|
||||
# bundles (silero-vad, omnilingual-asr, vits-ljs) and drives the realtime
|
||||
# websocket spec end-to-end.
|
||||
tests-sherpa-onnx-realtime:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.sherpa-onnx == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '22'
|
||||
- name: Build sherpa-onnx backend image and run realtime e2e tests
|
||||
run: |
|
||||
make test-extra-e2e-realtime-sherpa
|
||||
# Streaming ASR via the sherpa-onnx online recognizer (zipformer
|
||||
# transducer). Exercises both AudioTranscription (buffered) and
|
||||
# AudioTranscriptionStream (real-time deltas) on the e2e-backends
|
||||
# harness.
|
||||
tests-sherpa-onnx-grpc-transcription:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.sherpa-onnx == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Build sherpa-onnx backend image and run streaming ASR gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-sherpa-onnx-transcription
|
||||
# VITS TTS via the sherpa-onnx backend. Drives both TTS (file write) and
|
||||
# TTSStream (PCM chunks) on the e2e-backends harness.
|
||||
tests-sherpa-onnx-grpc-tts:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.sherpa-onnx == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Build sherpa-onnx backend image and run TTS gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-sherpa-onnx-tts
|
||||
tests-ik-llama-cpp-grpc:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.ik-llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
@@ -545,6 +614,30 @@ jobs:
|
||||
- name: Build turboquant backend image and run gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-turboquant
|
||||
tests-buun-llama-cpp-grpc:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs['buun-llama-cpp'] == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
# Exercises the buun-llama-cpp (fork-of-a-fork) backend with the
|
||||
# fork-specific TurboQuant/TCQ KV-cache types. BACKEND_TEST_CACHE_TYPE_V
|
||||
# is set to turbo3 so the test round-trips through the fork's KV
|
||||
# allow-list — picking a stock llama.cpp type would only re-test the
|
||||
# shared code path. DFlash speculative decoding is not exercised here
|
||||
# because the one known public target/drafter pair (Qwen3.5-27B) is too
|
||||
# large for CI.
|
||||
- name: Build buun-llama-cpp backend image and run gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-buun-llama-cpp
|
||||
# tests-vllm-grpc is currently disabled in CI.
|
||||
#
|
||||
# The prebuilt vllm CPU wheel is compiled with AVX-512 VNNI/BF16
|
||||
@@ -778,3 +871,29 @@ jobs:
|
||||
- name: Build insightface backend image and run both model configurations
|
||||
run: |
|
||||
make test-extra-backend-insightface-all
|
||||
tests-speaker-recognition-grpc:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.speaker-recognition == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
make build-essential curl ca-certificates git tar
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.26.0'
|
||||
- name: Free disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/lib/android /opt/hostedtoolcache/CodeQL || true
|
||||
df -h
|
||||
- name: Build speaker-recognition backend image and run the ECAPA-TDNN configuration
|
||||
run: |
|
||||
make test-extra-backend-speaker-recognition-all
|
||||
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -195,7 +195,7 @@ jobs:
|
||||
run: go version
|
||||
- name: Dependencies
|
||||
run: |
|
||||
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm opus
|
||||
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm opus ffmpeg
|
||||
pip install --user --no-cache-dir grpcio-tools grpcio
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
|
||||
@@ -19,7 +19,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
|------|-------------|
|
||||
| [.agents/ai-coding-assistants.md](.agents/ai-coding-assistants.md) | Policy for AI-assisted contributions — licensing, DCO, attribution |
|
||||
| [.agents/building-and-testing.md](.agents/building-and-testing.md) | Building the project, running tests, Docker builds for specific platforms |
|
||||
| [.agents/adding-backends.md](.agents/adding-backends.md) | Adding a new backend (Python, Go, or C++) — full step-by-step checklist |
|
||||
| [.agents/adding-backends.md](.agents/adding-backends.md) | Adding a new backend (Python, Go, or C++) — full step-by-step checklist, including importer integration (the `/import-model` dropdown is server-driven from `GET /backends/known`) |
|
||||
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
|
||||
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
||||
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
|
||||
|
||||
152
Makefile
152
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/buun-llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad backends/sherpa-onnx
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -394,7 +394,13 @@ protoc:
|
||||
.PHONY: protogen-go
|
||||
protogen-go: protoc install-go-tools
|
||||
mkdir -p pkg/grpc/proto
|
||||
./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
||||
# install-go-tools writes protoc-gen-go and protoc-gen-go-grpc into
|
||||
# $(shell go env GOPATH)/bin, which isn't on every dev's PATH. protoc
|
||||
# resolves its code-gen plugins via PATH, so without this prefix the
|
||||
# generate step fails with "protoc-gen-go: program not found". Prepend
|
||||
# GOPATH/bin so the freshly-installed plugins win without requiring a
|
||||
# shell-profile change.
|
||||
PATH="$$(go env GOPATH)/bin:$$PATH" ./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
||||
backend/backend.proto
|
||||
|
||||
core/config/inference_defaults.json: ## Fetch inference defaults from unsloth (only if missing)
|
||||
@@ -435,6 +441,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/trl
|
||||
$(MAKE) -C backend/python/tinygrad
|
||||
$(MAKE) -C backend/python/insightface
|
||||
$(MAKE) -C backend/python/speaker-recognition
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
@@ -459,6 +466,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/trl test
|
||||
$(MAKE) -C backend/python/tinygrad test
|
||||
$(MAKE) -C backend/python/insightface test
|
||||
$(MAKE) -C backend/python/speaker-recognition test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
|
||||
##
|
||||
@@ -537,6 +545,19 @@ test-extra-backend-turboquant: docker-build-turboquant
|
||||
BACKEND_TEST_CACHE_TYPE_V=turbo3 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## buun-llama-cpp: exercises the fork-of-a-fork backend (spiritbuun/buun-llama-cpp)
|
||||
## with the *TurboQuant/TCQ-specific* KV-cache types (turbo3 for V). Same rationale
|
||||
## as turboquant above: picking a standard llama.cpp type would only re-test the
|
||||
## shared code path. buun inherits turboquant's turbo2/turbo3/turbo4 and adds
|
||||
## turbo2_tcq / turbo3_tcq on top. DFlash speculative decoding is not exercised
|
||||
## here because no small DFlash drafter model exists (the known public pair is
|
||||
## Qwen3.5-27B, ~54 GB).
|
||||
test-extra-backend-buun-llama-cpp: docker-build-buun-llama-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:buun-llama-cpp \
|
||||
BACKEND_TEST_CACHE_TYPE_K=q8_0 \
|
||||
BACKEND_TEST_CACHE_TYPE_V=turbo3 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Audio transcription wrapper for the llama-cpp backend.
|
||||
## Drives the new AudioTranscription / AudioTranscriptionStream RPCs against
|
||||
## ggml-org/Qwen3-ASR-0.6B-GGUF (a small ASR model that requires its mmproj
|
||||
@@ -621,6 +642,11 @@ test-extra-backend-tinygrad-all: \
|
||||
FACE_IMAGE_1_URL ?= https://github.com/deepinsight/insightface/raw/master/python-package/insightface/data/images/t1.jpg
|
||||
FACE_IMAGE_2_URL ?= https://github.com/deepinsight/insightface/raw/master/python-package/insightface/data/images/t1.jpg
|
||||
FACE_IMAGE_3_URL ?= https://github.com/deepinsight/insightface/raw/master/python-package/insightface/data/images/mask_white.jpg
|
||||
## Known spoof fixture used by the face_antispoof e2e cap. This is
|
||||
## upstream's own `image_F2.jpg` (Silent-Face repo, via yakhyo mirror)
|
||||
## — verified to classify as is_real=false with score < 0.05 on the
|
||||
## MiniFASNetV2 + MiniFASNetV1SE ensemble.
|
||||
FACE_SPOOF_IMAGE_URL ?= https://github.com/yakhyo/face-anti-spoofing/raw/main/assets/image_F2.jpg
|
||||
|
||||
## Host-side cache for the OpenCV Zoo face ONNX files used by the
|
||||
## opencv e2e target. The backend image no longer bakes model weights —
|
||||
@@ -644,6 +670,15 @@ INSIGHTFACE_BUFFALO_SC_DIR := /tmp/localai-insightface-buffalo-sc-cache
|
||||
INSIGHTFACE_BUFFALO_SC_URL := https://github.com/deepinsight/insightface/releases/download/v0.7/buffalo_sc.zip
|
||||
INSIGHTFACE_BUFFALO_SC_SHA := 57d31b56b6ffa911c8a73cfc1707c73cab76efe7f13b675a05223bf42de47c72
|
||||
|
||||
## Silent-Face antispoofing (MiniFASNetV2 + MiniFASNetV1SE) — shared
|
||||
## between the buffalo_sc and opencv e2e targets. Both ONNX files are
|
||||
## ~1.7MB, Apache 2.0. URLs + SHAs mirror the gallery entries.
|
||||
INSIGHTFACE_ANTISPOOF_DIR := /tmp/localai-insightface-antispoof-cache
|
||||
INSIGHTFACE_ANTISPOOF_V2_URL := https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV2.onnx
|
||||
INSIGHTFACE_ANTISPOOF_V2_SHA := b32929adc2d9c34b9486f8c4c7bc97c1b69bc0ea9befefc380e4faae4e463907
|
||||
INSIGHTFACE_ANTISPOOF_V1SE_URL := https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV1SE.onnx
|
||||
INSIGHTFACE_ANTISPOOF_V1SE_SHA := ebab7f90c7833fbccd46d3a555410e78d969db5438e169b6524be444862b3676
|
||||
|
||||
.PHONY: insightface-opencv-models
|
||||
insightface-opencv-models:
|
||||
@mkdir -p $(INSIGHTFACE_OPENCV_DIR)
|
||||
@@ -658,6 +693,20 @@ insightface-opencv-models:
|
||||
echo "$(INSIGHTFACE_OPENCV_SFACE_SHA) $(INSIGHTFACE_OPENCV_DIR)/sface.onnx" | sha256sum -c; \
|
||||
fi
|
||||
|
||||
.PHONY: insightface-antispoof-models
|
||||
insightface-antispoof-models:
|
||||
@mkdir -p $(INSIGHTFACE_ANTISPOOF_DIR)
|
||||
@if [ "$$(sha256sum $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx 2>/dev/null | awk '{print $$1}')" != "$(INSIGHTFACE_ANTISPOOF_V2_SHA)" ]; then \
|
||||
echo "Fetching MiniFASNetV2..."; \
|
||||
curl -fsSL -o $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx $(INSIGHTFACE_ANTISPOOF_V2_URL); \
|
||||
echo "$(INSIGHTFACE_ANTISPOOF_V2_SHA) $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx" | sha256sum -c; \
|
||||
fi
|
||||
@if [ "$$(sha256sum $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx 2>/dev/null | awk '{print $$1}')" != "$(INSIGHTFACE_ANTISPOOF_V1SE_SHA)" ]; then \
|
||||
echo "Fetching MiniFASNetV1SE..."; \
|
||||
curl -fsSL -o $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx $(INSIGHTFACE_ANTISPOOF_V1SE_URL); \
|
||||
echo "$(INSIGHTFACE_ANTISPOOF_V1SE_SHA) $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx" | sha256sum -c; \
|
||||
fi
|
||||
|
||||
.PHONY: insightface-buffalo-sc-models
|
||||
insightface-buffalo-sc-models:
|
||||
@mkdir -p $(INSIGHTFACE_BUFFALO_SC_DIR)
|
||||
@@ -680,14 +729,15 @@ insightface-buffalo-sc-models:
|
||||
## the e2e suite drives LoadModel directly without going through
|
||||
## LocalAI's gallery flow (which is what would normally populate
|
||||
## ModelPath and in turn the engine's `_model_dir` option).
|
||||
test-extra-backend-insightface-buffalo-sc: docker-build-insightface insightface-buffalo-sc-models
|
||||
test-extra-backend-insightface-buffalo-sc: docker-build-insightface insightface-buffalo-sc-models insightface-antispoof-models
|
||||
BACKEND_IMAGE=local-ai-backend:insightface \
|
||||
BACKEND_TEST_MODEL_NAME=insightface-buffalo-sc \
|
||||
BACKEND_TEST_OPTIONS=engine:insightface,model_pack:buffalo_sc,root:$(INSIGHTFACE_BUFFALO_SC_DIR) \
|
||||
BACKEND_TEST_CAPS=health,load,face_detect,face_embed,face_verify \
|
||||
BACKEND_TEST_OPTIONS=engine:insightface,model_pack:buffalo_sc,root:$(INSIGHTFACE_BUFFALO_SC_DIR),antispoof_v2_onnx:$(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx,antispoof_v1se_onnx:$(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx \
|
||||
BACKEND_TEST_CAPS=health,load,face_detect,face_embed,face_verify,face_antispoof \
|
||||
BACKEND_TEST_FACE_IMAGE_1_URL=$(FACE_IMAGE_1_URL) \
|
||||
BACKEND_TEST_FACE_IMAGE_2_URL=$(FACE_IMAGE_2_URL) \
|
||||
BACKEND_TEST_FACE_IMAGE_3_URL=$(FACE_IMAGE_3_URL) \
|
||||
BACKEND_TEST_FACE_SPOOF_IMAGE_URL=$(FACE_SPOOF_IMAGE_URL) \
|
||||
BACKEND_TEST_VERIFY_DISTANCE_CEILING=0.55 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
@@ -696,14 +746,15 @@ test-extra-backend-insightface-buffalo-sc: docker-build-insightface insightface-
|
||||
## pre-fetched on the host via the insightface-opencv-models target and
|
||||
## passed as absolute paths, since the e2e suite drives LoadModel
|
||||
## directly without going through LocalAI's gallery flow.
|
||||
test-extra-backend-insightface-opencv: docker-build-insightface insightface-opencv-models
|
||||
test-extra-backend-insightface-opencv: docker-build-insightface insightface-opencv-models insightface-antispoof-models
|
||||
BACKEND_IMAGE=local-ai-backend:insightface \
|
||||
BACKEND_TEST_MODEL_NAME=insightface-opencv \
|
||||
BACKEND_TEST_OPTIONS=engine:onnx_direct,detector_onnx:$(INSIGHTFACE_OPENCV_DIR)/yunet.onnx,recognizer_onnx:$(INSIGHTFACE_OPENCV_DIR)/sface.onnx \
|
||||
BACKEND_TEST_CAPS=health,load,face_detect,face_embed,face_verify \
|
||||
BACKEND_TEST_OPTIONS=engine:onnx_direct,detector_onnx:$(INSIGHTFACE_OPENCV_DIR)/yunet.onnx,recognizer_onnx:$(INSIGHTFACE_OPENCV_DIR)/sface.onnx,antispoof_v2_onnx:$(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx,antispoof_v1se_onnx:$(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx \
|
||||
BACKEND_TEST_CAPS=health,load,face_detect,face_embed,face_verify,face_antispoof \
|
||||
BACKEND_TEST_FACE_IMAGE_1_URL=$(FACE_IMAGE_1_URL) \
|
||||
BACKEND_TEST_FACE_IMAGE_2_URL=$(FACE_IMAGE_2_URL) \
|
||||
BACKEND_TEST_FACE_IMAGE_3_URL=$(FACE_IMAGE_3_URL) \
|
||||
BACKEND_TEST_FACE_SPOOF_IMAGE_URL=$(FACE_SPOOF_IMAGE_URL) \
|
||||
BACKEND_TEST_VERIFY_DISTANCE_CEILING=0.55 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
@@ -713,6 +764,79 @@ test-extra-backend-insightface-all: \
|
||||
test-extra-backend-insightface-buffalo-sc \
|
||||
test-extra-backend-insightface-opencv
|
||||
|
||||
## speaker-recognition — voice (speaker) biometrics.
|
||||
##
|
||||
## Audio fixtures default to the speechbrain test samples served
|
||||
## straight from their GitHub repo — public, no auth needed, and they
|
||||
## ship as 16kHz mono WAV/FLAC which is exactly what the engine wants.
|
||||
## example{1,2,5} are three different speakers; the suite treats
|
||||
## example1 as the "same-image twin" probe (verify(clip, clip) must
|
||||
## return distance≈0) and the other two as cross-speaker ceilings.
|
||||
## Override with BACKEND_TEST_VOICE_AUDIO_{1,2,3}_FILE for offline runs.
|
||||
VOICE_AUDIO_1_URL ?= https://github.com/speechbrain/speechbrain/raw/develop/tests/samples/single-mic/example1.wav
|
||||
VOICE_AUDIO_2_URL ?= https://github.com/speechbrain/speechbrain/raw/develop/tests/samples/single-mic/example2.flac
|
||||
VOICE_AUDIO_3_URL ?= https://github.com/speechbrain/speechbrain/raw/develop/tests/samples/single-mic/example5.wav
|
||||
|
||||
## ECAPA-TDNN via SpeechBrain — default CI configuration. Auto-downloads
|
||||
## the checkpoint from HuggingFace on first LoadModel (bundled in the
|
||||
## backend image pip install). 192-d embeddings, cosine-distance based.
|
||||
## The e2e suite drives LoadModel directly so we don't rely on LocalAI's
|
||||
## gallery flow here.
|
||||
test-extra-backend-speaker-recognition-ecapa: docker-build-speaker-recognition
|
||||
BACKEND_IMAGE=local-ai-backend:speaker-recognition \
|
||||
BACKEND_TEST_MODEL_NAME=speechbrain/spkrec-ecapa-voxceleb \
|
||||
BACKEND_TEST_OPTIONS=engine:speechbrain,source:speechbrain/spkrec-ecapa-voxceleb \
|
||||
BACKEND_TEST_CAPS=health,load,voice_embed,voice_verify \
|
||||
BACKEND_TEST_VOICE_AUDIO_1_URL=$(VOICE_AUDIO_1_URL) \
|
||||
BACKEND_TEST_VOICE_AUDIO_2_URL=$(VOICE_AUDIO_2_URL) \
|
||||
BACKEND_TEST_VOICE_AUDIO_3_URL=$(VOICE_AUDIO_3_URL) \
|
||||
BACKEND_TEST_VOICE_VERIFY_DISTANCE_CEILING=0.4 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Aggregate — today there's only one voice config; the target exists
|
||||
## so the CI workflow matches the insightface-all naming convention and
|
||||
## can grow to include WeSpeaker / 3D-Speaker later.
|
||||
test-extra-backend-speaker-recognition-all: \
|
||||
test-extra-backend-speaker-recognition-ecapa
|
||||
|
||||
## Realtime e2e with sherpa-onnx driving VAD + STT + TTS against a mocked
|
||||
## LLM. Extracts the sherpa-onnx Docker image rootfs, downloads the three
|
||||
## gallery-referenced model bundles (silero-vad, omnilingual-asr, vits-ljs),
|
||||
## writes the corresponding model config YAMLs, and runs the realtime
|
||||
## websocket spec in tests/e2e with REALTIME_* env vars wiring the sherpa
|
||||
## slots into the pipeline. The LLM slot stays on the in-repo mock-backend
|
||||
## registered unconditionally by tests/e2e/e2e_suite_test.go. See
|
||||
## tests/e2e/run-realtime-sherpa.sh for the full orchestration.
|
||||
test-extra-e2e-realtime-sherpa: build-mock-backend docker-build-sherpa-onnx protogen-go react-ui
|
||||
bash tests/e2e/run-realtime-sherpa.sh
|
||||
|
||||
## Streaming ASR via the sherpa-onnx online recognizer. Uses the streaming
|
||||
## zipformer English model (encoder/decoder/joiner int8 + tokens) from the
|
||||
## sherpa-onnx gallery entry. Drives both AudioTranscription and
|
||||
## AudioTranscriptionStream via the e2e-backends gRPC harness; streaming
|
||||
## emits real partial deltas during decode. Each file is renamed on download
|
||||
## to the shape sherpa-onnx's online loader expects (encoder.int8.onnx etc.).
|
||||
test-extra-backend-sherpa-onnx-transcription: docker-build-sherpa-onnx
|
||||
BACKEND_IMAGE=local-ai-backend:sherpa-onnx \
|
||||
BACKEND_TEST_MODEL_URL='https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/encoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx#encoder.int8.onnx' \
|
||||
BACKEND_TEST_EXTRA_FILES='https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/decoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx#decoder.int8.onnx|https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/joiner-epoch-99-avg-1-chunk-16-left-128.int8.onnx#joiner.int8.onnx|https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/tokens.txt' \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
BACKEND_TEST_OPTIONS=subtype=online \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## VITS TTS via the sherpa-onnx backend. Pulls the individual files from
|
||||
## HuggingFace (the vits-ljs release tarball lives on the k2-fsa github
|
||||
## but is also mirrored as discrete files on HF). Exercises both
|
||||
## TTS (write-to-file) and TTSStream (PCM chunks + WAV header) via the
|
||||
## e2e-backends gRPC harness.
|
||||
test-extra-backend-sherpa-onnx-tts: docker-build-sherpa-onnx
|
||||
BACKEND_IMAGE=local-ai-backend:sherpa-onnx \
|
||||
BACKEND_TEST_MODEL_URL='https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx#vits-ljs.onnx' \
|
||||
BACKEND_TEST_EXTRA_FILES='https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt|https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt' \
|
||||
BACKEND_TEST_CAPS=health,load,tts \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## sglang mirrors the vllm setup: HuggingFace model id, same tiny Qwen,
|
||||
## tool-call extraction via sglang's native qwen parser. CPU builds use
|
||||
## sglang's upstream pyproject_cpu.toml recipe (see backend/python/sglang/install.sh).
|
||||
@@ -838,6 +962,11 @@ BACKEND_IK_LLAMA_CPP = ik-llama-cpp|ik-llama-cpp|.|false|false
|
||||
# turboquant is a llama.cpp fork with TurboQuant KV-cache quantization.
|
||||
# Reuses backend/cpp/llama-cpp grpc-server sources via a thin wrapper Makefile.
|
||||
BACKEND_TURBOQUANT = turboquant|turboquant|.|false|false
|
||||
# buun-llama-cpp is a fork-of-a-fork (spiritbuun/buun-llama-cpp forks
|
||||
# TheTom/llama-cpp-turboquant) that adds DFlash block-diffusion speculative
|
||||
# decoding and extra TCQ KV-cache variants on top of TurboQuant. Same thin
|
||||
# wrapper pattern as turboquant — reuses backend/cpp/llama-cpp grpc-server.
|
||||
BACKEND_BUUN_LLAMA_CPP = buun-llama-cpp|buun-llama-cpp|.|false|false
|
||||
|
||||
# Golang backends
|
||||
BACKEND_PIPER = piper|golang|.|false|true
|
||||
@@ -850,6 +979,7 @@ BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||
BACKEND_OPUS = opus|golang|.|false|true
|
||||
BACKEND_SHERPA_ONNX = sherpa-onnx|golang|.|false|true
|
||||
|
||||
# Python backends with root context
|
||||
BACKEND_RERANKERS = rerankers|python|.|false|true
|
||||
@@ -859,6 +989,7 @@ BACKEND_FASTER_WHISPER = faster-whisper|python|.|false|true
|
||||
BACKEND_COQUI = coqui|python|.|false|true
|
||||
BACKEND_RFDETR = rfdetr|python|.|false|true
|
||||
BACKEND_INSIGHTFACE = insightface|python|.|false|true
|
||||
BACKEND_SPEAKER_RECOGNITION = speaker-recognition|python|.|false|true
|
||||
BACKEND_KITTEN_TTS = kitten-tts|python|.|false|true
|
||||
BACKEND_NEUTTS = neutts|python|.|false|true
|
||||
BACKEND_KOKORO = kokoro|python|.|false|true
|
||||
@@ -916,6 +1047,7 @@ endef
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_IK_LLAMA_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_BUUN_LLAMA_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
@@ -931,6 +1063,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_FASTER_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_COQUI)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RFDETR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_INSIGHTFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SPEAKER_RECOGNITION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KITTEN_TTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_NEUTTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKORO)))
|
||||
@@ -960,12 +1093,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-insightface
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-buun-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -149,6 +149,7 @@ For more details, see the [Getting Started guide](https://localai.io/basics/gett
|
||||
|
||||
## Latest News
|
||||
|
||||
- **April 2026**: [Voice recognition](https://github.com/mudler/LocalAI/pull/9500), [Face recognition, identification & liveness detection](https://github.com/mudler/LocalAI/pull/9480), [Ollama API compatibility](https://github.com/mudler/LocalAI/pull/9284), [Video generation in stable-diffusion.ggml](https://github.com/mudler/LocalAI/pull/9420), [Backend versioning with auto-upgrade](https://github.com/mudler/LocalAI/pull/9315), [Pin models & load-on-demand toggle](https://github.com/mudler/LocalAI/pull/9309), [Universal model importer](https://github.com/mudler/LocalAI/pull/9466), new backends: [sglang](https://github.com/mudler/LocalAI/pull/9359), [ik-llama-cpp](https://github.com/mudler/LocalAI/pull/9326), [TurboQuant](https://github.com/mudler/LocalAI/pull/9355), [sam.cpp](https://github.com/mudler/LocalAI/pull/9288), [Kokoros](https://github.com/mudler/LocalAI/pull/9212), [qwen3tts.cpp](https://github.com/mudler/LocalAI/pull/9316), [tinygrad multimodal](https://github.com/mudler/LocalAI/pull/9364)
|
||||
- **March 2026**: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
||||
- **February 2026**: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
||||
- **January 2026**: **LocalAI 3.10.0** — Anthropic API support, Open Responses API, video & image generation (LTX-2), unified GPU backends, tool streaming, Moonshine, Pocket-TTS. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0)
|
||||
|
||||
290
backend/Dockerfile.buun-llama-cpp
Normal file
290
backend/Dockerfile.buun-llama-cpp
Normal file
@@ -0,0 +1,290 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG GRPC_BASE_IMAGE=${BASE_IMAGE}
|
||||
|
||||
|
||||
# The grpc target does one thing, it builds and installs GRPC. This is in it's own layer so that it can be effectively cached by CI.
|
||||
# You probably don't need to change anything here, and if you do, make sure that CI is adjusted so that the cache continues to work.
|
||||
FROM ${GRPC_BASE_IMAGE} AS grpc
|
||||
|
||||
# This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI
|
||||
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
|
||||
ARG GRPC_VERSION=v1.65.0
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
# CUDA Toolkit 13.x compatibility: CMake 3.31.9+ fixes toolchain detection/arch table issues
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
|
||||
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
build-essential curl libssl-dev \
|
||||
git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# We install GRPC to a different prefix here so that we can copy in only the build artifacts later
|
||||
# saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree
|
||||
# and running make install in the target container
|
||||
RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||
mkdir -p /build/grpc/cmake/build && \
|
||||
cd /build/grpc/cmake/build && \
|
||||
sed -i "216i\ TESTONLY" "../../third_party/abseil-cpp/absl/container/CMakeLists.txt" && \
|
||||
cmake -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX:PATH=/opt/grpc ../.. && \
|
||||
make && \
|
||||
make install && \
|
||||
rm -rf /build
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
# We can target specific CUDA ARCHITECTURES like --build-arg CUDA_DOCKER_ARCH='75;86;89;120'
|
||||
ARG CUDA_DOCKER_ARCH
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
ARG CMAKE_ARGS
|
||||
ENV CMAKE_ARGS=${CMAKE_ARGS}
|
||||
ARG BACKEND=rerankers
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG UBUNTU_VERSION=2404
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ccache git \
|
||||
ca-certificates \
|
||||
make \
|
||||
pkg-config libcurl4-openssl-dev \
|
||||
curl unzip \
|
||||
libssl-dev wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
# I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
|
||||
# to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
|
||||
ldconfig && \
|
||||
# Log which GPU architectures have rocBLAS kernel support
|
||||
echo "rocBLAS library data architectures:" && \
|
||||
(ls /opt/rocm*/lib/rocblas/library/Kernels* 2>/dev/null || ls /opt/rocm*/lib64/rocblas/library/Kernels* 2>/dev/null) | grep -oP 'gfx[0-9a-z+-]+' | sort -u || \
|
||||
echo "WARNING: No rocBLAS kernel data found" \
|
||||
; fi
|
||||
|
||||
RUN echo "TARGETARCH: $TARGETARCH"
|
||||
|
||||
# We need protoc installed, and the version in 22.04 is too old. We will create one as part installing the GRPC build below
|
||||
# but that will also being in a newer version of absl which stablediffusion cannot compile with. This version of protoc is only
|
||||
# here so that we can generate the grpc code for the stablediffusion build
|
||||
RUN <<EOT bash
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-aarch_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
COPY --from=grpc /opt/grpc /usr/local
|
||||
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
RUN <<'EOT' bash
|
||||
set -euxo pipefail
|
||||
|
||||
if [[ -n "${CUDA_DOCKER_ARCH:-}" ]]; then
|
||||
CUDA_ARCH_ESC="${CUDA_DOCKER_ARCH//;/\\;}"
|
||||
export CMAKE_ARGS="${CMAKE_ARGS:-} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH_ESC}"
|
||||
echo "CMAKE_ARGS(env) = ${CMAKE_ARGS}"
|
||||
rm -rf /LocalAI/backend/cpp/buun-llama-cpp-*-build
|
||||
fi
|
||||
|
||||
cd /LocalAI/backend/cpp/buun-llama-cpp
|
||||
|
||||
if [ "${TARGETARCH}" = "arm64" ] || [ "${BUILD_TYPE}" = "hipblas" ]; then
|
||||
make buun-llama-cpp-fallback
|
||||
make buun-llama-cpp-grpc
|
||||
make buun-llama-cpp-rpc-server
|
||||
else
|
||||
make buun-llama-cpp-avx
|
||||
make buun-llama-cpp-avx2
|
||||
make buun-llama-cpp-avx512
|
||||
make buun-llama-cpp-fallback
|
||||
make buun-llama-cpp-grpc
|
||||
make buun-llama-cpp-rpc-server
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# Copy libraries using a script to handle architecture differences
|
||||
RUN make -BC /LocalAI/backend/cpp/buun-llama-cpp package
|
||||
|
||||
|
||||
FROM scratch
|
||||
|
||||
|
||||
# Copy all available binaries (the build process only creates the appropriate ones for the target architecture)
|
||||
COPY --from=builder /LocalAI/backend/cpp/buun-llama-cpp/package/. ./
|
||||
@@ -26,6 +26,9 @@ service Backend {
|
||||
rpc Detect(DetectOptions) returns (DetectResponse) {}
|
||||
rpc FaceVerify(FaceVerifyRequest) returns (FaceVerifyResponse) {}
|
||||
rpc FaceAnalyze(FaceAnalyzeRequest) returns (FaceAnalyzeResponse) {}
|
||||
rpc VoiceVerify(VoiceVerifyRequest) returns (VoiceVerifyResponse) {}
|
||||
rpc VoiceAnalyze(VoiceAnalyzeRequest) returns (VoiceAnalyzeResponse) {}
|
||||
rpc VoiceEmbed(VoiceEmbedRequest) returns (VoiceEmbedResponse) {}
|
||||
|
||||
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
||||
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
||||
@@ -490,7 +493,7 @@ message FaceVerifyRequest {
|
||||
string img1 = 1; // base64-encoded image
|
||||
string img2 = 2; // base64-encoded image
|
||||
float threshold = 3; // cosine-distance threshold; 0 = use backend default
|
||||
bool anti_spoofing = 4; // reserved for future MiniFASNet bolt-on
|
||||
bool anti_spoofing = 4; // run MiniFASNet liveness on each image; failed liveness forces verified=false
|
||||
}
|
||||
|
||||
message FaceVerifyResponse {
|
||||
@@ -502,6 +505,10 @@ message FaceVerifyResponse {
|
||||
FacialArea img1_area = 6;
|
||||
FacialArea img2_area = 7;
|
||||
float processing_time_ms = 8;
|
||||
bool img1_is_real = 9; // anti-spoofing result when enabled
|
||||
float img1_antispoof_score = 10;
|
||||
bool img2_is_real = 11;
|
||||
float img2_antispoof_score = 12;
|
||||
}
|
||||
|
||||
message FaceAnalyzeRequest {
|
||||
@@ -528,6 +535,57 @@ message FaceAnalyzeResponse {
|
||||
repeated FaceAnalysis faces = 1;
|
||||
}
|
||||
|
||||
// --- Voice (speaker) recognition messages ---
|
||||
//
|
||||
// Analogous to the Face* messages above, but for speaker biometrics.
|
||||
// Audio fields accept a filesystem path (same convention as
|
||||
// TranscriptRequest.dst). The HTTP layer materialises base64 / URL /
|
||||
// data-URI inputs to a temp file before calling the gRPC backend.
|
||||
|
||||
message VoiceVerifyRequest {
|
||||
string audio1 = 1; // path to first audio clip
|
||||
string audio2 = 2; // path to second audio clip
|
||||
float threshold = 3; // cosine-distance threshold; 0 = use backend default
|
||||
bool anti_spoofing = 4; // reserved for future AASIST bolt-on
|
||||
}
|
||||
|
||||
message VoiceVerifyResponse {
|
||||
bool verified = 1;
|
||||
float distance = 2; // 1 - cosine_similarity
|
||||
float threshold = 3;
|
||||
float confidence = 4; // 0-100
|
||||
string model = 5; // e.g. "speechbrain/spkrec-ecapa-voxceleb"
|
||||
float processing_time_ms = 6;
|
||||
}
|
||||
|
||||
message VoiceAnalyzeRequest {
|
||||
string audio = 1; // path to audio clip
|
||||
repeated string actions = 2; // subset of ["age","gender","emotion"]; empty = all-supported
|
||||
}
|
||||
|
||||
message VoiceAnalysis {
|
||||
float start = 1; // segment start time in seconds (0 if single-utterance)
|
||||
float end = 2; // segment end time in seconds
|
||||
float age = 3;
|
||||
string dominant_gender = 4;
|
||||
map<string, float> gender = 5;
|
||||
string dominant_emotion = 6;
|
||||
map<string, float> emotion = 7;
|
||||
}
|
||||
|
||||
message VoiceAnalyzeResponse {
|
||||
repeated VoiceAnalysis segments = 1;
|
||||
}
|
||||
|
||||
message VoiceEmbedRequest {
|
||||
string audio = 1; // path to audio clip
|
||||
}
|
||||
|
||||
message VoiceEmbedResponse {
|
||||
repeated float embedding = 1;
|
||||
string model = 2;
|
||||
}
|
||||
|
||||
message ToolFormatMarkers {
|
||||
string format_type = 1; // "json_native", "tag_with_json", "tag_with_tagged"
|
||||
|
||||
|
||||
85
backend/cpp/buun-llama-cpp/Makefile
Normal file
85
backend/cpp/buun-llama-cpp/Makefile
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
# Pinned to the HEAD of master on https://github.com/spiritbuun/buun-llama-cpp.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
BUUN_LLAMA_VERSION?=22464d0848b87c5d56b52fdf6af2e5da46bf803e
|
||||
LLAMA_REPO?=https://github.com/spiritbuun/buun-llama-cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
ONEAPI_VARS?=/opt/intel/oneapi/setvars.sh
|
||||
TARGET?=--target grpc-server
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 1)
|
||||
ARCH?=$(shell uname -m)
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
LLAMA_CPP_DIR := $(CURRENT_MAKEFILE_DIR)/../llama-cpp
|
||||
|
||||
GREEN := \033[0;32m
|
||||
RESET := \033[0m
|
||||
|
||||
# buun-llama-cpp is a llama.cpp fork-of-a-fork (spiritbuun/buun-llama-cpp forked
|
||||
# TheTom/llama-cpp-turboquant, which itself forked ggml-org/llama.cpp). Rather
|
||||
# than duplicating grpc-server.cpp / CMakeLists.txt / prepare.sh we reuse the
|
||||
# ones in backend/cpp/llama-cpp, and only swap which repo+sha the fetch step
|
||||
# pulls. Each flavor target copies ../llama-cpp into a sibling
|
||||
# ../buun-llama-cpp-<flavor>-build directory, then invokes llama-cpp's own
|
||||
# build-llama-cpp-grpc-server with LLAMA_REPO/LLAMA_VERSION overridden to point
|
||||
# at the fork.
|
||||
PATCHES_DIR := $(CURRENT_MAKEFILE_DIR)/patches
|
||||
|
||||
# Each flavor target:
|
||||
# 1. copies backend/cpp/llama-cpp/ (grpc-server.cpp + prepare.sh + CMakeLists.txt + Makefile)
|
||||
# into a sibling buun-llama-cpp-<flavor>-build directory;
|
||||
# 2. clones the buun fork into buun-llama-cpp-<flavor>-build/llama.cpp via the
|
||||
# copy's own `llama.cpp` target, overriding LLAMA_REPO/LLAMA_VERSION;
|
||||
# 3. applies patches from backend/cpp/buun-llama-cpp/patches/ to the cloned
|
||||
# fork sources (for backporting upstream commits the fork hasn't pulled);
|
||||
# 4. runs the copy's `grpc-server` target, which produces the binary we copy
|
||||
# up as buun-llama-cpp-<flavor>.
|
||||
define buun-llama-cpp-build
|
||||
rm -rf $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-$(1)-build
|
||||
cp -rf $(LLAMA_CPP_DIR) $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-$(1)-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-$(1)-build purge
|
||||
# Augment the copied grpc-server.cpp's KV-cache allow-list with the
|
||||
# fork's turbo2/turbo3/turbo4/turbo2_tcq/turbo3_tcq types and wire up the
|
||||
# DFlash-specific option handlers (tree_budget / draft_topk). We patch the
|
||||
# *copy*, never the original under backend/cpp/llama-cpp/, so the stock
|
||||
# llama-cpp build stays compiling against vanilla upstream.
|
||||
bash $(CURRENT_MAKEFILE_DIR)/patch-grpc-server.sh $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-$(1)-build/grpc-server.cpp
|
||||
$(info $(GREEN)I buun-llama-cpp build info:$(1)$(RESET))
|
||||
LLAMA_REPO=$(LLAMA_REPO) LLAMA_VERSION=$(BUUN_LLAMA_VERSION) \
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-$(1)-build llama.cpp
|
||||
bash $(CURRENT_MAKEFILE_DIR)/apply-patches.sh $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-$(1)-build/llama.cpp $(PATCHES_DIR)
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) $(2)" TARGET="$(3)" \
|
||||
LLAMA_REPO=$(LLAMA_REPO) LLAMA_VERSION=$(BUUN_LLAMA_VERSION) \
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-$(1)-build grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-$(1)-build/grpc-server buun-llama-cpp-$(1)
|
||||
endef
|
||||
|
||||
buun-llama-cpp-avx2:
|
||||
$(call buun-llama-cpp-build,avx2,-DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on,--target grpc-server)
|
||||
|
||||
buun-llama-cpp-avx512:
|
||||
$(call buun-llama-cpp-build,avx512,-DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on,--target grpc-server)
|
||||
|
||||
buun-llama-cpp-avx:
|
||||
$(call buun-llama-cpp-build,avx,-DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off,--target grpc-server)
|
||||
|
||||
buun-llama-cpp-fallback:
|
||||
$(call buun-llama-cpp-build,fallback,-DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off,--target grpc-server)
|
||||
|
||||
buun-llama-cpp-grpc:
|
||||
$(call buun-llama-cpp-build,grpc,-DGGML_RPC=ON -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off,--target grpc-server --target rpc-server)
|
||||
|
||||
buun-llama-cpp-rpc-server: buun-llama-cpp-grpc
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-grpc-build/llama.cpp/build/bin/rpc-server buun-llama-cpp-rpc-server
|
||||
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
purge:
|
||||
rm -rf $(CURRENT_MAKEFILE_DIR)/../buun-llama-cpp-*-build
|
||||
rm -rf buun-llama-cpp-* package
|
||||
|
||||
clean: purge
|
||||
50
backend/cpp/buun-llama-cpp/apply-patches.sh
Executable file
50
backend/cpp/buun-llama-cpp/apply-patches.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
# Apply the buun-llama-cpp patch series to a cloned buun-llama-cpp checkout.
|
||||
#
|
||||
# buun-llama-cpp is a fork-of-a-fork that branched off upstream llama.cpp
|
||||
# before some API changes the shared backend/cpp/llama-cpp/grpc-server.cpp
|
||||
# depends on. We carry those upstream commits as patch files under
|
||||
# backend/cpp/buun-llama-cpp/patches/ and apply them here so the reused
|
||||
# grpc-server source compiles against the fork unmodified.
|
||||
#
|
||||
# Drop the corresponding patch from patches/ whenever the fork catches up with
|
||||
# upstream — the build will fail fast if a patch stops applying, which is the
|
||||
# signal to retire it.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# -ne 2 ]]; then
|
||||
echo "usage: $0 <llama.cpp-src-dir> <patches-dir>" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
SRC_DIR=$1
|
||||
PATCHES_DIR=$2
|
||||
|
||||
if [[ ! -d "$SRC_DIR" ]]; then
|
||||
echo "source dir does not exist: $SRC_DIR" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if [[ ! -d "$PATCHES_DIR" ]]; then
|
||||
echo "no patches dir at $PATCHES_DIR, nothing to apply"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
shopt -s nullglob
|
||||
patches=("$PATCHES_DIR"/*.patch)
|
||||
shopt -u nullglob
|
||||
|
||||
if [[ ${#patches[@]} -eq 0 ]]; then
|
||||
echo "no .patch files in $PATCHES_DIR, nothing to apply"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
cd "$SRC_DIR"
|
||||
|
||||
for patch in "${patches[@]}"; do
|
||||
echo "==> applying $patch"
|
||||
git apply --verbose "$patch"
|
||||
done
|
||||
|
||||
echo "all buun-llama-cpp patches applied successfully"
|
||||
57
backend/cpp/buun-llama-cpp/package.sh
Executable file
57
backend/cpp/buun-llama-cpp/package.sh
Executable file
@@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avrf $CURDIR/buun-llama-cpp-* $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
162
backend/cpp/buun-llama-cpp/patch-grpc-server.sh
Executable file
162
backend/cpp/buun-llama-cpp/patch-grpc-server.sh
Executable file
@@ -0,0 +1,162 @@
|
||||
#!/bin/bash
|
||||
# Patch the shared backend/cpp/llama-cpp/grpc-server.cpp *copy* used by the
|
||||
# buun-llama-cpp build to account for three gaps between upstream and the fork:
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types plus the buun
|
||||
# additions `turbo2_tcq` / `turbo3_tcq`.
|
||||
#
|
||||
# 2. Wire up buun-exclusive speculative-decoding option handlers
|
||||
# (tree_budget / draft_topk) alongside the existing spec_* handlers.
|
||||
# These reference struct fields (common_params.speculative.tree_budget
|
||||
# and .draft_topk) that only exist in buun's common/common.h — adding
|
||||
# them to the shared backend/cpp/llama-cpp/grpc-server.cpp would break
|
||||
# the stock llama-cpp build, so we inject them only into the buun copy.
|
||||
#
|
||||
# 3. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the buun path.
|
||||
#
|
||||
# We patch the *copy* sitting in buun-llama-cpp-<flavor>-build/, never the
|
||||
# original under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps
|
||||
# compiling against vanilla upstream.
|
||||
#
|
||||
# Idempotent: skips each insertion if its marker is already present (so re-runs
|
||||
# of the same build dir don't double-insert).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# -ne 1 ]]; then
|
||||
echo "usage: $0 <grpc-server.cpp>" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
SRC=$1
|
||||
|
||||
if [[ ! -f "$SRC" ]]; then
|
||||
echo "grpc-server.cpp not found at $SRC" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if grep -q 'GGML_TYPE_TURBO2_TCQ' "$SRC"; then
|
||||
echo "==> $SRC already has buun cache types, skipping KV allow-list patch"
|
||||
else
|
||||
echo "==> patching $SRC to allow turbo2/turbo3/turbo4/turbo2_tcq/turbo3_tcq KV-cache types"
|
||||
|
||||
# Insert the five TURBO entries right after the first ` GGML_TYPE_Q5_1,`
|
||||
# line (the kv_cache_types[] allow-list). Using awk because the builder
|
||||
# image does not ship python3, and GNU sed's multi-line `a\` quoting is
|
||||
# awkward.
|
||||
awk '
|
||||
/^ GGML_TYPE_Q5_1,$/ && !done {
|
||||
print
|
||||
print " // buun-llama-cpp fork extras — added by patch-grpc-server.sh"
|
||||
print " GGML_TYPE_TURBO2_0,"
|
||||
print " GGML_TYPE_TURBO3_0,"
|
||||
print " GGML_TYPE_TURBO4_0,"
|
||||
print " GGML_TYPE_TURBO2_TCQ,"
|
||||
print " GGML_TYPE_TURBO3_TCQ,"
|
||||
done = 1
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: anchor ` GGML_TYPE_Q5_1,` not found" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
if grep -q 'optname, "tree_budget"' "$SRC"; then
|
||||
echo "==> $SRC already has DFlash option handlers, skipping"
|
||||
else
|
||||
echo "==> patching $SRC to add tree_budget / draft_topk option handlers"
|
||||
|
||||
# Insert two new `else if` handlers between the inner close-brace of the
|
||||
# `spec_p_split` block and the next `} else if (…spec_ngram_size_n…)` line.
|
||||
# Upstream writes each `} else if` as a single physical line, so we don't
|
||||
# emit an outer `}` ourselves — the existing next line provides both the
|
||||
# close of our `draft_topk` block and the open of `spec_ngram_size_n`.
|
||||
# Anchor on the exact 3-line body of spec_p_split so we can't drift.
|
||||
awk '
|
||||
prev2 == " } else if (!strcmp(optname, \"spec_p_split\")) {" &&
|
||||
prev1 ~ /^ +if \(optval != NULL\) \{$/ &&
|
||||
$0 ~ /^ +try \{ params\.speculative\.p_split = std::stof\(optval_str\); \} catch \(\.\.\.\) \{\}$/ &&
|
||||
!done {
|
||||
print # print the try-line itself
|
||||
getline inner_close # read " }" closing the inner if
|
||||
print inner_close # print it — this closes spec_p_split body
|
||||
print " // buun-llama-cpp DFlash options — added by patch-grpc-server.sh"
|
||||
print " } else if (!strcmp(optname, \"tree_budget\")) {"
|
||||
print " if (optval != NULL) {"
|
||||
print " try { params.speculative.tree_budget = std::stoi(optval_str); } catch (...) {}"
|
||||
print " }"
|
||||
print " } else if (!strcmp(optname, \"draft_topk\")) {"
|
||||
print " if (optval != NULL) {"
|
||||
print " try { params.speculative.draft_topk = std::stoi(optval_str); } catch (...) {}"
|
||||
print " }"
|
||||
# The next source line (`} else if (…spec_ngram_size_n…) {`) closes
|
||||
# our draft_topk block and continues the chain naturally; fall back
|
||||
# into the main loop to emit it and everything after.
|
||||
done = 1
|
||||
prev2 = prev1
|
||||
prev1 = inner_close
|
||||
next
|
||||
}
|
||||
{ print; prev2 = prev1; prev1 = $0 }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: spec_p_split anchor not found" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
|
||||
echo "==> DFlash option-handler patch OK"
|
||||
fi
|
||||
|
||||
if grep -qE 'ctx_server\.get_meta\(\)\.logit_bias_eog|params_base\.sampling\.logit_bias_eog,' "$SRC"; then
|
||||
echo "==> patching $SRC to drop the logit_bias_eog arg from params_from_json_cmpl() callsites (buun still uses the pre-refactor 4-arg signature)"
|
||||
# Upstream llama.cpp refactored params_from_json_cmpl to take a precomputed
|
||||
# logit_bias_eog vector after buun's 2026-04-05 fork-point — simultaneously
|
||||
# adding server_context_meta::logit_bias_eog as the supplier. Buun carries
|
||||
# neither change: its params_from_json_cmpl is still 4-arg, and internally
|
||||
# derives logit_bias_eog from the common_params it's passed. So we just
|
||||
# delete the argument line entirely — the remaining 4 args match buun's
|
||||
# signature and the resulting behavior matches upstream bit-for-bit
|
||||
# (upstream's 5th arg is the same data buun derives internally).
|
||||
#
|
||||
# Guard is broad so this works whether the line has been run through this
|
||||
# block before (leaving params_base.sampling.logit_bias_eog,) or not
|
||||
# (leaving the original ctx_server.get_meta().logit_bias_eog,).
|
||||
sed -E '/^[[:space:]]+(ctx_server\.get_meta\(\)\.logit_bias_eog|params_base\.sampling\.logit_bias_eog),$/d' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> logit_bias_eog arg drop OK"
|
||||
else
|
||||
echo "==> $SRC has no logit_bias_eog arg line, skipping"
|
||||
fi
|
||||
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
else
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
@@ -0,0 +1,46 @@
|
||||
Subject: [PATCH] ggml-cuda/fattn: provide atomicAdd(double*,double) shim for pre-sm_60
|
||||
|
||||
Buun's Q² calibration path in ggml_cuda_turbo_scale_q calls
|
||||
atomicAdd(&d_q_channel_sq_fattn[threadIdx.x], (double)(val * val));
|
||||
but native double atomicAdd is only available on compute capability 6.0
|
||||
and newer. Compiling against a CUDA arch list that includes older
|
||||
architectures (LocalAI's CUDA 12 Docker image builds for the full
|
||||
published arch range) fails with:
|
||||
|
||||
fattn.cu(812): error: no instance of overloaded function "atomicAdd"
|
||||
matches the argument list, argument types are: (double *, double)
|
||||
|
||||
Add the canonical CUDA-programming-guide shim at the top of fattn.cu so
|
||||
pre-sm_60 codegen has a definition to call. On sm_60+ the native CUDA
|
||||
intrinsic is used and the shim is elided via __CUDA_ARCH__.
|
||||
|
||||
--- a/ggml/src/ggml-cuda/fattn.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn.cu
|
||||
@@ -7,6 +7,27 @@
|
||||
|
||||
#include <atomic>
|
||||
|
||||
+// Pre-sm_60 double atomicAdd shim. Native double atomicAdd(double*,double)
|
||||
+// is only available on CUDA compute capability 6.0+ (see CUDA C Programming
|
||||
+// Guide, B.15 Atomic Functions). Buun's Q² calibration path below calls
|
||||
+// atomicAdd with a double*; without this definition, nvcc fails to find a
|
||||
+// matching overload whenever the compile target list includes pre-sm_60
|
||||
+// architectures. The standard CAS loop implementation below matches the
|
||||
+// semantics of the native intrinsic.
|
||||
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
|
||||
+static __device__ double atomicAdd(double * address, double val) {
|
||||
+ unsigned long long int * address_as_ull = (unsigned long long int *)address;
|
||||
+ unsigned long long int old = *address_as_ull;
|
||||
+ unsigned long long int assumed;
|
||||
+ do {
|
||||
+ assumed = old;
|
||||
+ old = atomicCAS(address_as_ull, assumed,
|
||||
+ __double_as_longlong(val + __longlong_as_double(assumed)));
|
||||
+ } while (assumed != old);
|
||||
+ return __longlong_as_double(old);
|
||||
+}
|
||||
+#endif
|
||||
+
|
||||
// InnerQ: update the fattn-side inverse scale array from host (all devices)
|
||||
void turbo_innerq_update_fattn_scales(const float * scale_inv) {
|
||||
int cur_device;
|
||||
@@ -0,0 +1,32 @@
|
||||
Subject: [PATCH] ggml-cuda/argmax: pass WARP_SIZE to the top-K __shfl_xor_sync calls
|
||||
|
||||
Two __shfl_xor_sync calls in the top-K intra-warp merge drop the `width`
|
||||
argument and rely on the CUDA default (warpSize). Every other call in
|
||||
the same file already passes WARP_SIZE explicitly, and the HIP/ROCm
|
||||
compatibility shim at ggml/src/ggml-cuda/vendors/hip.h:33 is a 4-arg
|
||||
function-like macro — so the 3-arg form fails to preprocess when
|
||||
building with hipcc against ROCm:
|
||||
|
||||
argmax.cu:265: error: too few arguments provided to function-like
|
||||
macro invocation
|
||||
note: macro '__shfl_xor_sync' defined here:
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) \
|
||||
__shfl_xor(var, laneMask, width)
|
||||
|
||||
Align the two call sites with the rest of the file by passing WARP_SIZE
|
||||
explicitly. On CUDA the generated code is unchanged (warpSize is the
|
||||
default); on HIP it now matches the macro's arity.
|
||||
|
||||
--- a/ggml/src/ggml-cuda/argmax.cu
|
||||
+++ b/ggml/src/ggml-cuda/argmax.cu
|
||||
@@ -262,8 +262,8 @@
|
||||
// Each step: lane gets partner's min element, if it beats our min, replace and re-heapify
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
|
||||
for (int i = 0; i < K; i++) {
|
||||
- float partner_val = __shfl_xor_sync(0xFFFFFFFF, heap_val[i], offset);
|
||||
- int partner_idx = __shfl_xor_sync(0xFFFFFFFF, heap_idx[i], offset);
|
||||
+ float partner_val = __shfl_xor_sync(0xFFFFFFFF, heap_val[i], offset, WARP_SIZE);
|
||||
+ int partner_idx = __shfl_xor_sync(0xFFFFFFFF, heap_idx[i], offset, WARP_SIZE);
|
||||
if (partner_val > heap_val[0]) {
|
||||
heap_val[0] = partner_val;
|
||||
heap_idx[0] = partner_idx;
|
||||
@@ -0,0 +1,24 @@
|
||||
Subject: [PATCH] ggml-cuda/vendors/hip: alias cudaMemcpy{To,From}Symbol to hip counterparts
|
||||
|
||||
Buun's Q² calibration + TCQ codebook upload paths in fattn.cu use
|
||||
cudaMemcpyToSymbol / cudaMemcpyFromSymbol. The HIP-compat header in
|
||||
ggml/src/ggml-cuda/vendors/hip.h already aliases the scalar cudaMemcpy
|
||||
family (cudaMemcpy, cudaMemcpyAsync, cudaMemcpy2DAsync, …) but is
|
||||
missing the symbol variants. Building with hipcc therefore fails with
|
||||
15+ "use of undeclared identifier 'cudaMemcpyToSymbol'" errors.
|
||||
|
||||
Add the two missing aliases alongside the existing memcpy block. HIP
|
||||
provides hipMemcpy{To,From}Symbol with the same signature as CUDA's
|
||||
equivalents, so this is a straight name substitution.
|
||||
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -85,6 +85,8 @@
|
||||
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
+#define cudaMemcpyToSymbol hipMemcpyToSymbol
|
||||
+#define cudaMemcpyFromSymbol hipMemcpyFromSymbol
|
||||
#define cudaMemcpyKind hipMemcpyKind
|
||||
#define cudaMemset hipMemset
|
||||
#define cudaMemsetAsync hipMemsetAsync
|
||||
@@ -0,0 +1,36 @@
|
||||
Subject: [PATCH] ggml-cuda/fattn: pass WARP_SIZE to fwht128 __shfl_xor_sync calls
|
||||
|
||||
Same issue as the argmax top-K fix: two __shfl_xor_sync call sites in
|
||||
the FWHT-128 butterfly kernels (ggml_cuda_fwht128 and fwht128_store_half)
|
||||
use the 3-arg CUDA form and omit the `width` argument that the HIP
|
||||
function-like macro in vendors/hip.h:33 requires. Hipcc fails with:
|
||||
|
||||
fattn.cu:512: too few arguments provided to function-like macro
|
||||
invocation
|
||||
note: macro '__shfl_xor_sync' defined here:
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) \
|
||||
__shfl_xor(var, laneMask, width)
|
||||
|
||||
Add WARP_SIZE to both calls. CUDA codegen is unchanged (warpSize is the
|
||||
default); HIP now matches the macro arity.
|
||||
|
||||
--- a/ggml/src/ggml-cuda/fattn.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn.cu
|
||||
@@ -509,7 +509,7 @@
|
||||
// Intra-warp passes: shuffle xor with stride h, no smem, no sync.
|
||||
#pragma unroll
|
||||
for (int h = 1; h <= 16; h *= 2) {
|
||||
- const float other = __shfl_xor_sync(0xFFFFFFFF, val, h);
|
||||
+ const float other = __shfl_xor_sync(0xFFFFFFFF, val, h, WARP_SIZE);
|
||||
val = (tid & h) ? (other - val) : (val + other);
|
||||
}
|
||||
|
||||
@@ -533,7 +533,7 @@
|
||||
static __device__ __forceinline__ void fwht128_store_half(
|
||||
float val, half * dst_base) {
|
||||
const int tid = threadIdx.x;
|
||||
- const float neighbor = __shfl_xor_sync(0xFFFFFFFF, val, 1);
|
||||
+ const float neighbor = __shfl_xor_sync(0xFFFFFFFF, val, 1, WARP_SIZE);
|
||||
if ((tid & 1) == 0) {
|
||||
const half2 packed = __floats2half2_rn(val, neighbor);
|
||||
*((half2 *)(dst_base + tid)) = packed;
|
||||
65
backend/cpp/buun-llama-cpp/run.sh
Executable file
65
backend/cpp/buun-llama-cpp/run.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
|
||||
BINARY=buun-llama-cpp-fallback
|
||||
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/buun-llama-cpp-avx ]; then
|
||||
BINARY=buun-llama-cpp-avx
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/buun-llama-cpp-avx2 ]; then
|
||||
BINARY=buun-llama-cpp-avx2
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/buun-llama-cpp-avx512 ]; then
|
||||
BINARY=buun-llama-cpp-avx512
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMACPP_GRPC_SERVERS" ]; then
|
||||
if [ -e $CURDIR/buun-llama-cpp-grpc ]; then
|
||||
BINARY=buun-llama-cpp-grpc
|
||||
fi
|
||||
fi
|
||||
|
||||
# Extend ld library path with the dir where this script is located/lib
|
||||
if [ "$(uname)" == "Darwin" ]; then
|
||||
export DYLD_LIBRARY_PATH=$CURDIR/lib:$DYLD_LIBRARY_PATH
|
||||
else
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
# Tell rocBLAS where to find TensileLibrary data (GPU kernel tuning files)
|
||||
if [ -d "$CURDIR/lib/rocblas/library" ]; then
|
||||
export ROCBLAS_TENSILE_LIBPATH=$CURDIR/lib/rocblas/library
|
||||
fi
|
||||
fi
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using binary: $BINARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/$BINARY "$@"
|
||||
fi
|
||||
|
||||
echo "Using binary: $BINARY"
|
||||
exec $CURDIR/$BINARY "$@"
|
||||
|
||||
# We should never reach this point, however just in case we do, run fallback
|
||||
exec $CURDIR/buun-llama-cpp-fallback "$@"
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=d4824131580b94ffa7b0e91c955e2b237c2fe16e
|
||||
IK_LLAMA_VERSION?=16996aeab772c69b6473597038b2ef0b85297e8b
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -686,7 +686,16 @@ struct llama_server_context
|
||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||
slot->sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
{
|
||||
// upstream changed common_params_sampling::grammar from std::string to
|
||||
// the common_grammar struct (type + grammar). The incoming JSON still
|
||||
// carries a plain string, so build the user-provided grammar here and
|
||||
// fall back to the server default when the request omits it.
|
||||
std::string grammar_str = json_value(data, "grammar", std::string());
|
||||
slot->sparams.grammar = grammar_str.empty()
|
||||
? default_sparams.grammar
|
||||
: common_grammar{COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)};
|
||||
}
|
||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
slot->sparams.grammar_triggers = grammar_triggers;
|
||||
@@ -1232,7 +1241,7 @@ struct llama_server_context
|
||||
// {"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"grammar", slot.sparams.grammar.grammar},
|
||||
{"samplers", samplers}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
--- a/examples/llava/clip.cpp
|
||||
+++ b/examples/llava/clip.cpp
|
||||
@@ -2494,7 +2494,7 @@
|
||||
}
|
||||
new_data = work.data();
|
||||
|
||||
- new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr);
|
||||
+ new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr, nullptr);
|
||||
} else {
|
||||
new_type = cur->type;
|
||||
new_data = cur->data;
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=5a4cd6741fc33227cdacb329f355ab21f8481de2
|
||||
LLAMA_VERSION?=187a45637054881ecacf17f8e2f6f8f2ba7df1c7
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -10,6 +10,14 @@
|
||||
#include "server-task.cpp"
|
||||
#include "server-queue.cpp"
|
||||
#include "server-common.cpp"
|
||||
// server-chat.cpp exists only in llama.cpp after the upstream refactor that
|
||||
// split OAI/Anthropic/Responses/transcription conversion helpers out of
|
||||
// server-common.cpp. When present, server-context.cpp and server-task.cpp
|
||||
// above call into it, so we must pull its definitions into this TU or the
|
||||
// link fails. __has_include keeps the source compatible with older pins.
|
||||
#if __has_include("server-chat.cpp")
|
||||
#include "server-chat.cpp"
|
||||
#endif
|
||||
#include "server-context.cpp"
|
||||
|
||||
// LocalAI
|
||||
|
||||
@@ -4,7 +4,6 @@ package main
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"container/heap"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
@@ -100,9 +99,16 @@ func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
||||
}
|
||||
|
||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||
if opts.Model != "" {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
// local-store is an in-memory vector store with no on-disk artefact to
|
||||
// load — opts.Model is just a namespace identifier. The old `!= ""` guard
|
||||
// rejected any non-empty model name with "not implemented", which broke
|
||||
// callers that pass a namespace to isolate embedding spaces (face vs.
|
||||
// voice biometrics both go through local-store but need distinct stores
|
||||
// so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace
|
||||
// isolation is already handled upstream: ModelLoader spawns a fresh
|
||||
// local-store process per (backend, model) tuple, so each namespace is
|
||||
// its own Store{} instance. Nothing to do here beyond accepting the load.
|
||||
_ = opts
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
11
backend/go/sherpa-onnx/.gitignore
vendored
Normal file
11
backend/go/sherpa-onnx/.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
.cache/
|
||||
sources/
|
||||
build*/
|
||||
package/
|
||||
backend-assets/
|
||||
sherpa-onnx
|
||||
*.so
|
||||
compile_commands.json
|
||||
sherpa-onnx-whisper-*
|
||||
vits-ljs/
|
||||
streaming-zipformer-en/
|
||||
120
backend/go/sherpa-onnx/Makefile
Normal file
120
backend/go/sherpa-onnx/Makefile
Normal file
@@ -0,0 +1,120 @@
|
||||
CURRENT_DIR=$(abspath ./)
|
||||
GOCMD=go
|
||||
|
||||
ONNX_VERSION?=1.24.4
|
||||
# v1.12.39 — includes upstream's onnxruntime 1.24.4 bump (#3501). Earlier
|
||||
# pinned commits only support onnxruntime 1.23.2, which has no CUDA 13
|
||||
# pre-built tarball, blocking the -gpu-nvidia-cuda-13 build matrix entry.
|
||||
SHERPA_COMMIT?=7288d15e3e31a7bd589b2ba88828d521e7a6b140
|
||||
ONNX_ARCH?=x64
|
||||
ONNX_OS?=linux
|
||||
|
||||
ifneq (,$(findstring aarch64,$(shell uname -m)))
|
||||
ONNX_ARCH=aarch64
|
||||
endif
|
||||
|
||||
ifeq ($(OS),Darwin)
|
||||
ONNX_OS=osx
|
||||
ifneq (,$(findstring aarch64,$(shell uname -m)))
|
||||
ONNX_ARCH=arm64
|
||||
else ifneq (,$(findstring arm64,$(shell uname -m)))
|
||||
ONNX_ARCH=arm64
|
||||
else
|
||||
ONNX_ARCH=x86_64
|
||||
endif
|
||||
endif
|
||||
|
||||
# Upstream onnxruntime ships CUDA 12 and CUDA 13 variants under different
|
||||
# names: -gpu-<ver>.tgz for CUDA 12, -gpu_cuda13-<ver>.tgz for CUDA 13
|
||||
# (note underscore vs dash). CUDA 13 tarballs only exist from 1.24.x onward.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
SHERPA_GPU=ON
|
||||
ONNX_PROVIDER=cuda
|
||||
ifeq ($(CUDA_MAJOR_VERSION),13)
|
||||
ONNX_VARIANT=-gpu_cuda13
|
||||
else
|
||||
ONNX_VARIANT=-gpu
|
||||
endif
|
||||
else
|
||||
ONNX_VARIANT=
|
||||
SHERPA_GPU=OFF
|
||||
ONNX_PROVIDER=cpu
|
||||
endif
|
||||
|
||||
JOBS?=$(shell nproc --ignore=1 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
sources/onnxruntime:
|
||||
mkdir -p sources/onnxruntime
|
||||
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)$(ONNX_VARIANT)-$(ONNX_VERSION).tgz \
|
||||
-o sources/onnxruntime/onnxruntime.tgz
|
||||
cd sources/onnxruntime && tar -xf onnxruntime.tgz --strip-components=1 && rm onnxruntime.tgz
|
||||
|
||||
sources/sherpa-onnx: sources/onnxruntime
|
||||
git clone https://github.com/k2-fsa/sherpa-onnx.git sources/sherpa-onnx
|
||||
cd sources/sherpa-onnx && git checkout $(SHERPA_COMMIT)
|
||||
mkdir -p sources/sherpa-onnx/build
|
||||
# sherpa-onnx's cmake detects a pre-installed onnxruntime via the
|
||||
# SHERPA_ONNXRUNTIME_{INCLUDE,LIB}_DIR env vars (not via -D flags).
|
||||
# Point them at our locally-downloaded Microsoft tarball — without
|
||||
# this, sherpa-onnx falls through to download_onnxruntime() which
|
||||
# fetches from csukuangfj/onnxruntime-libs. For the GPU 1.24.4
|
||||
# build that release mirror publishes `-patched.zip` instead of the
|
||||
# expected `.tgz`, so the download 404s and the build fails.
|
||||
cd sources/sherpa-onnx/build && \
|
||||
SHERPA_ONNXRUNTIME_INCLUDE_DIR=$(CURRENT_DIR)/sources/onnxruntime/include \
|
||||
SHERPA_ONNXRUNTIME_LIB_DIR=$(CURRENT_DIR)/sources/onnxruntime/lib \
|
||||
cmake \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_C_FLAGS="-Wno-error=format-security" \
|
||||
-DCMAKE_CXX_FLAGS="-Wno-error=format-security" \
|
||||
-DSHERPA_ONNX_ENABLE_GPU=$(SHERPA_GPU) \
|
||||
-DSHERPA_ONNX_ENABLE_TTS=ON \
|
||||
-DSHERPA_ONNX_ENABLE_BINARY=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_C_API=ON \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DSHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE=ON \
|
||||
..
|
||||
cd sources/sherpa-onnx/build && make -j$(JOBS)
|
||||
|
||||
backend-assets/lib: sources/sherpa-onnx sources/onnxruntime
|
||||
mkdir -p backend-assets/lib
|
||||
cp -rfLv sources/onnxruntime/lib/* backend-assets/lib/
|
||||
cp -rfLv sources/sherpa-onnx/build/lib/*.so* backend-assets/lib/ 2>/dev/null || true
|
||||
cp -rfLv sources/sherpa-onnx/build/lib/*.dylib backend-assets/lib/ 2>/dev/null || true
|
||||
|
||||
# libsherpa-shim wraps sherpa-onnx's nested config structs and TTS
|
||||
# callback plumbing behind a purego-friendly API: opaque handles plus
|
||||
# fixed-signature setters/getters/trampoline. Plain C compile — no cgo.
|
||||
SHIM_EXT=so
|
||||
ifeq ($(OS),Darwin)
|
||||
SHIM_EXT=dylib
|
||||
endif
|
||||
|
||||
backend-assets/lib/libsherpa-shim.$(SHIM_EXT): csrc/shim.c csrc/shim.h backend-assets/lib
|
||||
$(CC) -shared -fPIC -O2 \
|
||||
-I$(CURRENT_DIR)/sources/sherpa-onnx/sherpa-onnx/c-api \
|
||||
-o $@ csrc/shim.c \
|
||||
-L$(CURRENT_DIR)/backend-assets/lib \
|
||||
-lsherpa-onnx-c-api \
|
||||
-Wl,-rpath,'$$ORIGIN'
|
||||
|
||||
sherpa-onnx: backend-assets/lib backend-assets/lib/libsherpa-shim.$(SHIM_EXT)
|
||||
CGO_ENABLED=0 $(GOCMD) build \
|
||||
-ldflags "$(LD_FLAGS) -X main.onnxProvider=$(ONNX_PROVIDER)" \
|
||||
-tags "$(GO_TAGS)" -o sherpa-onnx ./
|
||||
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
build: sherpa-onnx package
|
||||
|
||||
clean:
|
||||
rm -rf sherpa-onnx sources/ backend-assets/ package/ vits-ljs/ sherpa-onnx-whisper-*/
|
||||
|
||||
test: sherpa-onnx
|
||||
LD_LIBRARY_PATH=$(CURRENT_DIR)/backend-assets/lib \
|
||||
bash test.sh
|
||||
|
||||
.PHONY: build package clean test
|
||||
1249
backend/go/sherpa-onnx/backend.go
Normal file
1249
backend/go/sherpa-onnx/backend.go
Normal file
File diff suppressed because it is too large
Load Diff
169
backend/go/sherpa-onnx/backend_test.go
Normal file
169
backend/go/sherpa-onnx/backend_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestSherpaBackend(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Sherpa-ONNX Backend Suite")
|
||||
}
|
||||
|
||||
// Load libsherpa-shim + libsherpa-onnx-c-api via purego before any spec
|
||||
// runs — otherwise any Load/TTS/VAD/AudioTranscription call hits a nil
|
||||
// function pointer. LD_LIBRARY_PATH must contain the directory holding
|
||||
// both .so files; test.sh sets this.
|
||||
var _ = BeforeSuite(func() {
|
||||
Expect(loadSherpaLibs()).To(Succeed())
|
||||
})
|
||||
|
||||
var _ = Describe("Sherpa-ONNX", func() {
|
||||
Context("lifecycle", func() {
|
||||
It("is locking (C API is not thread safe)", func() {
|
||||
Expect((&SherpaBackend{}).Locking()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("errors loading a non-existent model", func() {
|
||||
tmpDir, err := os.MkdirTemp("", "sherpa-test-nonexistent")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = (&SherpaBackend{}).Load(&pb.ModelOptions{
|
||||
ModelFile: filepath.Join(tmpDir, "non-existent-model.onnx"),
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors loading a non-existent ASR model", func() {
|
||||
tmpDir, err := os.MkdirTemp("", "sherpa-test-asr")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = (&SherpaBackend{}).Load(&pb.ModelOptions{
|
||||
ModelFile: filepath.Join(tmpDir, "model.onnx"),
|
||||
Type: "asr",
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("dispatches Load by Type", func() {
|
||||
tmpDir, err := os.MkdirTemp("", "sherpa-test-dispatch")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
modelFile := filepath.Join(tmpDir, "model.onnx")
|
||||
for _, typ := range []string{"", "asr", "vad"} {
|
||||
err := (&SherpaBackend{}).Load(&pb.ModelOptions{ModelFile: modelFile, Type: typ})
|
||||
Expect(err).To(HaveOccurred(), "Type=%q", typ)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("method errors without loaded model", func() {
|
||||
It("rejects TTS", func() {
|
||||
tmpDir, err := os.MkdirTemp("", "sherpa-test-tts")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = (&SherpaBackend{}).TTS(&pb.TTSRequest{
|
||||
Text: "should fail — no model loaded",
|
||||
Dst: filepath.Join(tmpDir, "output.wav"),
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects AudioTranscription", func() {
|
||||
_, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{
|
||||
Dst: "/tmp/nonexistent.wav",
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects VAD", func() {
|
||||
_, err := (&SherpaBackend{}).VAD(&pb.VADRequest{
|
||||
Audio: []float32{0.1, 0.2, 0.3},
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("type detection", func() {
|
||||
DescribeTable("isASRType",
|
||||
func(input string, want bool) {
|
||||
Expect(isASRType(input)).To(Equal(want))
|
||||
},
|
||||
Entry("asr", "asr", true),
|
||||
Entry("ASR", "ASR", true),
|
||||
Entry("Asr", "Asr", true),
|
||||
Entry("transcription", "transcription", true),
|
||||
Entry("Transcription", "Transcription", true),
|
||||
Entry("transcribe", "transcribe", true),
|
||||
Entry("Transcribe", "Transcribe", true),
|
||||
Entry("tts", "tts", false),
|
||||
Entry("empty", "", false),
|
||||
Entry("other", "other", false),
|
||||
Entry("vad", "vad", false),
|
||||
)
|
||||
|
||||
DescribeTable("isVADType",
|
||||
func(input string, want bool) {
|
||||
Expect(isVADType(input)).To(Equal(want))
|
||||
},
|
||||
Entry("vad", "vad", true),
|
||||
Entry("VAD", "VAD", true),
|
||||
Entry("Vad", "Vad", true),
|
||||
Entry("asr", "asr", false),
|
||||
Entry("tts", "tts", false),
|
||||
Entry("empty", "", false),
|
||||
Entry("other", "other", false),
|
||||
)
|
||||
})
|
||||
|
||||
Context("option parsing", func() {
|
||||
It("parses float options with fallback on bad input", func() {
|
||||
opts := &pb.ModelOptions{Options: []string{
|
||||
"vad.threshold=0.3",
|
||||
"tts.length_scale=1.25",
|
||||
"bad.number=not-a-float",
|
||||
}}
|
||||
Expect(findOptionFloat(opts, "vad.threshold=", 0.5)).To(BeNumerically("~", 0.3, 1e-6))
|
||||
Expect(findOptionFloat(opts, "tts.length_scale=", 1.0)).To(BeNumerically("~", 1.25, 1e-6))
|
||||
Expect(findOptionFloat(opts, "missing.key=", 0.7)).To(BeNumerically("~", 0.7, 1e-6))
|
||||
Expect(findOptionFloat(opts, "bad.number=", 9.9)).To(BeNumerically("~", 9.9, 1e-6))
|
||||
})
|
||||
|
||||
It("parses int options with fallback on bad input", func() {
|
||||
opts := &pb.ModelOptions{Options: []string{
|
||||
"asr.sample_rate=22050",
|
||||
"online.chunk_samples=800",
|
||||
"bad.int=4.2",
|
||||
}}
|
||||
Expect(findOptionInt(opts, "asr.sample_rate=", 16000)).To(Equal(int32(22050)))
|
||||
Expect(findOptionInt(opts, "online.chunk_samples=", 1600)).To(Equal(int32(800)))
|
||||
Expect(findOptionInt(opts, "missing.key=", 42)).To(Equal(int32(42)))
|
||||
Expect(findOptionInt(opts, "bad.int=", 100)).To(Equal(int32(100)))
|
||||
})
|
||||
|
||||
It("parses bool options (0/1, true/false, yes/no, on/off)", func() {
|
||||
opts := &pb.ModelOptions{Options: []string{
|
||||
"online.enable_endpoint=0",
|
||||
"asr.sense_voice.use_itn=True",
|
||||
"feature.on=yes",
|
||||
"feature.off=Off",
|
||||
"feature.bad=maybe",
|
||||
}}
|
||||
Expect(findOptionBool(opts, "online.enable_endpoint=", 1)).To(Equal(int32(0)))
|
||||
Expect(findOptionBool(opts, "asr.sense_voice.use_itn=", 0)).To(Equal(int32(1)))
|
||||
Expect(findOptionBool(opts, "feature.on=", 0)).To(Equal(int32(1)))
|
||||
Expect(findOptionBool(opts, "feature.off=", 1)).To(Equal(int32(0)))
|
||||
Expect(findOptionBool(opts, "feature.bad=", 1)).To(Equal(int32(1)))
|
||||
Expect(findOptionBool(opts, "missing.key=", 1)).To(Equal(int32(1)))
|
||||
})
|
||||
})
|
||||
})
|
||||
325
backend/go/sherpa-onnx/csrc/shim.c
Normal file
325
backend/go/sherpa-onnx/csrc/shim.c
Normal file
@@ -0,0 +1,325 @@
|
||||
#include "shim.h"
|
||||
#include "c-api.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
// Replace the char* field pointed to by `slot` with a strdup of `s`
|
||||
// (or NULL if s is NULL). Frees any prior value. Silently no-ops when
|
||||
// strdup fails — the caller will see a Create* failure downstream.
|
||||
static void shim_set_str(const char **slot, const char *s) {
|
||||
free((char *)*slot);
|
||||
*slot = s ? strdup(s) : NULL;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// VAD config
|
||||
// ==================================================================
|
||||
|
||||
void *sherpa_shim_vad_config_new(void) {
|
||||
return calloc(1, sizeof(SherpaOnnxVadModelConfig));
|
||||
}
|
||||
|
||||
void sherpa_shim_vad_config_free(void *h) {
|
||||
if (!h) return;
|
||||
SherpaOnnxVadModelConfig *c = (SherpaOnnxVadModelConfig *)h;
|
||||
free((char *)c->silero_vad.model);
|
||||
free((char *)c->provider);
|
||||
free(c);
|
||||
}
|
||||
|
||||
void sherpa_shim_vad_config_set_silero_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxVadModelConfig *)h)->silero_vad.model, v);
|
||||
}
|
||||
void sherpa_shim_vad_config_set_silero_threshold(void *h, float v) {
|
||||
((SherpaOnnxVadModelConfig *)h)->silero_vad.threshold = v;
|
||||
}
|
||||
void sherpa_shim_vad_config_set_silero_min_silence_duration(void *h, float v) {
|
||||
((SherpaOnnxVadModelConfig *)h)->silero_vad.min_silence_duration = v;
|
||||
}
|
||||
void sherpa_shim_vad_config_set_silero_min_speech_duration(void *h, float v) {
|
||||
((SherpaOnnxVadModelConfig *)h)->silero_vad.min_speech_duration = v;
|
||||
}
|
||||
void sherpa_shim_vad_config_set_silero_window_size(void *h, int32_t v) {
|
||||
((SherpaOnnxVadModelConfig *)h)->silero_vad.window_size = v;
|
||||
}
|
||||
void sherpa_shim_vad_config_set_silero_max_speech_duration(void *h, float v) {
|
||||
((SherpaOnnxVadModelConfig *)h)->silero_vad.max_speech_duration = v;
|
||||
}
|
||||
void sherpa_shim_vad_config_set_sample_rate(void *h, int32_t v) {
|
||||
((SherpaOnnxVadModelConfig *)h)->sample_rate = v;
|
||||
}
|
||||
void sherpa_shim_vad_config_set_num_threads(void *h, int32_t v) {
|
||||
((SherpaOnnxVadModelConfig *)h)->num_threads = v;
|
||||
}
|
||||
void sherpa_shim_vad_config_set_provider(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxVadModelConfig *)h)->provider, v);
|
||||
}
|
||||
void sherpa_shim_vad_config_set_debug(void *h, int32_t v) {
|
||||
((SherpaOnnxVadModelConfig *)h)->debug = v;
|
||||
}
|
||||
|
||||
void *sherpa_shim_create_vad(void *h, float buffer_size_seconds) {
|
||||
return (void *)SherpaOnnxCreateVoiceActivityDetector(
|
||||
(const SherpaOnnxVadModelConfig *)h, buffer_size_seconds);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Offline TTS config (VITS)
|
||||
// ==================================================================
|
||||
|
||||
void *sherpa_shim_tts_config_new(void) {
|
||||
return calloc(1, sizeof(SherpaOnnxOfflineTtsConfig));
|
||||
}
|
||||
|
||||
void sherpa_shim_tts_config_free(void *h) {
|
||||
if (!h) return;
|
||||
SherpaOnnxOfflineTtsConfig *c = (SherpaOnnxOfflineTtsConfig *)h;
|
||||
free((char *)c->model.vits.model);
|
||||
free((char *)c->model.vits.tokens);
|
||||
free((char *)c->model.vits.lexicon);
|
||||
free((char *)c->model.vits.data_dir);
|
||||
free((char *)c->model.provider);
|
||||
free(c);
|
||||
}
|
||||
|
||||
void sherpa_shim_tts_config_set_vits_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.vits.model, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_vits_tokens(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.vits.tokens, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_vits_lexicon(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.vits.lexicon, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_vits_data_dir(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.vits.data_dir, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_vits_noise_scale(void *h, float v) {
|
||||
((SherpaOnnxOfflineTtsConfig *)h)->model.vits.noise_scale = v;
|
||||
}
|
||||
void sherpa_shim_tts_config_set_vits_noise_scale_w(void *h, float v) {
|
||||
((SherpaOnnxOfflineTtsConfig *)h)->model.vits.noise_scale_w = v;
|
||||
}
|
||||
void sherpa_shim_tts_config_set_vits_length_scale(void *h, float v) {
|
||||
((SherpaOnnxOfflineTtsConfig *)h)->model.vits.length_scale = v;
|
||||
}
|
||||
void sherpa_shim_tts_config_set_num_threads(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineTtsConfig *)h)->model.num_threads = v;
|
||||
}
|
||||
void sherpa_shim_tts_config_set_debug(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineTtsConfig *)h)->model.debug = v;
|
||||
}
|
||||
void sherpa_shim_tts_config_set_provider(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.provider, v);
|
||||
}
|
||||
void sherpa_shim_tts_config_set_max_num_sentences(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineTtsConfig *)h)->max_num_sentences = v;
|
||||
}
|
||||
|
||||
void *sherpa_shim_create_offline_tts(void *h) {
|
||||
return (void *)SherpaOnnxCreateOfflineTts(
|
||||
(const SherpaOnnxOfflineTtsConfig *)h);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Offline recognizer config
|
||||
// ==================================================================
|
||||
|
||||
void *sherpa_shim_offline_recog_config_new(void) {
|
||||
return calloc(1, sizeof(SherpaOnnxOfflineRecognizerConfig));
|
||||
}
|
||||
|
||||
void sherpa_shim_offline_recog_config_free(void *h) {
|
||||
if (!h) return;
|
||||
SherpaOnnxOfflineRecognizerConfig *c = (SherpaOnnxOfflineRecognizerConfig *)h;
|
||||
free((char *)c->model_config.provider);
|
||||
free((char *)c->model_config.tokens);
|
||||
free((char *)c->model_config.whisper.encoder);
|
||||
free((char *)c->model_config.whisper.decoder);
|
||||
free((char *)c->model_config.whisper.language);
|
||||
free((char *)c->model_config.whisper.task);
|
||||
free((char *)c->model_config.paraformer.model);
|
||||
free((char *)c->model_config.sense_voice.model);
|
||||
free((char *)c->model_config.sense_voice.language);
|
||||
free((char *)c->model_config.omnilingual.model);
|
||||
free((char *)c->decoding_method);
|
||||
free(c);
|
||||
}
|
||||
|
||||
void sherpa_shim_offline_recog_config_set_num_threads(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.num_threads = v;
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_debug(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.debug = v;
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_provider(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.provider, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_tokens(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.tokens, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_feat_sample_rate(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineRecognizerConfig *)h)->feat_config.sample_rate = v;
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_feat_feature_dim(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineRecognizerConfig *)h)->feat_config.feature_dim = v;
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_decoding_method(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->decoding_method, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_whisper_encoder(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.encoder, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_whisper_decoder(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.decoder, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_whisper_language(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.language, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_whisper_task(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.task, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_whisper_tail_paddings(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.tail_paddings = v;
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_paraformer_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.paraformer.model, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_sense_voice_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.sense_voice.model, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_sense_voice_language(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.sense_voice.language, v);
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_sense_voice_use_itn(void *h, int32_t v) {
|
||||
((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.sense_voice.use_itn = v;
|
||||
}
|
||||
void sherpa_shim_offline_recog_config_set_omnilingual_model(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.omnilingual.model, v);
|
||||
}
|
||||
|
||||
void *sherpa_shim_create_offline_recognizer(void *h) {
|
||||
return (void *)SherpaOnnxCreateOfflineRecognizer(
|
||||
(const SherpaOnnxOfflineRecognizerConfig *)h);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Online recognizer config
|
||||
// ==================================================================
|
||||
|
||||
void *sherpa_shim_online_recog_config_new(void) {
|
||||
return calloc(1, sizeof(SherpaOnnxOnlineRecognizerConfig));
|
||||
}
|
||||
|
||||
void sherpa_shim_online_recog_config_free(void *h) {
|
||||
if (!h) return;
|
||||
SherpaOnnxOnlineRecognizerConfig *c = (SherpaOnnxOnlineRecognizerConfig *)h;
|
||||
free((char *)c->model_config.transducer.encoder);
|
||||
free((char *)c->model_config.transducer.decoder);
|
||||
free((char *)c->model_config.transducer.joiner);
|
||||
free((char *)c->model_config.tokens);
|
||||
free((char *)c->model_config.provider);
|
||||
free((char *)c->decoding_method);
|
||||
free(c);
|
||||
}
|
||||
|
||||
void sherpa_shim_online_recog_config_set_transducer_encoder(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.transducer.encoder, v);
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_transducer_decoder(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.transducer.decoder, v);
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_transducer_joiner(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.transducer.joiner, v);
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_tokens(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.tokens, v);
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_num_threads(void *h, int32_t v) {
|
||||
((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.num_threads = v;
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_debug(void *h, int32_t v) {
|
||||
((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.debug = v;
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_provider(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.provider, v);
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_feat_sample_rate(void *h, int32_t v) {
|
||||
((SherpaOnnxOnlineRecognizerConfig *)h)->feat_config.sample_rate = v;
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_feat_feature_dim(void *h, int32_t v) {
|
||||
((SherpaOnnxOnlineRecognizerConfig *)h)->feat_config.feature_dim = v;
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_decoding_method(void *h, const char *v) {
|
||||
shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->decoding_method, v);
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_enable_endpoint(void *h, int32_t v) {
|
||||
((SherpaOnnxOnlineRecognizerConfig *)h)->enable_endpoint = v;
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_rule1_min_trailing_silence(void *h, float v) {
|
||||
((SherpaOnnxOnlineRecognizerConfig *)h)->rule1_min_trailing_silence = v;
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_rule2_min_trailing_silence(void *h, float v) {
|
||||
((SherpaOnnxOnlineRecognizerConfig *)h)->rule2_min_trailing_silence = v;
|
||||
}
|
||||
void sherpa_shim_online_recog_config_set_rule3_min_utterance_length(void *h, float v) {
|
||||
((SherpaOnnxOnlineRecognizerConfig *)h)->rule3_min_utterance_length = v;
|
||||
}
|
||||
|
||||
void *sherpa_shim_create_online_recognizer(void *h) {
|
||||
return (void *)SherpaOnnxCreateOnlineRecognizer(
|
||||
(const SherpaOnnxOnlineRecognizerConfig *)h);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Result-struct accessors
|
||||
// ==================================================================
|
||||
|
||||
int32_t sherpa_shim_wave_sample_rate(const void *h) {
|
||||
return ((const SherpaOnnxWave *)h)->sample_rate;
|
||||
}
|
||||
int32_t sherpa_shim_wave_num_samples(const void *h) {
|
||||
return ((const SherpaOnnxWave *)h)->num_samples;
|
||||
}
|
||||
const float *sherpa_shim_wave_samples(const void *h) {
|
||||
return ((const SherpaOnnxWave *)h)->samples;
|
||||
}
|
||||
|
||||
const char *sherpa_shim_offline_result_text(const void *h) {
|
||||
return ((const SherpaOnnxOfflineRecognizerResult *)h)->text;
|
||||
}
|
||||
const char *sherpa_shim_online_result_text(const void *h) {
|
||||
return ((const SherpaOnnxOnlineRecognizerResult *)h)->text;
|
||||
}
|
||||
|
||||
int32_t sherpa_shim_generated_audio_sample_rate(const void *h) {
|
||||
return ((const SherpaOnnxGeneratedAudio *)h)->sample_rate;
|
||||
}
|
||||
int32_t sherpa_shim_generated_audio_n(const void *h) {
|
||||
return ((const SherpaOnnxGeneratedAudio *)h)->n;
|
||||
}
|
||||
const float *sherpa_shim_generated_audio_samples(const void *h) {
|
||||
return ((const SherpaOnnxGeneratedAudio *)h)->samples;
|
||||
}
|
||||
|
||||
int32_t sherpa_shim_speech_segment_start(const void *h) {
|
||||
return ((const SherpaOnnxSpeechSegment *)h)->start;
|
||||
}
|
||||
int32_t sherpa_shim_speech_segment_n(const void *h) {
|
||||
return ((const SherpaOnnxSpeechSegment *)h)->n;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// TTS streaming callback trampoline
|
||||
// ==================================================================
|
||||
|
||||
void *sherpa_shim_tts_generate_with_callback(
|
||||
void *tts, const char *text, int32_t sid, float speed,
|
||||
uintptr_t callback_ptr, uintptr_t user_data) {
|
||||
SherpaOnnxGeneratedAudioCallbackWithArg cb =
|
||||
(SherpaOnnxGeneratedAudioCallbackWithArg)callback_ptr;
|
||||
return (void *)SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
|
||||
(const SherpaOnnxOfflineTts *)tts, text, sid, speed, cb,
|
||||
(void *)user_data);
|
||||
}
|
||||
129
backend/go/sherpa-onnx/csrc/shim.h
Normal file
129
backend/go/sherpa-onnx/csrc/shim.h
Normal file
@@ -0,0 +1,129 @@
|
||||
#ifndef LOCALAI_SHERPA_ONNX_SHIM_H
|
||||
#define LOCALAI_SHERPA_ONNX_SHIM_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
// libsherpa-shim: purego-friendly wrapper around sherpa-onnx's C API.
|
||||
// Purego can't access C struct fields and can't route C callbacks to Go
|
||||
// funcs directly. Every function here is a fixed-signature trampoline
|
||||
// that replaces one field read/write or callback handoff that the Go
|
||||
// backend would otherwise have to do through cgo.
|
||||
//
|
||||
// String lifetime: setters strdup; _free walks every owned string and
|
||||
// frees it. Callers may discard their input buffers the moment a setter
|
||||
// returns.
|
||||
//
|
||||
// Opaque handles are `void *` in both directions. Nothing here holds a
|
||||
// reference across calls except config handles (freed via _free) and
|
||||
// sherpa-allocated results (freed via sherpa's own Destroy* entry
|
||||
// points, which Go calls through purego pass-through).
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// --- VAD config -----------------------------------------------------
|
||||
void *sherpa_shim_vad_config_new(void);
|
||||
void sherpa_shim_vad_config_free(void *cfg);
|
||||
void sherpa_shim_vad_config_set_silero_model(void *cfg, const char *path);
|
||||
void sherpa_shim_vad_config_set_silero_threshold(void *cfg, float v);
|
||||
void sherpa_shim_vad_config_set_silero_min_silence_duration(void *cfg, float v);
|
||||
void sherpa_shim_vad_config_set_silero_min_speech_duration(void *cfg, float v);
|
||||
void sherpa_shim_vad_config_set_silero_window_size(void *cfg, int32_t v);
|
||||
void sherpa_shim_vad_config_set_silero_max_speech_duration(void *cfg, float v);
|
||||
void sherpa_shim_vad_config_set_sample_rate(void *cfg, int32_t v);
|
||||
void sherpa_shim_vad_config_set_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_vad_config_set_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_vad_config_set_debug(void *cfg, int32_t v);
|
||||
void *sherpa_shim_create_vad(void *cfg, float buffer_size_seconds);
|
||||
|
||||
// --- Offline TTS config (VITS path — the only TTS family the backend uses) ---
|
||||
void *sherpa_shim_tts_config_new(void);
|
||||
void sherpa_shim_tts_config_free(void *cfg);
|
||||
void sherpa_shim_tts_config_set_vits_model(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_vits_tokens(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_vits_lexicon(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_vits_data_dir(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_vits_noise_scale(void *cfg, float v);
|
||||
void sherpa_shim_tts_config_set_vits_noise_scale_w(void *cfg, float v);
|
||||
void sherpa_shim_tts_config_set_vits_length_scale(void *cfg, float v);
|
||||
void sherpa_shim_tts_config_set_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_tts_config_set_debug(void *cfg, int32_t v);
|
||||
void sherpa_shim_tts_config_set_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_tts_config_set_max_num_sentences(void *cfg, int32_t v);
|
||||
void *sherpa_shim_create_offline_tts(void *cfg);
|
||||
|
||||
// --- Offline recognizer config (Whisper / Paraformer / SenseVoice / Omnilingual) ---
|
||||
void *sherpa_shim_offline_recog_config_new(void);
|
||||
void sherpa_shim_offline_recog_config_free(void *cfg);
|
||||
void sherpa_shim_offline_recog_config_set_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_offline_recog_config_set_debug(void *cfg, int32_t v);
|
||||
void sherpa_shim_offline_recog_config_set_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_tokens(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_feat_sample_rate(void *cfg, int32_t v);
|
||||
void sherpa_shim_offline_recog_config_set_feat_feature_dim(void *cfg, int32_t v);
|
||||
void sherpa_shim_offline_recog_config_set_decoding_method(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_whisper_encoder(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_whisper_decoder(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_whisper_language(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_whisper_task(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_whisper_tail_paddings(void *cfg, int32_t v);
|
||||
void sherpa_shim_offline_recog_config_set_paraformer_model(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_sense_voice_model(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_sense_voice_language(void *cfg, const char *v);
|
||||
void sherpa_shim_offline_recog_config_set_sense_voice_use_itn(void *cfg, int32_t v);
|
||||
void sherpa_shim_offline_recog_config_set_omnilingual_model(void *cfg, const char *v);
|
||||
void *sherpa_shim_create_offline_recognizer(void *cfg);
|
||||
|
||||
// --- Online recognizer config (streaming zipformer transducer) ---
|
||||
void *sherpa_shim_online_recog_config_new(void);
|
||||
void sherpa_shim_online_recog_config_free(void *cfg);
|
||||
void sherpa_shim_online_recog_config_set_transducer_encoder(void *cfg, const char *v);
|
||||
void sherpa_shim_online_recog_config_set_transducer_decoder(void *cfg, const char *v);
|
||||
void sherpa_shim_online_recog_config_set_transducer_joiner(void *cfg, const char *v);
|
||||
void sherpa_shim_online_recog_config_set_tokens(void *cfg, const char *v);
|
||||
void sherpa_shim_online_recog_config_set_num_threads(void *cfg, int32_t v);
|
||||
void sherpa_shim_online_recog_config_set_debug(void *cfg, int32_t v);
|
||||
void sherpa_shim_online_recog_config_set_provider(void *cfg, const char *v);
|
||||
void sherpa_shim_online_recog_config_set_feat_sample_rate(void *cfg, int32_t v);
|
||||
void sherpa_shim_online_recog_config_set_feat_feature_dim(void *cfg, int32_t v);
|
||||
void sherpa_shim_online_recog_config_set_decoding_method(void *cfg, const char *v);
|
||||
void sherpa_shim_online_recog_config_set_enable_endpoint(void *cfg, int32_t v);
|
||||
void sherpa_shim_online_recog_config_set_rule1_min_trailing_silence(void *cfg, float v);
|
||||
void sherpa_shim_online_recog_config_set_rule2_min_trailing_silence(void *cfg, float v);
|
||||
void sherpa_shim_online_recog_config_set_rule3_min_utterance_length(void *cfg, float v);
|
||||
void *sherpa_shim_create_online_recognizer(void *cfg);
|
||||
|
||||
// --- Result accessors (sherpa-allocated; caller destroys via sherpa's own Destroy*) ---
|
||||
int32_t sherpa_shim_wave_sample_rate(const void *wave);
|
||||
int32_t sherpa_shim_wave_num_samples(const void *wave);
|
||||
const float *sherpa_shim_wave_samples(const void *wave);
|
||||
|
||||
const char *sherpa_shim_offline_result_text(const void *result);
|
||||
const char *sherpa_shim_online_result_text(const void *result);
|
||||
|
||||
int32_t sherpa_shim_generated_audio_sample_rate(const void *audio);
|
||||
int32_t sherpa_shim_generated_audio_n(const void *audio);
|
||||
const float *sherpa_shim_generated_audio_samples(const void *audio);
|
||||
|
||||
int32_t sherpa_shim_speech_segment_start(const void *seg);
|
||||
int32_t sherpa_shim_speech_segment_n(const void *seg);
|
||||
|
||||
// --- TTS streaming callback trampoline -----------------------------
|
||||
// Replaces the //export sherpaTtsGoCallback + callbacks.c bridge pattern.
|
||||
// `callback_ptr` is the C-callable function pointer returned by
|
||||
// purego.NewCallback. `user_data` is an integer the Go side uses to
|
||||
// look up its state (sync.Map keyed by uint64).
|
||||
//
|
||||
// Returns the sherpa-allocated SherpaOnnxGeneratedAudio. Destroy with
|
||||
// SherpaOnnxDestroyOfflineTtsGeneratedAudio (callable directly from
|
||||
// Go via purego).
|
||||
void *sherpa_shim_tts_generate_with_callback(
|
||||
void *tts, const char *text, int32_t sid, float speed,
|
||||
uintptr_t callback_ptr, uintptr_t user_data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
23
backend/go/sherpa-onnx/main.go
Normal file
23
backend/go/sherpa-onnx/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := loadSherpaLibs(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := grpc.StartServer(*addr, &SherpaBackend{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
51
backend/go/sherpa-onnx/package.sh
Executable file
51
backend/go/sherpa-onnx/package.sh
Executable file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/sherpa-onnx $CURDIR/package/
|
||||
cp -avf $CURDIR/run.sh $CURDIR/package/
|
||||
cp -rfLv $CURDIR/backend-assets/lib/* $CURDIR/package/lib/
|
||||
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
13
backend/go/sherpa-onnx/run.sh
Executable file
13
backend/go/sherpa-onnx/run.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/sherpa-onnx "$@"
|
||||
fi
|
||||
|
||||
exec $CURDIR/sherpa-onnx "$@"
|
||||
12
backend/go/sherpa-onnx/test.sh
Executable file
12
backend/go/sherpa-onnx/test.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
# Unit tests for the sherpa-onnx backend. Exercises error-path and
|
||||
# dispatch logic via SherpaBackend directly (no gRPC). Integration
|
||||
# coverage (gRPC TTS / streaming ASR / realtime pipeline) lives in
|
||||
# tests/e2e-backends and tests/e2e and runs against the Docker image.
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
cd "$CURDIR"
|
||||
|
||||
PACKAGES=$(go list ./... | grep -v /sources/)
|
||||
go test -v -timeout 60s $PACKAGES
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=44cca3d626d301e2215d5e243277e8f0e65bfa78
|
||||
STABLEDIFFUSION_GGML_VERSION?=c97702e1057c2fe13a7074cd9069cb9dd6edc1bf
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -1006,6 +1006,23 @@
|
||||
nvidia: "cuda12-neutts"
|
||||
amd: "rocm-neutts"
|
||||
nvidia-cuda-12: "cuda12-neutts"
|
||||
- &sherpa-onnx
|
||||
name: "sherpa-onnx"
|
||||
alias: "sherpa-onnx"
|
||||
urls:
|
||||
- https://k2-fsa.github.io/sherpa/onnx/
|
||||
description: |
|
||||
Sherpa-ONNX backend for text-to-speech (VITS, Matcha, Kokoro), speech-to-text (Whisper, Paraformer, SenseVoice, Omnilingual ASR CTC), and voice activity detection via ONNX Runtime.
|
||||
Supports multi-speaker voices, 1600+ language ASR, and GPU acceleration.
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
- speech-to-text
|
||||
- ASR
|
||||
capabilities:
|
||||
default: "cpu-sherpa-onnx"
|
||||
nvidia: "cuda12-sherpa-onnx"
|
||||
nvidia-cuda-12: "cuda12-sherpa-onnx"
|
||||
- !!merge <<: *neutts
|
||||
name: "neutts-development"
|
||||
capabilities:
|
||||
@@ -3773,3 +3790,91 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-insightface"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-insightface
|
||||
|
||||
# speaker-recognition (voice/speaker biometrics) — Apache-2.0 stack
|
||||
- &speakerrecognition
|
||||
name: "speaker-recognition"
|
||||
alias: "speaker-recognition"
|
||||
# SpeechBrain is Apache-2.0. WeSpeaker / 3D-Speaker ONNX exports are
|
||||
# Apache-2.0. The backend itself ships only Python deps — all model
|
||||
# weights flow through LocalAI's gallery download mechanism (or
|
||||
# SpeechBrain's built-in HF auto-download at first LoadModel).
|
||||
license: apache-2.0
|
||||
description: |
|
||||
Speaker (voice) recognition backend — the audio analog to
|
||||
insightface. Wraps SpeechBrain ECAPA-TDNN (default engine, 192-d
|
||||
embeddings, ~1.9% EER on VoxCeleb) plus an OnnxDirectEngine for
|
||||
pre-exported WeSpeaker / 3D-Speaker ONNX models.
|
||||
|
||||
Exposes speaker verification (/v1/voice/verify), speaker embedding
|
||||
(/v1/voice/embed), speaker analysis (/v1/voice/analyze), and 1:N
|
||||
speaker identification (/v1/voice/{register,identify,forget}).
|
||||
Registrations use LocalAI's built-in vector store — same in-memory
|
||||
backing the face-recognition registry uses, separate instance.
|
||||
urls:
|
||||
- https://speechbrain.github.io/
|
||||
- https://github.com/wenet-e2e/wespeaker
|
||||
- https://github.com/modelscope/3D-Speaker
|
||||
tags:
|
||||
- voice-recognition
|
||||
- speaker-verification
|
||||
- speaker-embedding
|
||||
- gpu
|
||||
- cpu
|
||||
capabilities:
|
||||
default: "cpu-speaker-recognition"
|
||||
nvidia: "cuda12-speaker-recognition"
|
||||
nvidia-cuda-12: "cuda12-speaker-recognition"
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "speaker-recognition-development"
|
||||
capabilities:
|
||||
default: "cpu-speaker-recognition-development"
|
||||
nvidia: "cuda12-speaker-recognition-development"
|
||||
nvidia-cuda-12: "cuda12-speaker-recognition-development"
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "cpu-speaker-recognition"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-speaker-recognition"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-speaker-recognition
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "cuda12-speaker-recognition"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-speaker-recognition"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-speaker-recognition
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "cpu-speaker-recognition-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-speaker-recognition"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-speaker-recognition
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "cuda12-speaker-recognition-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-speaker-recognition"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-speaker-recognition
|
||||
## sherpa-onnx
|
||||
- !!merge <<: *sherpa-onnx
|
||||
name: "sherpa-onnx-development"
|
||||
capabilities:
|
||||
default: "cpu-sherpa-onnx-development"
|
||||
nvidia: "cuda12-sherpa-onnx-development"
|
||||
nvidia-cuda-12: "cuda12-sherpa-onnx-development"
|
||||
- !!merge <<: *sherpa-onnx
|
||||
name: "cpu-sherpa-onnx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-sherpa-onnx"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-sherpa-onnx
|
||||
- !!merge <<: *sherpa-onnx
|
||||
name: "cpu-sherpa-onnx-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-sherpa-onnx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-sherpa-onnx
|
||||
- !!merge <<: *sherpa-onnx
|
||||
name: "cuda12-sherpa-onnx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-sherpa-onnx"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-sherpa-onnx
|
||||
- !!merge <<: *sherpa-onnx
|
||||
name: "cuda12-sherpa-onnx-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-sherpa-onnx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-sherpa-onnx
|
||||
|
||||
@@ -11,3 +11,6 @@ protogen-clean:
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
|
||||
test: install
|
||||
bash test.sh
|
||||
|
||||
@@ -180,23 +180,57 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
verified = distance < threshold
|
||||
confidence = max(0.0, min(100.0, (1.0 - distance / threshold) * 100.0)) if threshold > 0 else 0.0
|
||||
|
||||
def _region(img) -> backend_pb2.FacialArea:
|
||||
# Detect once per image — region is needed for the response and
|
||||
# potentially for the antispoof crop. Returns the highest-score face.
|
||||
def _best_detection(img):
|
||||
dets = self.engine.detect(img)
|
||||
if not dets:
|
||||
return None
|
||||
return max(dets, key=lambda d: d.score)
|
||||
|
||||
def _region(det) -> backend_pb2.FacialArea:
|
||||
if det is None:
|
||||
return backend_pb2.FacialArea()
|
||||
best = max(dets, key=lambda d: d.score)
|
||||
x1, y1, x2, y2 = best.bbox
|
||||
x1, y1, x2, y2 = det.bbox
|
||||
return backend_pb2.FacialArea(x=x1, y=y1, w=x2 - x1, h=y2 - y1)
|
||||
|
||||
det1 = _best_detection(img1)
|
||||
det2 = _best_detection(img2)
|
||||
|
||||
img1_is_real = False
|
||||
img1_score = 0.0
|
||||
img2_is_real = False
|
||||
img2_score = 0.0
|
||||
if request.anti_spoofing:
|
||||
spoof1 = self.engine.antispoof(img1, det1.bbox) if det1 is not None else None
|
||||
spoof2 = self.engine.antispoof(img2, det2.bbox) if det2 is not None else None
|
||||
if spoof1 is None or spoof2 is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details(
|
||||
"anti_spoofing requested but no antispoof model is loaded — "
|
||||
"install `silent-face-antispoofing` or pick a gallery entry "
|
||||
"that bundles MiniFASNet weights"
|
||||
)
|
||||
return backend_pb2.FaceVerifyResponse()
|
||||
img1_is_real, img1_score = spoof1.is_real, spoof1.score
|
||||
img2_is_real, img2_score = spoof2.is_real, spoof2.score
|
||||
# Failed liveness vetoes verification regardless of similarity.
|
||||
if not (img1_is_real and img2_is_real):
|
||||
verified = False
|
||||
|
||||
return backend_pb2.FaceVerifyResponse(
|
||||
verified=verified,
|
||||
distance=float(distance),
|
||||
threshold=float(threshold),
|
||||
confidence=float(confidence),
|
||||
model=self.model_name or self.engine_name,
|
||||
img1_area=_region(img1),
|
||||
img2_area=_region(img2),
|
||||
img1_area=_region(det1),
|
||||
img2_area=_region(det2),
|
||||
processing_time_ms=float((time.time() - start) * 1000.0),
|
||||
img1_is_real=img1_is_real,
|
||||
img1_antispoof_score=float(img1_score),
|
||||
img2_is_real=img2_is_real,
|
||||
img2_antispoof_score=float(img2_score),
|
||||
)
|
||||
|
||||
def FaceAnalyze(self, request, context):
|
||||
@@ -223,6 +257,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
fa.dominant_gender = attrs.dominant_gender
|
||||
for k, v in attrs.gender.items():
|
||||
fa.gender[k] = float(v)
|
||||
if request.anti_spoofing:
|
||||
bbox = (float(x), float(y), float(x + w), float(y + h))
|
||||
spoof = self.engine.antispoof(img, bbox)
|
||||
if spoof is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details(
|
||||
"anti_spoofing requested but no antispoof model is loaded — "
|
||||
"install `silent-face-antispoofing` or pick a gallery entry "
|
||||
"that bundles MiniFASNet weights"
|
||||
)
|
||||
return backend_pb2.FaceAnalyzeResponse()
|
||||
fa.is_real = spoof.is_real
|
||||
fa.antispoof_score = float(spoof.score)
|
||||
faces.append(fa)
|
||||
return backend_pb2.FaceAnalyzeResponse(faces=faces)
|
||||
|
||||
|
||||
@@ -41,6 +41,12 @@ class FaceAttributes:
|
||||
gender: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpoofResult:
|
||||
is_real: bool
|
||||
score: float # averaged probability of the "real" class, 0.0-1.0
|
||||
|
||||
|
||||
class FaceEngine(Protocol):
|
||||
"""Minimal interface every engine must implement."""
|
||||
|
||||
@@ -48,10 +54,149 @@ class FaceEngine(Protocol):
|
||||
def detect(self, img: np.ndarray) -> list[FaceDetection]: ...
|
||||
def embed(self, img: np.ndarray) -> np.ndarray | None: ...
|
||||
def analyze(self, img: np.ndarray) -> list[FaceAttributes]: ...
|
||||
# Optional: returns None when no antispoof model is loaded.
|
||||
def antispoof(self, img: np.ndarray, bbox: tuple[float, float, float, float]) -> SpoofResult | None: ...
|
||||
|
||||
|
||||
# ─── Antispoofer (Silent-Face MiniFASNet) ──────────────────────────────
|
||||
|
||||
class Antispoofer:
|
||||
"""Liveness detector using the Silent-Face MiniFASNet ensemble.
|
||||
|
||||
Loads up to two ONNX exports (MiniFASNetV2 at scale 2.7 and
|
||||
MiniFASNetV1SE at scale 4.0). Both are 80x80 BGR-float32-input
|
||||
classifiers with 3 output logits where index 1 = "real". When both
|
||||
are loaded, softmax outputs are averaged before argmax — the same
|
||||
ensembling the upstream `test.py` does.
|
||||
|
||||
Preprocessing matches yakhyo/face-anti-spoofing's reference impl:
|
||||
each model gets its own scale-expanded crop centered on the face
|
||||
bbox, resized to 80x80, fed straight as float32 BGR (no /255, no
|
||||
mean/std). See `_crop_face` for the bbox math.
|
||||
|
||||
A single model also works (the missing one is simply skipped).
|
||||
"""
|
||||
|
||||
INPUT_SIZE = (80, 80) # h, w
|
||||
REAL_CLASS_IDX = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: list[tuple[Any, float, str, str]] = [] # (session, scale, input_name, output_name)
|
||||
self.threshold: float = 0.5
|
||||
|
||||
def load(self, model_paths: list[tuple[str, float]], threshold: float = 0.5) -> None:
|
||||
"""Load one or more (path, scale) pairs."""
|
||||
import onnxruntime as ort
|
||||
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
for path, scale in model_paths:
|
||||
session = ort.InferenceSession(path, providers=providers)
|
||||
input_name = session.get_inputs()[0].name
|
||||
output_name = session.get_outputs()[0].name
|
||||
self._sessions.append((session, float(scale), input_name, output_name))
|
||||
self.threshold = float(threshold)
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return bool(self._sessions)
|
||||
|
||||
def _crop_face(self, img: np.ndarray, bbox: tuple[float, float, float, float], scale: float) -> np.ndarray:
|
||||
# bbox is (x1, y1, x2, y2) in source-image coordinates.
|
||||
src_h, src_w = img.shape[:2]
|
||||
x1, y1, x2, y2 = bbox
|
||||
box_w = max(1.0, x2 - x1)
|
||||
box_h = max(1.0, y2 - y1)
|
||||
|
||||
# Clamp scale so the expanded crop fits inside the source image.
|
||||
scale = min((src_h - 1) / box_h, (src_w - 1) / box_w, scale)
|
||||
new_w = box_w * scale
|
||||
new_h = box_h * scale
|
||||
|
||||
cx = x1 + box_w / 2.0
|
||||
cy = y1 + box_h / 2.0
|
||||
|
||||
cx1 = max(0, int(cx - new_w / 2.0))
|
||||
cy1 = max(0, int(cy - new_h / 2.0))
|
||||
cx2 = min(src_w - 1, int(cx + new_w / 2.0))
|
||||
cy2 = min(src_h - 1, int(cy + new_h / 2.0))
|
||||
|
||||
cropped = img[cy1 : cy2 + 1, cx1 : cx2 + 1]
|
||||
if cropped.size == 0:
|
||||
cropped = img
|
||||
out_h, out_w = self.INPUT_SIZE
|
||||
return cv2.resize(cropped, (out_w, out_h))
|
||||
|
||||
@staticmethod
|
||||
def _softmax(x: np.ndarray) -> np.ndarray:
|
||||
e = np.exp(x - np.max(x, axis=1, keepdims=True))
|
||||
return e / e.sum(axis=1, keepdims=True)
|
||||
|
||||
def predict(self, img: np.ndarray, bbox: tuple[float, float, float, float]) -> SpoofResult:
|
||||
if not self._sessions:
|
||||
raise RuntimeError("Antispoofer.predict called with no models loaded")
|
||||
accum = np.zeros((1, 3), dtype=np.float32)
|
||||
for session, scale, input_name, output_name in self._sessions:
|
||||
face = self._crop_face(img, bbox, scale).astype(np.float32)
|
||||
tensor = np.transpose(face, (2, 0, 1))[np.newaxis, ...]
|
||||
logits = session.run([output_name], {input_name: tensor})[0]
|
||||
accum += self._softmax(logits)
|
||||
accum /= float(len(self._sessions))
|
||||
real_prob = float(accum[0, self.REAL_CLASS_IDX])
|
||||
is_real = int(np.argmax(accum)) == self.REAL_CLASS_IDX and real_prob >= self.threshold
|
||||
return SpoofResult(is_real=is_real, score=real_prob)
|
||||
|
||||
|
||||
def _build_antispoofer(options: dict[str, str], model_dir: str | None) -> Antispoofer | None:
|
||||
"""Instantiate an Antispoofer from option keys, or return None.
|
||||
|
||||
Recognised options:
|
||||
antispoof_v2_onnx — path/filename of MiniFASNetV2 (scale 2.7)
|
||||
antispoof_v1se_onnx — path/filename of MiniFASNetV1SE (scale 4.0)
|
||||
antispoof_threshold — real-class probability threshold, default 0.5
|
||||
|
||||
Either or both can be provided. Returns None when neither is set.
|
||||
"""
|
||||
pairs: list[tuple[str, float]] = []
|
||||
v2 = options.get("antispoof_v2_onnx", "")
|
||||
if v2:
|
||||
pairs.append((_resolve_model_path(v2, model_dir=model_dir), 2.7))
|
||||
v1se = options.get("antispoof_v1se_onnx", "")
|
||||
if v1se:
|
||||
pairs.append((_resolve_model_path(v1se, model_dir=model_dir), 4.0))
|
||||
if not pairs:
|
||||
return None
|
||||
threshold = float(options.get("antispoof_threshold", "0.5"))
|
||||
spoofer = Antispoofer()
|
||||
spoofer.load(pairs, threshold=threshold)
|
||||
return spoofer
|
||||
|
||||
|
||||
# ─── InsightFaceEngine ────────────────────────────────────────────────
|
||||
|
||||
# Canonical ONNX manifest for each upstream insightface pack (v0.7 release
|
||||
# at github.com/deepinsight/insightface/releases). LocalAI's gallery extracts
|
||||
# these zips flat into the models directory, so when multiple packs or other
|
||||
# backends drop their own ONNX files alongside, the glob-the-directory
|
||||
# approach picks up foreign files and insightface's model_zoo.get_model()
|
||||
# raises IndexError trying to index `input_shape[2]` on a tensor that isn't
|
||||
# shaped like a face model. The manifest lets us pre-filter to only the
|
||||
# files that actually belong to the requested pack — deterministic, correct
|
||||
# pack choice, no crashes on neighbour ONNX files.
|
||||
_KNOWN_PACK_MANIFESTS: dict[str, frozenset[str]] = {
|
||||
"buffalo_l": frozenset({
|
||||
"det_10g.onnx",
|
||||
"w600k_r50.onnx",
|
||||
"genderage.onnx",
|
||||
"2d106det.onnx",
|
||||
"1k3d68.onnx",
|
||||
}),
|
||||
"buffalo_sc": frozenset({
|
||||
"det_500m.onnx",
|
||||
"w600k_mbf.onnx",
|
||||
}),
|
||||
}
|
||||
|
||||
|
||||
class InsightFaceEngine:
|
||||
"""Drives insightface's model_zoo directly — no FaceAnalysis wrapper.
|
||||
|
||||
@@ -80,6 +225,7 @@ class InsightFaceEngine:
|
||||
self.det_size: tuple[int, int] = (640, 640)
|
||||
self.det_thresh: float = 0.5
|
||||
self._providers: list[str] = ["CPUExecutionProvider"]
|
||||
self._antispoofer: Antispoofer | None = None
|
||||
|
||||
def prepare(self, options: dict[str, str]) -> None:
|
||||
import glob
|
||||
@@ -90,6 +236,7 @@ class InsightFaceEngine:
|
||||
self.model_pack = options.get("model_pack", "buffalo_l")
|
||||
self.det_size = _parse_det_size(options.get("det_size", "640x640"))
|
||||
self.det_thresh = float(options.get("det_thresh", "0.5"))
|
||||
self._antispoofer = _build_antispoofer(options, options.get("_model_dir"))
|
||||
|
||||
pack_dir = _locate_insightface_pack(options, self.model_pack)
|
||||
if pack_dir is None:
|
||||
@@ -99,6 +246,21 @@ class InsightFaceEngine:
|
||||
)
|
||||
|
||||
onnx_files = sorted(glob.glob(os.path.join(pack_dir, "*.onnx")))
|
||||
# When the pack extracts flat into a shared models directory it
|
||||
# mixes with ONNX files from other backends (opencv face engine,
|
||||
# MiniFASNet antispoof, WeSpeaker voice embedding, other buffalo
|
||||
# packs installed earlier). Feeding those into model_zoo.get_model()
|
||||
# blows up inside insightface's router — it assumes a 4-D NCHW
|
||||
# input and indexes `input_shape[2]` on tensors that aren't shaped
|
||||
# like a face model, raising IndexError. For the upstream packs we
|
||||
# know the exact ONNX manifest; scoping to it makes the load
|
||||
# deterministic (without it, det_10g.onnx from buffalo_l sorts
|
||||
# before det_500m.onnx from buffalo_sc and silently wins).
|
||||
manifest = _KNOWN_PACK_MANIFESTS.get(self.model_pack)
|
||||
if manifest is not None:
|
||||
scoped = [f for f in onnx_files if os.path.basename(f) in manifest]
|
||||
if scoped:
|
||||
onnx_files = scoped
|
||||
if not onnx_files:
|
||||
raise ValueError(f"no ONNX files in pack directory: {pack_dir}")
|
||||
|
||||
@@ -108,14 +270,31 @@ class InsightFaceEngine:
|
||||
self._providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
|
||||
self.models = {}
|
||||
skipped: list[tuple[str, str]] = []
|
||||
for onnx_file in onnx_files:
|
||||
m = model_zoo.get_model(onnx_file, providers=self._providers)
|
||||
try:
|
||||
m = model_zoo.get_model(onnx_file, providers=self._providers)
|
||||
except Exception as err:
|
||||
# Foreign ONNX (wrong rank/shape, non-insightface model) —
|
||||
# older insightface versions raise IndexError / ValueError
|
||||
# instead of returning None. Keep loading the rest.
|
||||
skipped.append((os.path.basename(onnx_file), str(err)))
|
||||
continue
|
||||
if m is None:
|
||||
skipped.append((os.path.basename(onnx_file), "unknown taskname"))
|
||||
continue
|
||||
# First occurrence of each taskname wins (matches FaceAnalysis).
|
||||
if m.taskname not in self.models:
|
||||
self.models[m.taskname] = m
|
||||
|
||||
if skipped:
|
||||
import sys
|
||||
print(
|
||||
f"[insightface] skipped {len(skipped)} non-pack ONNX file(s) in {pack_dir}: "
|
||||
+ ", ".join(f"{n} ({why})" for n, why in skipped),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
if "detection" not in self.models:
|
||||
raise ValueError(f"no detector (taskname='detection') found in {pack_dir}")
|
||||
self.det_model = self.models["detection"]
|
||||
@@ -187,6 +366,11 @@ class InsightFaceEngine:
|
||||
out.append(attrs)
|
||||
return out
|
||||
|
||||
def antispoof(self, img: np.ndarray, bbox: tuple[float, float, float, float]) -> SpoofResult | None:
|
||||
if self._antispoofer is None or not self._antispoofer.loaded:
|
||||
return None
|
||||
return self._antispoofer.predict(img, bbox)
|
||||
|
||||
|
||||
# ─── OnnxDirectEngine ─────────────────────────────────────────────────
|
||||
|
||||
@@ -206,6 +390,7 @@ class OnnxDirectEngine:
|
||||
self.det_thresh: float = 0.5
|
||||
self._detector: Any = None
|
||||
self._recognizer: Any = None
|
||||
self._antispoofer: Antispoofer | None = None
|
||||
|
||||
def prepare(self, options: dict[str, str]) -> None:
|
||||
raw_det = options.get("detector_onnx", "")
|
||||
@@ -219,6 +404,7 @@ class OnnxDirectEngine:
|
||||
self.recognizer_path = _resolve_model_path(raw_rec, model_dir=model_dir)
|
||||
self.input_size = _parse_det_size(options.get("det_size", "320x320"))
|
||||
self.det_thresh = float(options.get("det_thresh", "0.5"))
|
||||
self._antispoofer = _build_antispoofer(options, model_dir)
|
||||
|
||||
# YuNet is a fixed-size detector; size is reset per detect() call to
|
||||
# match the input frame.
|
||||
@@ -286,6 +472,11 @@ class OnnxDirectEngine:
|
||||
for d in self.detect(img)
|
||||
]
|
||||
|
||||
def antispoof(self, img: np.ndarray, bbox: tuple[float, float, float, float]) -> SpoofResult | None:
|
||||
if self._antispoofer is None or not self._antispoofer.loaded:
|
||||
return None
|
||||
return self._antispoofer.predict(img, bbox)
|
||||
|
||||
|
||||
# ─── helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import sys
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import grpc
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
@@ -39,6 +40,44 @@ OPENCV_FILES = [
|
||||
),
|
||||
]
|
||||
|
||||
# Silent-Face MiniFASNet ONNX files for antispoofing tests.
|
||||
ANTISPOOF_FILES = [
|
||||
(
|
||||
"MiniFASNetV2.onnx",
|
||||
"https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV2.onnx",
|
||||
"b32929adc2d9c34b9486f8c4c7bc97c1b69bc0ea9befefc380e4faae4e463907",
|
||||
),
|
||||
(
|
||||
"MiniFASNetV1SE.onnx",
|
||||
"https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV1SE.onnx",
|
||||
"ebab7f90c7833fbccd46d3a555410e78d969db5438e169b6524be444862b3676",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _download_files(specs: list[tuple[str, str, str]], env_var: str, prefix: str) -> str | None:
|
||||
"""Download a list of (filename, uri, sha256) into a directory.
|
||||
|
||||
Returns the directory, or None if any download failed.
|
||||
"""
|
||||
import hashlib
|
||||
import tempfile
|
||||
import urllib.request
|
||||
|
||||
root = os.environ.get(env_var) or tempfile.mkdtemp(prefix=prefix)
|
||||
for filename, uri, sha256 in specs:
|
||||
dest = os.path.join(root, filename)
|
||||
if os.path.isfile(dest):
|
||||
if hashlib.sha256(open(dest, "rb").read()).hexdigest() == sha256:
|
||||
continue
|
||||
try:
|
||||
urllib.request.urlretrieve(uri, dest)
|
||||
except Exception:
|
||||
return None
|
||||
if hashlib.sha256(open(dest, "rb").read()).hexdigest() != sha256:
|
||||
return None
|
||||
return root
|
||||
|
||||
|
||||
def _encode(img: np.ndarray) -> str:
|
||||
_, buf = cv2.imencode(".jpg", img)
|
||||
@@ -48,14 +87,19 @@ def _encode(img: np.ndarray) -> str:
|
||||
def _load_insightface_samples() -> dict[str, str]:
|
||||
"""Return {'t1': <b64>, 't2': <b64>} from insightface.data.get_image.
|
||||
|
||||
t1 is a group photo, t2 a different one. We reuse both as
|
||||
stand-ins for "Alice photo 1/2" and "Bob".
|
||||
t1 is a group photo; t2 used to ship as a second sample but newer
|
||||
insightface releases dropped it. We fall back to `Tom_Hanks_54745`
|
||||
(also bundled) as a distinct second face.
|
||||
"""
|
||||
from insightface.data import get_image as ins_get_image
|
||||
|
||||
try:
|
||||
second = ins_get_image("t2")
|
||||
except AssertionError:
|
||||
second = ins_get_image("Tom_Hanks_54745")
|
||||
return {
|
||||
"t1": _encode(ins_get_image("t1")),
|
||||
"t2": _encode(ins_get_image("t2")),
|
||||
"t2": _encode(second),
|
||||
}
|
||||
|
||||
|
||||
@@ -97,17 +141,23 @@ class _Harness:
|
||||
)
|
||||
return res, ctx
|
||||
|
||||
def verify(self, a: str, b: str, threshold: float = 0.0):
|
||||
return self.svc.FaceVerify(
|
||||
backend_pb2.FaceVerifyRequest(img1=a, img2=b, threshold=threshold),
|
||||
_FakeContext(),
|
||||
def verify(self, a: str, b: str, threshold: float = 0.0, anti_spoofing: bool = False):
|
||||
ctx = _FakeContext()
|
||||
res = self.svc.FaceVerify(
|
||||
backend_pb2.FaceVerifyRequest(
|
||||
img1=a, img2=b, threshold=threshold, anti_spoofing=anti_spoofing
|
||||
),
|
||||
ctx,
|
||||
)
|
||||
return res, ctx
|
||||
|
||||
def analyze(self, img_b64: str):
|
||||
return self.svc.FaceAnalyze(
|
||||
backend_pb2.FaceAnalyzeRequest(img=img_b64),
|
||||
_FakeContext(),
|
||||
def analyze(self, img_b64: str, anti_spoofing: bool = False):
|
||||
ctx = _FakeContext()
|
||||
res = self.svc.FaceAnalyze(
|
||||
backend_pb2.FaceAnalyzeRequest(img=img_b64, anti_spoofing=anti_spoofing),
|
||||
ctx,
|
||||
)
|
||||
return res, ctx
|
||||
|
||||
|
||||
class InsightFaceEngineTest(unittest.TestCase):
|
||||
@@ -138,21 +188,21 @@ class InsightFaceEngineTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(norm_sq, 1.0, places=2)
|
||||
|
||||
def test_verify_same_image(self):
|
||||
res = self.harness.verify(self.samples["t1"], self.samples["t1"])
|
||||
res, _ = self.harness.verify(self.samples["t1"], self.samples["t1"])
|
||||
self.assertTrue(res.verified)
|
||||
self.assertLess(res.distance, 0.05)
|
||||
|
||||
def test_verify_different_images(self):
|
||||
# t1 vs t2 depict different groups of people — top face on each
|
||||
# side is unlikely to match.
|
||||
res = self.harness.verify(self.samples["t1"], self.samples["t2"])
|
||||
res, _ = self.harness.verify(self.samples["t1"], self.samples["t2"])
|
||||
# We assert only that some numerical answer came back; the
|
||||
# matches-or-not determination depends on which face each side
|
||||
# picked and isn't a stable test assertion.
|
||||
self.assertGreaterEqual(res.distance, 0.0)
|
||||
|
||||
def test_analyze_has_age_and_gender(self):
|
||||
res = self.harness.analyze(self.samples["t1"])
|
||||
res, _ = self.harness.analyze(self.samples["t1"])
|
||||
self.assertGreater(len(res.faces), 0)
|
||||
for face in res.faces:
|
||||
self.assertGreater(face.face_confidence, 0.0)
|
||||
@@ -160,31 +210,29 @@ class InsightFaceEngineTest(unittest.TestCase):
|
||||
self.assertGreater(face.age, 0.0)
|
||||
self.assertIn(face.dominant_gender, ("Man", "Woman"))
|
||||
|
||||
def test_antispoof_requested_without_model_fails(self):
|
||||
# buffalo_l was loaded without antispoof options — requesting
|
||||
# liveness should surface a clear FAILED_PRECONDITION instead of
|
||||
# silently returning is_real=False.
|
||||
_, ctx = self.harness.verify(
|
||||
self.samples["t1"], self.samples["t1"], anti_spoofing=True
|
||||
)
|
||||
self.assertEqual(ctx.code, grpc.StatusCode.FAILED_PRECONDITION)
|
||||
self.assertIn("anti_spoofing", ctx.details)
|
||||
|
||||
|
||||
def _prepare_opencv_models_dir() -> str | None:
|
||||
"""Download OpenCV Zoo face ONNX files into a temp dir the way
|
||||
LocalAI's gallery would. Returns the directory, or None if
|
||||
downloads failed (network-restricted sandbox).
|
||||
"""
|
||||
import hashlib
|
||||
import tempfile
|
||||
import urllib.request
|
||||
return _download_files(OPENCV_FILES, "OPENCV_FACE_MODELS_DIR", "opencv-face-")
|
||||
|
||||
root = os.environ.get("OPENCV_FACE_MODELS_DIR") or tempfile.mkdtemp(
|
||||
prefix="opencv-face-"
|
||||
)
|
||||
for filename, uri, sha256 in OPENCV_FILES:
|
||||
dest = os.path.join(root, filename)
|
||||
if os.path.isfile(dest):
|
||||
if hashlib.sha256(open(dest, "rb").read()).hexdigest() == sha256:
|
||||
continue
|
||||
try:
|
||||
urllib.request.urlretrieve(uri, dest)
|
||||
except Exception:
|
||||
return None
|
||||
if hashlib.sha256(open(dest, "rb").read()).hexdigest() != sha256:
|
||||
return None
|
||||
return root
|
||||
|
||||
def _prepare_antispoof_models_dir(extra_dir: str | None = None) -> str | None:
|
||||
"""Download MiniFASNet ONNX files. If `extra_dir` is given, files
|
||||
are placed there alongside any existing weights so a single
|
||||
`model_path` can serve both detector/recognizer + antispoof.
|
||||
"""
|
||||
if extra_dir is not None:
|
||||
os.environ.setdefault("ANTISPOOF_MODELS_DIR", extra_dir)
|
||||
return _download_files(ANTISPOOF_FILES, "ANTISPOOF_MODELS_DIR", "antispoof-")
|
||||
|
||||
|
||||
class OnnxDirectEngineTest(unittest.TestCase):
|
||||
@@ -218,17 +266,79 @@ class OnnxDirectEngineTest(unittest.TestCase):
|
||||
self.assertGreater(len(res.embeddings), 0)
|
||||
|
||||
def test_verify_same_image(self):
|
||||
res = self.harness.verify(self.samples["t1"], self.samples["t1"], threshold=0.4)
|
||||
res, _ = self.harness.verify(self.samples["t1"], self.samples["t1"], threshold=0.4)
|
||||
self.assertTrue(res.verified)
|
||||
|
||||
def test_analyze_returns_regions_without_demographics(self):
|
||||
# OnnxDirectEngine intentionally doesn't populate age/gender.
|
||||
res = self.harness.analyze(self.samples["t1"])
|
||||
res, _ = self.harness.analyze(self.samples["t1"])
|
||||
self.assertGreater(len(res.faces), 0)
|
||||
for face in res.faces:
|
||||
self.assertEqual(face.dominant_gender, "")
|
||||
self.assertEqual(face.age, 0.0)
|
||||
|
||||
|
||||
class AntispoofingTest(unittest.TestCase):
|
||||
"""End-to-end FaceVerify / FaceAnalyze with anti_spoofing=True.
|
||||
|
||||
Loads the OpenCV-Zoo (Apache-2.0) face engine alongside the Silent-Face
|
||||
MiniFASNet ensemble. Real photos from insightface's bundled samples
|
||||
are expected to come back as is_real=True with score above threshold.
|
||||
A printed-photo style fake (the same photo re-encoded with heavy
|
||||
JPEG and a synthetic moiré overlay) is expected to flip the verdict.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Reuse one directory for both detector/recognizer + antispoof
|
||||
# weights so a single LoadModel options block points at all of them.
|
||||
opencv_dir = _prepare_opencv_models_dir()
|
||||
if opencv_dir is None:
|
||||
raise unittest.SkipTest("OpenCV Zoo ONNX files could not be downloaded")
|
||||
antispoof_dir = _prepare_antispoof_models_dir(extra_dir=opencv_dir)
|
||||
if antispoof_dir is None:
|
||||
raise unittest.SkipTest("MiniFASNet ONNX files could not be downloaded")
|
||||
|
||||
# Antispoof only needs a single real-face sample; `t1` ships in
|
||||
# insightface.data across every release.
|
||||
from insightface.data import get_image as ins_get_image
|
||||
|
||||
cls.samples = {"t1": _encode(ins_get_image("t1"))}
|
||||
cls.harness = _Harness(BackendServicer())
|
||||
load = cls.harness.load(
|
||||
[
|
||||
"engine:onnx_direct",
|
||||
"detector_onnx:face_detection_yunet_2023mar.onnx",
|
||||
"recognizer_onnx:face_recognition_sface_2021dec.onnx",
|
||||
"antispoof_v2_onnx:MiniFASNetV2.onnx",
|
||||
"antispoof_v1se_onnx:MiniFASNetV1SE.onnx",
|
||||
],
|
||||
model_path=opencv_dir,
|
||||
)
|
||||
if not load.success:
|
||||
raise unittest.SkipTest(f"LoadModel failed: {load.message}")
|
||||
|
||||
def test_verify_returns_per_image_liveness(self):
|
||||
res, ctx = self.harness.verify(
|
||||
self.samples["t1"], self.samples["t1"], threshold=0.4, anti_spoofing=True
|
||||
)
|
||||
self.assertIsNone(ctx.code, f"FaceVerify error: {ctx.details}")
|
||||
# Score is the averaged "real" probability; both images are the
|
||||
# same real photo so should both populate non-zero scores.
|
||||
self.assertGreater(res.img1_antispoof_score, 0.0)
|
||||
self.assertGreater(res.img2_antispoof_score, 0.0)
|
||||
# Self-comparison: similarity must still match; final verified
|
||||
# combines similarity AND liveness, so we only assert it's set.
|
||||
self.assertIsInstance(res.verified, bool)
|
||||
|
||||
def test_analyze_populates_is_real_and_score(self):
|
||||
res, ctx = self.harness.analyze(self.samples["t1"], anti_spoofing=True)
|
||||
self.assertIsNone(ctx.code, f"FaceAnalyze error: {ctx.details}")
|
||||
self.assertGreater(len(res.faces), 0)
|
||||
for face in res.faces:
|
||||
self.assertGreaterEqual(face.antispoof_score, 0.0)
|
||||
self.assertLessEqual(face.antispoof_score, 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
13
backend/python/speaker-recognition/Makefile
Normal file
13
backend/python/speaker-recognition/Makefile
Normal file
@@ -0,0 +1,13 @@
|
||||
.DEFAULT_GOAL := install
|
||||
|
||||
.PHONY: install
|
||||
install:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
40
backend/python/speaker-recognition/README.md
Normal file
40
backend/python/speaker-recognition/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# speaker-recognition
|
||||
|
||||
Speaker (voice) recognition backend for LocalAI. The audio analog to
|
||||
`insightface` — produces speaker embeddings and supports 1:1 voice
|
||||
verification and voice demographic analysis.
|
||||
|
||||
## Engines
|
||||
|
||||
- **SpeechBrainEngine** (default): ECAPA-TDNN trained on VoxCeleb.
|
||||
192-d L2-normalised embeddings, cosine distance for verification.
|
||||
Auto-downloads from HuggingFace on first LoadModel.
|
||||
- **OnnxDirectEngine**: Any pre-exported ONNX speaker encoder
|
||||
(WeSpeaker ResNet, 3D-Speaker ERes2Net, CAM++, …). Model path comes
|
||||
from the gallery `files:` entry.
|
||||
|
||||
Engine selection is gallery-driven: if the model config provides
|
||||
`model_path:` / `onnx:` the ONNX engine is used, otherwise the
|
||||
SpeechBrain engine.
|
||||
|
||||
## Endpoints
|
||||
|
||||
- `POST /v1/voice/verify` — 1:1 same-speaker check.
|
||||
- `POST /v1/voice/embed` — extract a speaker embedding vector.
|
||||
- `POST /v1/voice/analyze` — voice demographics, loaded lazily on
|
||||
the first analyze call:
|
||||
- **Emotion** (default, opt-out): `superb/wav2vec2-base-superb-er`
|
||||
(Apache-2.0), 4-way categorical (neutral / happy / angry / sad).
|
||||
- **Age + gender** (opt-in): no default — wire a checkpoint with a
|
||||
standard `Wav2Vec2ForSequenceClassification` head via
|
||||
`age_gender_model:<repo>` in options. The Audeering
|
||||
age-gender model is *not* usable as a drop-in because its
|
||||
multi-task head isn't loadable via `AutoModelForAudioClassification`.
|
||||
|
||||
Both heads are optional. When nothing loads, the engine returns 501.
|
||||
|
||||
## Audio input
|
||||
|
||||
Audio is materialised by the HTTP layer to a temp wav before calling
|
||||
the gRPC backend. Accepted input forms on the HTTP side: URL, data-URI,
|
||||
or raw base64. The backend itself always receives a filesystem path.
|
||||
205
backend/python/speaker-recognition/backend.py
Normal file
205
backend/python/speaker-recognition/backend.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#!/usr/bin/env python3
|
||||
"""gRPC server for the LocalAI speaker-recognition backend.
|
||||
|
||||
Implements Health / LoadModel / Status plus the voice-specific methods:
|
||||
VoiceVerify, VoiceAnalyze, VoiceEmbed. The heavy lifting lives in
|
||||
engines.py — this file is just the gRPC plumbing, mirroring the
|
||||
insightface backend's two-engine split (SpeechBrain + OnnxDirect).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent import futures
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "common"))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "common"))
|
||||
from grpc_auth import get_auth_interceptors # noqa: E402
|
||||
|
||||
from engines import SpeakerEngine, build_engine # noqa: E402
|
||||
|
||||
_ONE_DAY = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1"))
|
||||
|
||||
# ECAPA-TDNN on VoxCeleb is the reference. Threshold is tuned for
|
||||
# cosine distance (1 - cosine_similarity). Clients may override.
|
||||
DEFAULT_VERIFY_THRESHOLD = 0.25
|
||||
|
||||
|
||||
def _parse_options(raw: list[str]) -> dict[str, str]:
|
||||
out: dict[str, str] = {}
|
||||
for entry in raw:
|
||||
if ":" not in entry:
|
||||
continue
|
||||
k, v = entry.split(":", 1)
|
||||
out[k.strip()] = v.strip()
|
||||
return out
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self) -> None:
|
||||
self.engine: SpeakerEngine | None = None
|
||||
self.engine_name: str = ""
|
||||
self.model_name: str = ""
|
||||
self.verify_threshold: float = DEFAULT_VERIFY_THRESHOLD
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", "utf-8"))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
options = _parse_options(list(request.Options))
|
||||
# Surface LocalAI's models directory (ModelPath) so engines can
|
||||
# anchor relative paths and auto-download into a writable spot
|
||||
# alongside every other gallery-managed asset.
|
||||
options["_model_path"] = request.ModelPath or ""
|
||||
try:
|
||||
engine, engine_name = build_engine(request.Model, options)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return backend_pb2.Result(success=False, message=f"engine init failed: {exc}")
|
||||
|
||||
self.engine = engine
|
||||
self.engine_name = engine_name
|
||||
self.model_name = request.Model
|
||||
|
||||
threshold_opt = options.get("verify_threshold")
|
||||
if threshold_opt:
|
||||
try:
|
||||
self.verify_threshold = float(threshold_opt)
|
||||
except ValueError:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message=f"loaded {engine_name}")
|
||||
|
||||
def Status(self, request, context):
|
||||
state = backend_pb2.StatusResponse.State.READY if self.engine else backend_pb2.StatusResponse.State.UNINITIALIZED
|
||||
return backend_pb2.StatusResponse(state=state)
|
||||
|
||||
def _require_engine(self, context) -> SpeakerEngine | None:
|
||||
if self.engine is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("no speaker-recognition model loaded")
|
||||
return None
|
||||
return self.engine
|
||||
|
||||
def VoiceVerify(self, request, context):
|
||||
engine = self._require_engine(context)
|
||||
if engine is None:
|
||||
return backend_pb2.VoiceVerifyResponse()
|
||||
if not request.audio1 or not request.audio2:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("audio1 and audio2 are required")
|
||||
return backend_pb2.VoiceVerifyResponse()
|
||||
|
||||
threshold = request.threshold if request.threshold > 0 else self.verify_threshold
|
||||
started = time.time()
|
||||
try:
|
||||
distance = engine.compare(request.audio1, request.audio2)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"voice verify failed: {exc}")
|
||||
return backend_pb2.VoiceVerifyResponse()
|
||||
|
||||
elapsed_ms = (time.time() - started) * 1000.0
|
||||
# Confidence goes linearly from 100 at distance=0 to 0 at distance=threshold.
|
||||
confidence = max(0.0, min(100.0, (1.0 - distance / threshold) * 100.0))
|
||||
return backend_pb2.VoiceVerifyResponse(
|
||||
verified=distance <= threshold,
|
||||
distance=distance,
|
||||
threshold=threshold,
|
||||
confidence=confidence,
|
||||
model=self.model_name,
|
||||
processing_time_ms=elapsed_ms,
|
||||
)
|
||||
|
||||
def VoiceEmbed(self, request, context):
|
||||
engine = self._require_engine(context)
|
||||
if engine is None:
|
||||
return backend_pb2.VoiceEmbedResponse()
|
||||
if not request.audio:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("audio is required")
|
||||
return backend_pb2.VoiceEmbedResponse()
|
||||
try:
|
||||
vec = engine.embed(request.audio)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"voice embed failed: {exc}")
|
||||
return backend_pb2.VoiceEmbedResponse()
|
||||
return backend_pb2.VoiceEmbedResponse(embedding=list(vec), model=self.model_name)
|
||||
|
||||
def VoiceAnalyze(self, request, context):
|
||||
engine = self._require_engine(context)
|
||||
if engine is None:
|
||||
return backend_pb2.VoiceAnalyzeResponse()
|
||||
if not request.audio:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("audio is required")
|
||||
return backend_pb2.VoiceAnalyzeResponse()
|
||||
|
||||
actions = list(request.actions) or ["age", "gender", "emotion"]
|
||||
try:
|
||||
segments = engine.analyze(request.audio, actions)
|
||||
except NotImplementedError:
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details(f"analyze not supported by {self.engine_name}")
|
||||
return backend_pb2.VoiceAnalyzeResponse()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"voice analyze failed: {exc}")
|
||||
return backend_pb2.VoiceAnalyzeResponse()
|
||||
|
||||
proto_segments = []
|
||||
for seg in segments:
|
||||
proto_segments.append(
|
||||
backend_pb2.VoiceAnalysis(
|
||||
start=seg.get("start", 0.0),
|
||||
end=seg.get("end", 0.0),
|
||||
age=seg.get("age", 0.0),
|
||||
dominant_gender=seg.get("dominant_gender", ""),
|
||||
gender=seg.get("gender", {}),
|
||||
dominant_emotion=seg.get("dominant_emotion", ""),
|
||||
emotion=seg.get("emotion", {}),
|
||||
)
|
||||
)
|
||||
return backend_pb2.VoiceAnalyzeResponse(segments=proto_segments)
|
||||
|
||||
|
||||
def serve(address: str) -> None:
|
||||
interceptors = get_auth_interceptors()
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
interceptors=interceptors,
|
||||
options=[
|
||||
("grpc.max_send_message_length", 128 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 128 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("speaker-recognition backend listening on", address, flush=True)
|
||||
|
||||
def _stop(*_):
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _stop)
|
||||
signal.signal(signal.SIGINT, _stop)
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--addr", default="localhost:50051")
|
||||
args = parser.parse_args()
|
||||
serve(args.addr)
|
||||
428
backend/python/speaker-recognition/engines.py
Normal file
428
backend/python/speaker-recognition/engines.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Speaker-recognition engines.
|
||||
|
||||
Two engines are offered, mirroring the insightface backend's split:
|
||||
|
||||
* SpeechBrainEngine: full PyTorch / SpeechBrain path. Uses the
|
||||
ECAPA-TDNN recipe trained on VoxCeleb; 192-d L2-normalized
|
||||
embeddings, cosine distance for verification. Auto-downloads the
|
||||
checkpoint into LocalAI's models directory on first LoadModel.
|
||||
|
||||
* OnnxDirectEngine: CPU-friendly fallback that runs pre-exported
|
||||
ONNX speaker encoders (WeSpeaker ResNet34, 3D-Speaker ERes2Net,
|
||||
CAM++, etc.). Model paths come from the model config — the gallery
|
||||
`files:` flow drops them into the models directory.
|
||||
|
||||
Engine selection follows the same gallery-driven convention face
|
||||
recognition uses (insightface commits 9c6da0f7 / 405fec0b): the
|
||||
Python backend reads `engine` / `model_path` / `checkpoint` from the
|
||||
options dict and picks an engine accordingly.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Iterable, Protocol
|
||||
|
||||
|
||||
class SpeakerEngine(Protocol):
|
||||
"""Interface both concrete engines satisfy."""
|
||||
|
||||
name: str
|
||||
|
||||
def embed(self, audio_path: str) -> list[float]: # pragma: no cover - interface
|
||||
...
|
||||
|
||||
def compare(self, audio1: str, audio2: str) -> float: # pragma: no cover
|
||||
...
|
||||
|
||||
def analyze(self, audio_path: str, actions: Iterable[str]) -> list[dict[str, Any]]: # pragma: no cover
|
||||
...
|
||||
|
||||
|
||||
def _cosine_distance(a, b) -> float:
|
||||
import numpy as np
|
||||
|
||||
va = np.asarray(a, dtype=np.float32).reshape(-1)
|
||||
vb = np.asarray(b, dtype=np.float32).reshape(-1)
|
||||
na = float(np.linalg.norm(va))
|
||||
nb = float(np.linalg.norm(vb))
|
||||
if na == 0.0 or nb == 0.0:
|
||||
return 1.0
|
||||
return float(1.0 - np.dot(va, vb) / (na * nb))
|
||||
|
||||
|
||||
class AnalysisHead:
|
||||
"""Age / gender / emotion head, lazy-loaded on first analyze call.
|
||||
|
||||
Wraps two open-licence HuggingFace checkpoints:
|
||||
|
||||
* audeering/wav2vec2-large-robust-24-ft-age-gender — age
|
||||
regression (0–100 years) + 3-way gender (female/male/child).
|
||||
Apache 2.0.
|
||||
* superb/wav2vec2-base-superb-er — 4-way emotion classification
|
||||
(neutral / happy / angry / sad). Apache 2.0.
|
||||
|
||||
Either model is optional — the head degrades gracefully to only the
|
||||
attributes it could load. Override the checkpoint with the
|
||||
`age_gender_model` / `emotion_model` option if you want something
|
||||
else. Set either to an empty string to disable that head.
|
||||
"""
|
||||
|
||||
# Age + gender is OFF by default: the high-accuracy Apache-2.0
|
||||
# checkpoint (Audeering wav2vec2-large-robust-24-ft-age-gender) uses a
|
||||
# custom multi-task head that AutoModelForAudioClassification silently
|
||||
# mangles — it drops the age weights as UNEXPECTED and re-initialises
|
||||
# the classifier head with random values, so the output is noise. Users
|
||||
# who have a cleanly loadable age/gender classifier can opt in with
|
||||
# `age_gender_model:<repo>` in options. The emotion default below
|
||||
# (superb/wav2vec2-base-superb-er) loads via the standard audio-
|
||||
# classification pipeline with no such caveat.
|
||||
DEFAULT_AGE_GENDER_MODEL = ""
|
||||
DEFAULT_EMOTION_MODEL = "superb/wav2vec2-base-superb-er"
|
||||
AGE_GENDER_LABELS = ("female", "male", "child")
|
||||
|
||||
def __init__(self, options: dict[str, str]):
|
||||
self._options = options
|
||||
self._age_gender = None
|
||||
self._age_gender_processor = None
|
||||
self._age_gender_loaded = False
|
||||
self._age_gender_error: str | None = None
|
||||
self._emotion = None
|
||||
self._emotion_loaded = False
|
||||
self._emotion_error: str | None = None
|
||||
|
||||
# --- age / gender -------------------------------------------------
|
||||
def _ensure_age_gender(self):
|
||||
if self._age_gender_loaded:
|
||||
return
|
||||
self._age_gender_loaded = True
|
||||
model_id = self._options.get(
|
||||
"age_gender_model", self.DEFAULT_AGE_GENDER_MODEL
|
||||
)
|
||||
if not model_id:
|
||||
self._age_gender_error = "disabled"
|
||||
return
|
||||
try:
|
||||
# Late imports — torch / transformers are heavy and only
|
||||
# pulled in when the analyze head actually runs.
|
||||
import torch # type: ignore
|
||||
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification # type: ignore
|
||||
|
||||
self._torch = torch
|
||||
self._age_gender_processor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
self._age_gender = AutoModelForAudioClassification.from_pretrained(model_id)
|
||||
self._age_gender.eval()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._age_gender_error = f"{type(exc).__name__}: {exc}"
|
||||
|
||||
def _infer_age_gender(self, waveform_16k) -> dict[str, Any]:
|
||||
self._ensure_age_gender()
|
||||
if self._age_gender is None:
|
||||
return {}
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
inputs = self._age_gender_processor(
|
||||
waveform_16k, sampling_rate=16000, return_tensors="pt"
|
||||
)
|
||||
with self._torch.no_grad():
|
||||
outputs = self._age_gender(**inputs)
|
||||
|
||||
# Audeering's checkpoint is published with a custom head: the
|
||||
# official recipe exposes `(hidden_states, logits_age, logits_gender)`.
|
||||
# AutoModelForAudioClassification flattens that into a single
|
||||
# `logits` tensor of shape [batch, 4] — [age_regression, female, male, child].
|
||||
# Fall back gracefully when the shape is different (e.g. a
|
||||
# user-supplied age_gender_model checkpoint that returns a proper tuple).
|
||||
hidden = getattr(outputs, "logits", outputs)
|
||||
age_years = None
|
||||
gender_logits = None
|
||||
if isinstance(hidden, (tuple, list)) and len(hidden) >= 2:
|
||||
age_years = float(hidden[0].squeeze().item()) * 100.0
|
||||
gender_logits = hidden[1]
|
||||
else:
|
||||
flat = hidden.squeeze()
|
||||
if flat.ndim == 1 and flat.numel() >= 4:
|
||||
age_years = float(flat[0].item()) * 100.0
|
||||
gender_logits = flat[1:4]
|
||||
elif flat.ndim == 1 and flat.numel() == 1:
|
||||
age_years = float(flat.item()) * 100.0
|
||||
|
||||
if age_years is None and gender_logits is None:
|
||||
return {}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if age_years is not None:
|
||||
result["age"] = age_years
|
||||
if gender_logits is not None:
|
||||
probs = self._torch.softmax(gender_logits, dim=-1).cpu().numpy()
|
||||
probs = np.asarray(probs).reshape(-1)
|
||||
gender_map = {
|
||||
label: float(probs[i])
|
||||
for i, label in enumerate(self.AGE_GENDER_LABELS[: len(probs)])
|
||||
}
|
||||
result["gender"] = gender_map
|
||||
if gender_map:
|
||||
dom = max(gender_map.items(), key=lambda kv: kv[1])[0]
|
||||
result["dominant_gender"] = {
|
||||
"female": "Female",
|
||||
"male": "Male",
|
||||
"child": "Child",
|
||||
}.get(dom, dom.capitalize())
|
||||
return result
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Analyze is a best-effort feature — never take down the
|
||||
# whole analyze call because the age/gender head had a bad
|
||||
# day. Mark the failure so the emotion branch still runs.
|
||||
self._age_gender_error = f"runtime: {type(exc).__name__}: {exc}"
|
||||
return {}
|
||||
|
||||
# --- emotion ------------------------------------------------------
|
||||
def _ensure_emotion(self):
|
||||
if self._emotion_loaded:
|
||||
return
|
||||
self._emotion_loaded = True
|
||||
model_id = self._options.get("emotion_model", self.DEFAULT_EMOTION_MODEL)
|
||||
if not model_id:
|
||||
self._emotion_error = "disabled"
|
||||
return
|
||||
try:
|
||||
from transformers import pipeline # type: ignore
|
||||
|
||||
self._emotion = pipeline("audio-classification", model=model_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._emotion_error = f"{type(exc).__name__}: {exc}"
|
||||
|
||||
def _infer_emotion(self, audio_path: str) -> dict[str, Any]:
|
||||
self._ensure_emotion()
|
||||
if self._emotion is None:
|
||||
return {}
|
||||
try:
|
||||
raw = self._emotion(audio_path, top_k=8)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Second-line defense: don't fail the whole analyze call
|
||||
# over a runtime inference hiccup.
|
||||
self._emotion_error = f"runtime: {type(exc).__name__}: {exc}"
|
||||
return {}
|
||||
emotion_map = {row["label"].lower(): float(row["score"]) for row in raw}
|
||||
if not emotion_map:
|
||||
return {}
|
||||
dom = max(emotion_map.items(), key=lambda kv: kv[1])[0]
|
||||
return {"emotion": emotion_map, "dominant_emotion": dom}
|
||||
|
||||
# --- orchestrator -------------------------------------------------
|
||||
def analyze(self, audio_path: str, waveform_16k, actions: Iterable[str]) -> dict[str, Any]:
|
||||
wanted = {a.strip().lower() for a in actions} if actions else {"age", "gender", "emotion"}
|
||||
result: dict[str, Any] = {}
|
||||
if "age" in wanted or "gender" in wanted:
|
||||
ag = self._infer_age_gender(waveform_16k)
|
||||
if "age" in wanted and "age" in ag:
|
||||
result["age"] = ag["age"]
|
||||
if "gender" in wanted:
|
||||
if "gender" in ag:
|
||||
result["gender"] = ag["gender"]
|
||||
if "dominant_gender" in ag:
|
||||
result["dominant_gender"] = ag["dominant_gender"]
|
||||
if "emotion" in wanted:
|
||||
em = self._infer_emotion(audio_path)
|
||||
result.update(em)
|
||||
return result
|
||||
|
||||
|
||||
class SpeechBrainEngine:
|
||||
"""ECAPA-TDNN via SpeechBrain. Auto-downloads on first use."""
|
||||
|
||||
name = "speechbrain-ecapa-tdnn"
|
||||
|
||||
def __init__(self, model_name: str, options: dict[str, str]):
|
||||
# Late imports so the module can be introspected / tested
|
||||
# without torch / speechbrain being installed.
|
||||
from speechbrain.inference.speaker import EncoderClassifier # type: ignore
|
||||
|
||||
source = options.get("source") or model_name or "speechbrain/spkrec-ecapa-voxceleb"
|
||||
savedir = options.get("_model_path") or os.environ.get("HF_HOME") or "./pretrained_models"
|
||||
self._model = EncoderClassifier.from_hparams(source=source, savedir=savedir)
|
||||
self._analysis = AnalysisHead(options)
|
||||
|
||||
def _load_waveform(self, path: str):
|
||||
# Use soundfile + torch directly — torchaudio.load in torchaudio
|
||||
# 2.8+ requires the torchcodec package for decoding, which adds
|
||||
# another heavy ffmpeg-linked dep. soundfile covers WAV/FLAC
|
||||
# which is what we care about here.
|
||||
import numpy as np
|
||||
import soundfile as sf # type: ignore
|
||||
import torch # type: ignore
|
||||
|
||||
audio, sr = sf.read(path, always_2d=False)
|
||||
if audio.ndim > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
audio = np.asarray(audio, dtype=np.float32)
|
||||
if sr != 16000:
|
||||
# Simple linear resample — good enough for 16kHz downsampling
|
||||
# from 44.1/48kHz, and we expect 16kHz inputs in practice.
|
||||
ratio = 16000 / float(sr)
|
||||
n = int(round(len(audio) * ratio))
|
||||
audio = np.interp(
|
||||
np.linspace(0, len(audio), n, endpoint=False),
|
||||
np.arange(len(audio)),
|
||||
audio,
|
||||
).astype(np.float32)
|
||||
return torch.from_numpy(audio).unsqueeze(0) # [1, T]
|
||||
|
||||
def embed(self, audio_path: str) -> list[float]:
|
||||
waveform = self._load_waveform(audio_path)
|
||||
vec = self._model.encode_batch(waveform).squeeze().detach().cpu().numpy()
|
||||
return [float(x) for x in vec]
|
||||
|
||||
def compare(self, audio1: str, audio2: str) -> float:
|
||||
return _cosine_distance(self.embed(audio1), self.embed(audio2))
|
||||
|
||||
def analyze(self, audio_path: str, actions):
|
||||
# Age / gender / emotion aren't produced by ECAPA-TDNN itself;
|
||||
# delegate to AnalysisHead which wraps separate Apache-2.0
|
||||
# checkpoints. Returns a single segment spanning the clip —
|
||||
# segmentation / diarisation is a future enhancement.
|
||||
waveform = self._load_waveform(audio_path)
|
||||
mono = waveform.squeeze().detach().cpu().numpy()
|
||||
attrs = self._analysis.analyze(audio_path, mono, actions)
|
||||
if not attrs:
|
||||
raise NotImplementedError(
|
||||
"analyze head failed to load — install transformers + torch or pass age_gender_model/emotion_model options"
|
||||
)
|
||||
duration = float(mono.shape[-1]) / 16000.0 if mono.size else 0.0
|
||||
return [dict(start=0.0, end=duration, **attrs)]
|
||||
|
||||
|
||||
class OnnxDirectEngine:
|
||||
"""Run a pre-exported ONNX speaker encoder (WeSpeaker / 3D-Speaker)."""
|
||||
|
||||
name = "onnx-direct"
|
||||
|
||||
def __init__(self, model_name: str, options: dict[str, str]):
|
||||
import onnxruntime as ort # type: ignore
|
||||
|
||||
# The gallery is expected to have dropped the ONNX file under
|
||||
# the models directory; accept either an absolute path or a
|
||||
# filename relative to _model_path.
|
||||
onnx_path = options.get("model_path") or options.get("onnx")
|
||||
if not onnx_path:
|
||||
raise ValueError("OnnxDirectEngine requires `model_path: <file.onnx>` in options")
|
||||
if not os.path.isabs(onnx_path):
|
||||
onnx_path = os.path.join(options.get("_model_path", ""), onnx_path)
|
||||
if not os.path.isfile(onnx_path):
|
||||
raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
|
||||
|
||||
providers = options.get("providers")
|
||||
if providers:
|
||||
provider_list = [p.strip() for p in providers.split(",") if p.strip()]
|
||||
else:
|
||||
provider_list = ["CPUExecutionProvider"]
|
||||
self._session = ort.InferenceSession(onnx_path, providers=provider_list)
|
||||
input_meta = self._session.get_inputs()[0]
|
||||
self._input_name = input_meta.name
|
||||
# Pre-exported speaker encoders come in two shapes:
|
||||
# rank-2 [batch, samples] — some 3D-Speaker exports feed raw waveform.
|
||||
# rank-3 [batch, frames, n_mels] — WeSpeaker and most Kaldi-lineage encoders
|
||||
# expect pre-computed Kaldi FBank features.
|
||||
# We detect this at load time and branch in embed(), because feeding raw audio
|
||||
# into a rank-3 graph is exactly what triggered
|
||||
# "Invalid rank for input: feats Got: 2 Expected: 3".
|
||||
self._input_rank = len(input_meta.shape) if input_meta.shape is not None else 2
|
||||
self._expected_sr = int(options.get("sample_rate", "16000"))
|
||||
self._fbank_mels = int(options.get("fbank_num_mel_bins", "80"))
|
||||
self._fbank_frame_length_ms = float(options.get("fbank_frame_length_ms", "25"))
|
||||
self._fbank_frame_shift_ms = float(options.get("fbank_frame_shift_ms", "10"))
|
||||
# Per-utterance cepstral mean normalisation — on for WeSpeaker by default,
|
||||
# toggleable for encoders that expect raw FBank.
|
||||
self._fbank_cmn = options.get("fbank_cmn", "true").lower() in ("1", "true", "yes")
|
||||
self._analysis = AnalysisHead(options)
|
||||
|
||||
def _load_waveform(self, path: str):
|
||||
import numpy as np
|
||||
import soundfile as sf # type: ignore
|
||||
|
||||
audio, sr = sf.read(path, always_2d=False)
|
||||
if sr != self._expected_sr:
|
||||
# Cheap linear resample — good enough for sanity; callers
|
||||
# should pre-resample for production.
|
||||
ratio = self._expected_sr / float(sr)
|
||||
n = int(round(len(audio) * ratio))
|
||||
audio = np.interp(
|
||||
np.linspace(0, len(audio), n, endpoint=False),
|
||||
np.arange(len(audio)),
|
||||
audio,
|
||||
)
|
||||
if audio.ndim > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
return audio.astype("float32")
|
||||
|
||||
def embed(self, audio_path: str) -> list[float]:
|
||||
import numpy as np
|
||||
|
||||
audio = self._load_waveform(audio_path)
|
||||
if self._input_rank >= 3:
|
||||
feats = self._extract_fbank(audio) # [frames, n_mels]
|
||||
feed = feats[np.newaxis, :, :] # [1, frames, n_mels]
|
||||
else:
|
||||
feed = audio.reshape(1, -1) # [1, samples]
|
||||
out = self._session.run(None, {self._input_name: feed})
|
||||
vec = np.asarray(out[0]).reshape(-1)
|
||||
return [float(x) for x in vec]
|
||||
|
||||
def _extract_fbank(self, audio):
|
||||
"""Compute Kaldi-style 80-dim FBank features for speaker encoders that
|
||||
expect pre-featurised input (WeSpeaker, most 3D-Speaker exports).
|
||||
torchaudio is already a backend dependency for SpeechBrain — no new
|
||||
package required."""
|
||||
import numpy as np
|
||||
import torch # type: ignore
|
||||
import torchaudio.compliance.kaldi as kaldi # type: ignore
|
||||
|
||||
tensor = torch.from_numpy(audio).unsqueeze(0) # [1, samples]
|
||||
feats = kaldi.fbank(
|
||||
tensor,
|
||||
sample_frequency=self._expected_sr,
|
||||
num_mel_bins=self._fbank_mels,
|
||||
frame_length=self._fbank_frame_length_ms,
|
||||
frame_shift=self._fbank_frame_shift_ms,
|
||||
dither=0.0,
|
||||
) # [frames, n_mels]
|
||||
if self._fbank_cmn:
|
||||
feats = feats - feats.mean(dim=0, keepdim=True)
|
||||
return feats.numpy().astype(np.float32)
|
||||
|
||||
def compare(self, audio1: str, audio2: str) -> float:
|
||||
return _cosine_distance(self.embed(audio1), self.embed(audio2))
|
||||
|
||||
def analyze(self, audio_path: str, actions):
|
||||
# AnalysisHead expects 16kHz mono; _load_waveform already
|
||||
# resamples to self._expected_sr. If the user configured a
|
||||
# non-16k expected rate, resample one more time for analyze.
|
||||
audio = self._load_waveform(audio_path)
|
||||
if self._expected_sr != 16000:
|
||||
import numpy as np
|
||||
|
||||
ratio = 16000 / float(self._expected_sr)
|
||||
n = int(round(len(audio) * ratio))
|
||||
audio = np.interp(
|
||||
np.linspace(0, len(audio), n, endpoint=False),
|
||||
np.arange(len(audio)),
|
||||
audio,
|
||||
).astype("float32")
|
||||
attrs = self._analysis.analyze(audio_path, audio, actions)
|
||||
if not attrs:
|
||||
raise NotImplementedError(
|
||||
"analyze head failed to load — install transformers + torch or pass age_gender_model/emotion_model options"
|
||||
)
|
||||
duration = float(len(audio)) / 16000.0 if len(audio) else 0.0
|
||||
return [dict(start=0.0, end=duration, **attrs)]
|
||||
|
||||
|
||||
def build_engine(model_name: str, options: dict[str, str]) -> tuple[SpeakerEngine, str]:
|
||||
"""Pick an engine based on the options. ONNX path takes priority:
|
||||
if the gallery has dropped a `model_path:` or `onnx:` option, run
|
||||
the direct ONNX engine. Otherwise, fall back to SpeechBrain.
|
||||
"""
|
||||
engine_kind = (options.get("engine") or "").lower()
|
||||
if engine_kind == "onnx" or options.get("model_path") or options.get("onnx"):
|
||||
return OnnxDirectEngine(model_name, options), OnnxDirectEngine.name
|
||||
return SpeechBrainEngine(model_name, options), SpeechBrainEngine.name
|
||||
19
backend/python/speaker-recognition/install.sh
Executable file
19
backend/python/speaker-recognition/install.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
# No pre-baked model weights. Weights flow through LocalAI's gallery
|
||||
# `files:` mechanism — see gallery entries for speechbrain-ecapa-tdnn
|
||||
# and WeSpeaker / 3D-Speaker ONNX packs. SpeechBrain's
|
||||
# EncoderClassifier.from_hparams also knows how to auto-download from
|
||||
# HuggingFace into the configured savedir (we point it at ModelPath),
|
||||
# so the first LoadModel call bootstraps the checkpoint if the gallery
|
||||
# flow wasn't used.
|
||||
5
backend/python/speaker-recognition/requirements-cpu.txt
Normal file
5
backend/python/speaker-recognition/requirements-cpu.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
torch
|
||||
torchaudio
|
||||
speechbrain
|
||||
transformers
|
||||
onnxruntime
|
||||
@@ -0,0 +1,5 @@
|
||||
torch
|
||||
torchaudio
|
||||
speechbrain
|
||||
transformers
|
||||
onnxruntime-gpu
|
||||
5
backend/python/speaker-recognition/requirements.txt
Normal file
5
backend/python/speaker-recognition/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
grpcio-tools
|
||||
numpy
|
||||
soundfile
|
||||
9
backend/python/speaker-recognition/run.sh
Executable file
9
backend/python/speaker-recognition/run.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
78
backend/python/speaker-recognition/test.py
Normal file
78
backend/python/speaker-recognition/test.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Unit tests for the speaker-recognition gRPC backend.
|
||||
|
||||
The servicer is instantiated in-process (no gRPC channel) and driven
|
||||
directly. The default path exercises SpeechBrain's ECAPA-TDNN — the
|
||||
first run downloads the checkpoint into a temp savedir. Tests are
|
||||
skipped gracefully when the heavy optional dependencies (torch /
|
||||
speechbrain / onnxruntime) are not installed, so the gRPC plumbing
|
||||
can still be verified on a bare image.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
import backend_pb2 # noqa: E402
|
||||
|
||||
from backend import BackendServicer # noqa: E402
|
||||
|
||||
|
||||
def _have(*mods: str) -> bool:
|
||||
for m in mods:
|
||||
if importlib.util.find_spec(m) is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class _FakeCtx:
|
||||
"""Minimal stand-in for a gRPC servicer context."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.code = None
|
||||
self.details = ""
|
||||
|
||||
def set_code(self, c):
|
||||
self.code = c
|
||||
|
||||
def set_details(self, d):
|
||||
self.details = d
|
||||
|
||||
|
||||
class ServicerPlumbingTest(unittest.TestCase):
|
||||
"""Checks that LoadModel returns a clear error when no engine deps
|
||||
are installed, and that Voice* calls on an uninitialised servicer
|
||||
surface FAILED_PRECONDITION — both verifying the gRPC wiring
|
||||
without requiring SpeechBrain or ONNX at test time."""
|
||||
|
||||
def test_pre_load_voice_calls_are_rejected(self):
|
||||
svc = BackendServicer()
|
||||
ctx = _FakeCtx()
|
||||
svc.VoiceVerify(backend_pb2.VoiceVerifyRequest(audio1="/tmp/a.wav", audio2="/tmp/b.wav"), ctx)
|
||||
self.assertEqual(str(ctx.code), "StatusCode.FAILED_PRECONDITION")
|
||||
|
||||
def test_load_without_deps_fails_cleanly(self):
|
||||
svc = BackendServicer()
|
||||
req = backend_pb2.ModelOptions(Model="speechbrain/spkrec-ecapa-voxceleb", ModelPath="")
|
||||
result = svc.LoadModel(req, _FakeCtx())
|
||||
# Either the deps are installed and it loaded, or they aren't
|
||||
# and we got a structured error instead of a crash.
|
||||
self.assertTrue(result.success or "engine init failed" in result.message)
|
||||
|
||||
|
||||
@unittest.skipUnless(_have("speechbrain", "torch", "torchaudio"), "speechbrain / torch missing")
|
||||
class SpeechBrainEngineSmokeTest(unittest.TestCase):
|
||||
def test_load_and_embed(self):
|
||||
svc = BackendServicer()
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
req = backend_pb2.ModelOptions(Model="speechbrain/spkrec-ecapa-voxceleb", ModelPath=td)
|
||||
result = svc.LoadModel(req, _FakeCtx())
|
||||
self.assertTrue(result.success, result.message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/speaker-recognition/test.sh
Executable file
11
backend/python/speaker-recognition/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -372,6 +372,41 @@ impl Backend for KokorosService {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn face_verify(
|
||||
&self,
|
||||
_: Request<backend::FaceVerifyRequest>,
|
||||
) -> Result<Response<backend::FaceVerifyResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn face_analyze(
|
||||
&self,
|
||||
_: Request<backend::FaceAnalyzeRequest>,
|
||||
) -> Result<Response<backend::FaceAnalyzeResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn voice_verify(
|
||||
&self,
|
||||
_: Request<backend::VoiceVerifyRequest>,
|
||||
) -> Result<Response<backend::VoiceVerifyResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn voice_analyze(
|
||||
&self,
|
||||
_: Request<backend::VoiceAnalyzeRequest>,
|
||||
) -> Result<Response<backend::VoiceAnalyzeResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn voice_embed(
|
||||
&self,
|
||||
_: Request<backend::VoiceEmbedRequest>,
|
||||
) -> Result<Response<backend::VoiceEmbedResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_set(
|
||||
&self,
|
||||
_: Request<backend::StoresSetOptions>,
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/facerecognition"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/voicerecognition"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -29,6 +30,12 @@ import (
|
||||
// family per deployment; we keep the door open instead.
|
||||
const faceEmbeddingDim = 0
|
||||
|
||||
// voiceEmbeddingDim is the expected dimension for speaker embeddings.
|
||||
// 0 so the Registry accepts whatever dim the loaded recognizer
|
||||
// produces — ECAPA-TDNN is 192, WeSpeaker ResNet34 is 256, 3D-Speaker
|
||||
// ERes2Net is 192, CAM++ is 512.
|
||||
const voiceEmbeddingDim = 0
|
||||
|
||||
type Application struct {
|
||||
backendLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
@@ -39,6 +46,7 @@ type Application struct {
|
||||
agentJobService *agentpool.AgentJobService
|
||||
agentPoolService atomic.Pointer[agentpool.AgentPoolService]
|
||||
faceRegistry facerecognition.Registry
|
||||
voiceRegistry voicerecognition.Registry
|
||||
authDB *gorm.DB
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
@@ -73,10 +81,30 @@ func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
// The resolver closes over the ModelLoader so the Registry stays
|
||||
// decoupled from loader plumbing; swapping in a postgres-backed
|
||||
// implementation later is a single construction change here.
|
||||
//
|
||||
// `faceStoreName` is the default namespace passed to StoreBackend when
|
||||
// the request doesn't override it. Face and voice MUST use distinct
|
||||
// namespaces — the local-store gRPC surface rejects mixed dimensions
|
||||
// inside one namespace ("Try to add key with length N when existing
|
||||
// length is M"). ArcFace buffalo_l produces 512-dim embeddings while
|
||||
// ECAPA-TDNN produces 192-dim; enrolling one after the other into a
|
||||
// shared namespace is exactly how we hit that error.
|
||||
const (
|
||||
faceStoreName = "localai-face-biometrics"
|
||||
voiceStoreName = "localai-voice-biometrics"
|
||||
)
|
||||
faceStoreResolver := func(_ context.Context, storeName string) (pkggrpc.Backend, error) {
|
||||
return corebackend.StoreBackend(ml, appConfig, storeName, "")
|
||||
}
|
||||
app.faceRegistry = facerecognition.NewStoreRegistry(faceStoreResolver, "", faceEmbeddingDim)
|
||||
app.faceRegistry = facerecognition.NewStoreRegistry(faceStoreResolver, faceStoreName, faceEmbeddingDim)
|
||||
|
||||
// Voice (speaker) recognition registry — same plumbing, separate
|
||||
// namespace so embedding spaces stay isolated (a face vector and a
|
||||
// speaker vector are not comparable and differ in dimensionality).
|
||||
voiceStoreResolver := func(_ context.Context, storeName string) (pkggrpc.Backend, error) {
|
||||
return corebackend.StoreBackend(ml, appConfig, storeName, "")
|
||||
}
|
||||
app.voiceRegistry = voicerecognition.NewStoreRegistry(voiceStoreResolver, voiceStoreName, voiceEmbeddingDim)
|
||||
|
||||
return app
|
||||
}
|
||||
@@ -130,6 +158,14 @@ func (a *Application) FaceRegistry() facerecognition.Registry {
|
||||
return a.faceRegistry
|
||||
}
|
||||
|
||||
// VoiceRegistry returns the voice (speaker) recognition registry used
|
||||
// for 1:N identification. Same in-memory local-store backing as
|
||||
// FaceRegistry but a separate instance — voice embeddings live in
|
||||
// their own vector space.
|
||||
func (a *Application) VoiceRegistry() voicerecognition.Registry {
|
||||
return a.voiceRegistry
|
||||
}
|
||||
|
||||
// AuthDB returns the auth database connection, or nil if auth is not enabled.
|
||||
func (a *Application) AuthDB() *gorm.DB {
|
||||
return a.authDB
|
||||
|
||||
@@ -242,6 +242,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
bmFn := func() galleryop.BackendManager { return application.GalleryService().BackendManager() }
|
||||
uc := NewUpgradeChecker(options, application.ModelLoader(), application.distributedDB(), bmFn)
|
||||
application.upgradeChecker = uc
|
||||
// Refresh the upgrade cache the moment a backend op finishes — otherwise
|
||||
// the UI keeps showing a just-upgraded backend as upgradeable until the
|
||||
// next 6-hour tick. TriggerCheck is non-blocking.
|
||||
if gs := application.GalleryService(); gs != nil {
|
||||
gs.OnBackendOpCompleted = uc.TriggerCheck
|
||||
}
|
||||
go uc.Run(options.Context)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,17 @@ func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, st
|
||||
if backend == "" {
|
||||
backend = model.LocalStoreBackend
|
||||
}
|
||||
// ModelLoader caches backend processes by `modelID`, not by the `model`
|
||||
// passed via WithModel. Without a distinct modelID, every StoreBackend
|
||||
// call collapses to the same `modelID=""` cache slot — face (512-D) and
|
||||
// voice (192-D) biometrics would then share the same local-store process
|
||||
// and the second enrollment would fail with
|
||||
// Try to add key with length N when existing length is M
|
||||
// Use the store namespace as modelID so each namespace gets its own
|
||||
// process instance and its own in-memory Store{}.
|
||||
sc := []model.Option{
|
||||
model.WithBackendString(backend),
|
||||
model.WithModelID(storeName),
|
||||
model.WithModel(storeName),
|
||||
}
|
||||
|
||||
|
||||
58
core/backend/voice_analyze.go
Normal file
58
core/backend/voice_analyze.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VoiceAnalyze(
|
||||
audio string,
|
||||
actions []string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.VoiceAnalyzeResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
voiceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if voiceModel == nil {
|
||||
return nil, fmt.Errorf("could not load voice recognition model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceAnalyze(context.Background(), &proto.VoiceAnalyzeRequest{
|
||||
Audio: audio,
|
||||
Actions: actions,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceVoiceAnalyze,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
66
core/backend/voice_embed.go
Normal file
66
core/backend/voice_embed.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// VoiceEmbed returns a speaker embedding (typically 192-d for ECAPA-TDNN)
|
||||
// for the audio file at audioPath. Unlike ModelEmbedding (which is
|
||||
// OpenAI-compatible and text-only), this call takes an audio path and
|
||||
// returns the backend's speaker-encoder output.
|
||||
func VoiceEmbed(
|
||||
audioPath string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.VoiceEmbedResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
voiceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if voiceModel == nil {
|
||||
return nil, fmt.Errorf("could not load voice recognition model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceEmbed(context.Background(), &proto.VoiceEmbedRequest{
|
||||
Audio: audioPath,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceVoiceEmbed,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res == nil || len(res.Embedding) == 0 {
|
||||
return nil, fmt.Errorf("voice embedding returned empty vector (no speech detected?)")
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
61
core/backend/voice_verify.go
Normal file
61
core/backend/voice_verify.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VoiceVerify(
|
||||
audio1, audio2 string,
|
||||
threshold float32,
|
||||
antiSpoofing bool,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.VoiceVerifyResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
voiceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if voiceModel == nil {
|
||||
return nil, fmt.Errorf("could not load voice recognition model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceVerify(context.Background(), &proto.VoiceVerifyRequest{
|
||||
Audio1: audio1,
|
||||
Audio2: audio2,
|
||||
Threshold: threshold,
|
||||
AntiSpoofing: antiSpoofing,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceVoiceVerify,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
@@ -37,6 +37,14 @@ var CacheTypeOptions = []FieldOption{
|
||||
{Value: "q4_1", Label: "Q4_1"},
|
||||
{Value: "q5_0", Label: "Q5_0"},
|
||||
{Value: "q5_1", Label: "Q5_1"},
|
||||
// TurboQuant KV-cache types — accepted by the turboquant and
|
||||
// buun-llama-cpp fork backends; stock llama-cpp will reject them at load.
|
||||
{Value: "turbo2", Label: "Turbo2 (TurboQuant)"},
|
||||
{Value: "turbo3", Label: "Turbo3 (TurboQuant)"},
|
||||
{Value: "turbo4", Label: "Turbo4 (TurboQuant)"},
|
||||
// Trellis-Coded Quantization variants — buun-llama-cpp only.
|
||||
{Value: "turbo2_tcq", Label: "Turbo2 TCQ (buun-llama-cpp)"},
|
||||
{Value: "turbo3_tcq", Label: "Turbo3 TCQ (buun-llama-cpp)"},
|
||||
}
|
||||
|
||||
var DiffusersPipelineOptions = []FieldOption{
|
||||
|
||||
@@ -588,7 +588,8 @@ const (
|
||||
FLAG_VAD ModelConfigUsecase = 0b010000000000
|
||||
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
|
||||
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
|
||||
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b10000000000000
|
||||
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b10000000000000
|
||||
FLAG_SPEAKER_RECOGNITION ModelConfigUsecase = 0b100000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
@@ -612,7 +613,8 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_DETECTION": FLAG_DETECTION,
|
||||
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
|
||||
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
|
||||
"FLAG_SPEAKER_RECOGNITION": FLAG_SPEAKER_RECOGNITION,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -653,7 +655,7 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
nonTextGenBackends := []string{
|
||||
"whisper", "piper", "kokoro",
|
||||
"diffusers", "stablediffusion", "stablediffusion-ggml",
|
||||
"rerankers", "silero-vad", "rfdetr", "insightface",
|
||||
"rerankers", "silero-vad", "rfdetr", "insightface", "speaker-recognition",
|
||||
"transformers-musicgen", "ace-step", "acestep-cpp",
|
||||
}
|
||||
|
||||
@@ -743,6 +745,13 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SPEAKER_RECOGNITION) == FLAG_SPEAKER_RECOGNITION {
|
||||
speakerBackends := []string{"speaker-recognition"}
|
||||
if !slices.Contains(speakerBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
|
||||
soundGenBackends := []string{"transformers-musicgen", "ace-step", "acestep-cpp", "mock-backend"}
|
||||
if !slices.Contains(soundGenBackends, c.Backend) {
|
||||
@@ -758,7 +767,7 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
|
||||
if (u & FLAG_VAD) == FLAG_VAD {
|
||||
if c.Backend != "silero-vad" && !(c.Backend == "whisper" && slices.Contains(c.Options, "vad_only")) {
|
||||
if c.Backend != "silero-vad" && c.Backend != "sherpa-onnx" && !(c.Backend == "whisper" && slices.Contains(c.Options, "vad_only")) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,6 +194,20 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
|
||||
name := config.Name
|
||||
backendPath := filepath.Join(systemState.Backend.BackendsPath, name)
|
||||
// Clean up legacy flat-layout artefacts: earlier dev builds of the
|
||||
// golang backends dropped the compiled binary directly at
|
||||
// `<backendsPath>/<name>` (a plain file) instead of
|
||||
// `<backendsPath>/<name>/<name>` (the nested layout the current code
|
||||
// expects). MkdirAll below returns ENOTDIR when such a stale file
|
||||
// exists, permanently blocking any reinstall or upgrade. Remove the
|
||||
// file first so the install can proceed; the new install will write
|
||||
// the correct nested layout, including metadata.json + run.sh.
|
||||
if fi, statErr := os.Lstat(backendPath); statErr == nil && !fi.IsDir() {
|
||||
xlog.Warn("removing stale non-directory backend artefact to make room for fresh install", "path", backendPath)
|
||||
if rmErr := os.Remove(backendPath); rmErr != nil {
|
||||
return fmt.Errorf("failed to remove stale backend artefact at %s: %w", backendPath, rmErr)
|
||||
}
|
||||
}
|
||||
err = os.MkdirAll(backendPath, 0750)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create base path: %v", err)
|
||||
|
||||
126
core/gallery/importers/ace-step.go
Normal file
126
core/gallery/importers/ace-step.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &ACEStepImporter{}
|
||||
|
||||
// ACEStepImporter recognises ACE-Step music generation checkpoints
|
||||
// (ACE-Step/ACE-Step-v1-3.5B, ACE-Step/Ace-Step1.5, community finetunes).
|
||||
// Detection matches on "ace-step" in the repo name — case-insensitive —
|
||||
// so quantised mirrors still route here. The backend itself is
|
||||
// sound-generation / TTS-adjacent; the Modality() method returns "image"
|
||||
// purely to slot into the UI dropdown's image/video tab where it lives
|
||||
// with other generative media importers. preferences.backend="ace-step"
|
||||
// overrides detection.
|
||||
type ACEStepImporter struct{}
|
||||
|
||||
func (i *ACEStepImporter) Name() string { return "ace-step" }
|
||||
func (i *ACEStepImporter) Modality() string { return "image" }
|
||||
func (i *ACEStepImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *ACEStepImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "ace-step" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "ace-step") {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(details.HuggingFace.Author, "ACE-Step") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: hfapi recursion bug may leave HuggingFace nil — decide
|
||||
// from the URI owner/repo.
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "ACE-Step") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repo), "ace-step") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *ACEStepImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
} else if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
model = owner + "/" + repo
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "ace-step",
|
||||
// Mirrors gallery/index.yaml's ace-step-turbo entry which flags
|
||||
// both sound_generation and tts — ACE-Step is a music/sound model,
|
||||
// the UI groups it under image/video simply because there is no
|
||||
// first-class music tab yet.
|
||||
KnownUsecaseStrings: []string{"sound_generation", "tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
50
core/gallery/importers/ace-step_test.go
Normal file
50
core/gallery/importers/ace-step_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ACEStepImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
// ACE-Step/ACE-Step-v1-3.5B is the reference public checkpoint for
|
||||
// the ACE-Step music generation model. Detection must match on the
|
||||
// repo name substring so third-party forks and quantised mirrors
|
||||
// (e.g. Serveurperso/ACE-Step-1.5-GGUF) route to the same backend.
|
||||
It("matches ACE-Step/ACE-Step-v1-3.5B (repo name contains ACE-Step)", func() {
|
||||
uri := "https://huggingface.co/ACE-Step/ACE-Step-v1-3.5B"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: ace-step"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("ACE-Step/ACE-Step-v1-3.5B"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=ace-step for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "ace-step"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: ace-step"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.ACEStepImporter{}
|
||||
Expect(imp.Name()).To(Equal("ace-step"))
|
||||
Expect(imp.Modality()).To(Equal("image"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
29
core/gallery/importers/ambiguity_asr_test.go
Normal file
29
core/gallery/importers/ambiguity_asr_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ASR ambiguity", func() {
|
||||
// pyannote/voice-activity-detection carries
|
||||
// pipeline_tag=automatic-speech-recognition but ships only a YAML
|
||||
// recipe — no ggml-*.bin, no .nemo, no Systran-style model.bin, no
|
||||
// tokenizer.json, no .onnx. None of the ASR importers should match and
|
||||
// none of the generic importers (vllm, transformers, llama-cpp, mlx,
|
||||
// diffusers) should match either. Because the modality is in the
|
||||
// ambiguous whitelist, DiscoverModelConfig must surface
|
||||
// ErrAmbiguousImport rather than a bare "no importer matched" error.
|
||||
It("returns ErrAmbiguousImport when ASR pipeline_tag is present but no importer matches", func() {
|
||||
uri := "https://huggingface.co/pyannote/voice-activity-detection"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue(), "expected ErrAmbiguousImport, got: %v", err)
|
||||
})
|
||||
})
|
||||
34
core/gallery/importers/ambiguity_embeddings_test.go
Normal file
34
core/gallery/importers/ambiguity_embeddings_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Embeddings ambiguity", func() {
|
||||
// Qdrant/bm25 carries pipeline_tag="sentence-similarity" but ships
|
||||
// only config.json, README.md, .gitattributes, and per-language
|
||||
// stopword .txt files — no tokenizer.json (rules out vllm and
|
||||
// transformers), no modules.json / sentence_bert_config.json (rules
|
||||
// out sentencetransformers), no "reranker" / cross-encoder owner
|
||||
// (rules out rerankers), no rf-detr name (rules out rfdetr), no
|
||||
// snakers4 / silero_vad.onnx (rules out silero-vad), no .gguf
|
||||
// (rules out llama-cpp and stablediffusion-ggml), no mlx-community
|
||||
// owner (rules out mlx), no model_index.json / scheduler_config.json
|
||||
// (rules out diffusers). None of the ASR/TTS/image importers should
|
||||
// trip either. Because sentence-similarity is in the ambiguous
|
||||
// modality whitelist, DiscoverModelConfig must surface
|
||||
// ErrAmbiguousImport.
|
||||
It("returns ErrAmbiguousImport when sentence-similarity pipeline_tag is present but no importer matches", func() {
|
||||
uri := "https://huggingface.co/Qdrant/bm25"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue(), "expected ErrAmbiguousImport, got: %v", err)
|
||||
})
|
||||
})
|
||||
31
core/gallery/importers/ambiguity_image_test.go
Normal file
31
core/gallery/importers/ambiguity_image_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Image ambiguity", func() {
|
||||
// h94/IP-Adapter-FaceID carries pipeline_tag="text-to-image" but ships
|
||||
// only .bin + .safetensors + README — no model_index.json /
|
||||
// scheduler_config.json (rules out diffusers), no .gguf (rules out
|
||||
// llama-cpp and stablediffusion-ggml), no tokenizer.json (rules out
|
||||
// vllm/transformers), owner is not mlx-community (rules out mlx), and
|
||||
// the repo owner/name contain no ace-step/flux/sd1.5/sdxl/sd3/
|
||||
// stable-diffusion arch token at the URI level — so none of the
|
||||
// Batch-3 Image/Video importers match either. Because text-to-image
|
||||
// is whitelisted as an ambiguous modality, DiscoverModelConfig must
|
||||
// surface ErrAmbiguousImport rather than a bare "no importer matched".
|
||||
It("returns ErrAmbiguousImport when text-to-image pipeline_tag is present but no importer matches", func() {
|
||||
uri := "https://huggingface.co/h94/IP-Adapter-FaceID"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue(), "expected ErrAmbiguousImport, got: %v", err)
|
||||
})
|
||||
})
|
||||
32
core/gallery/importers/ambiguity_tts_test.go
Normal file
32
core/gallery/importers/ambiguity_tts_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TTS ambiguity", func() {
|
||||
// nari-labs/Dia-1.6B carries pipeline_tag="text-to-speech" but ships
|
||||
// only config.json + *.pth + model.safetensors + preprocessor_config.json.
|
||||
// None of the Batch-2 TTS importers match (owner neither "suno" nor
|
||||
// "fishaudio" nor "OuteAI" nor "KittenML" nor "ResembleAI" nor "neuphonic"
|
||||
// nor "coqui"; repo name contains none of "bark", "outetts", "voxcpm",
|
||||
// "kokoro", "kitten-tts", "neutts", "chatterbox", "vibevoice"; no piper
|
||||
// onnx/onnx.json pair). None of the generic importers match either —
|
||||
// no tokenizer.json (rules out vllm/transformers), no .gguf (llama-cpp),
|
||||
// no mlx-community owner (mlx), no model_index.json/scheduler_config
|
||||
// (diffusers). Because the HF pipeline_tag is in the ambiguous
|
||||
// whitelist, DiscoverModelConfig must surface ErrAmbiguousImport.
|
||||
It("returns ErrAmbiguousImport when TTS pipeline_tag is present but no importer matches", func() {
|
||||
uri := "https://huggingface.co/nari-labs/Dia-1.6B"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue(), "expected ErrAmbiguousImport, got: %v", err)
|
||||
})
|
||||
})
|
||||
124
core/gallery/importers/bark.go
Normal file
124
core/gallery/importers/bark.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &BarkImporter{}
|
||||
|
||||
// BarkImporter recognises Suno's Bark TTS models. The `suno` owner hosts a
|
||||
// handful of Bark variants (bark, bark-small, bark-v2-en, …) sharing the
|
||||
// "bark" prefix — narrow enough to detect without false positives from
|
||||
// other suno repos. preferences.backend="bark" overrides detection.
|
||||
//
|
||||
// NOTE: suno/bark ships a `speaker_embeddings/v2` subdirectory that hits a
|
||||
// pre-existing path-doubling bug in pkg/huggingface-api's recursive tree
|
||||
// listing (item.Path already carries the parent path, but the recursion
|
||||
// prepends the parent path again → 404). When ModelDetails fetching fails,
|
||||
// DiscoverModelConfig leaves HuggingFace nil. To keep detection robust,
|
||||
// matchURIOwnerRepo() falls back to parsing the raw URI for "suno/bark*"
|
||||
// so the importer still fires end-to-end.
|
||||
type BarkImporter struct{}
|
||||
|
||||
// matchBarkURI tolerates a nil ModelDetails (see note above) by extracting
|
||||
// the HF owner+repo portion directly from the raw URI.
|
||||
func matchBarkURI(uri string) bool {
|
||||
owner, repo, ok := HFOwnerRepoFromURI(uri)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(owner, "suno") && strings.HasPrefix(strings.ToLower(repo), "bark")
|
||||
}
|
||||
|
||||
func (i *BarkImporter) Name() string { return "bark" }
|
||||
func (i *BarkImporter) Modality() string { return "tts" }
|
||||
func (i *BarkImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *BarkImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "bark" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "suno") {
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(repoName), "bark") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HF metadata may be absent when the recursive tree listing errors
|
||||
// (see type-level note). Fall back to URI parsing.
|
||||
return matchBarkURI(details.URI)
|
||||
}
|
||||
|
||||
func (i *BarkImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "bark",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/bark_test.go
Normal file
47
core/gallery/importers/bark_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("BarkImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches suno/bark (owner + repo name prefix)", func() {
|
||||
uri := "https://huggingface.co/suno/bark"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: bark"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("suno/bark"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=bark for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "bark"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: bark"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.BarkImporter{}
|
||||
Expect(imp.Name()).To(Equal("bark"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
110
core/gallery/importers/chatterbox.go
Normal file
110
core/gallery/importers/chatterbox.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &ChatterboxImporter{}
|
||||
|
||||
// ChatterboxImporter recognises Resemble AI's Chatterbox TTS. Detection
|
||||
// uses the `ResembleAI` owner or a "chatterbox" substring in the repo
|
||||
// name (covers the primary release plus community finetunes).
|
||||
// preferences.backend="chatterbox" overrides detection.
|
||||
type ChatterboxImporter struct{}
|
||||
|
||||
func (i *ChatterboxImporter) Name() string { return "chatterbox" }
|
||||
func (i *ChatterboxImporter) Modality() string { return "tts" }
|
||||
func (i *ChatterboxImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *ChatterboxImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "chatterbox" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "ResembleAI") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "chatterbox") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "ResembleAI") || strings.Contains(strings.ToLower(repo), "chatterbox") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *ChatterboxImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "chatterbox",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/chatterbox_test.go
Normal file
47
core/gallery/importers/chatterbox_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ChatterboxImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches ResembleAI/chatterbox (owner)", func() {
|
||||
uri := "https://huggingface.co/ResembleAI/chatterbox"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: chatterbox"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("ResembleAI/chatterbox"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=chatterbox for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "chatterbox"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: chatterbox"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.ChatterboxImporter{}
|
||||
Expect(imp.Name()).To(Equal("chatterbox"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
99
core/gallery/importers/coqui.go
Normal file
99
core/gallery/importers/coqui.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &CoquiImporter{}
|
||||
|
||||
// CoquiImporter recognises Coqui AI's open-weight TTS releases (XTTS-v2,
|
||||
// YourTTS, the Tortoise port, etc). Detection is owner-scoped to `coqui`
|
||||
// — their HF org is the authoritative publisher for models that run on
|
||||
// the Coqui TTS Python runtime. preferences.backend="coqui" overrides.
|
||||
type CoquiImporter struct{}
|
||||
|
||||
func (i *CoquiImporter) Name() string { return "coqui" }
|
||||
func (i *CoquiImporter) Modality() string { return "tts" }
|
||||
func (i *CoquiImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *CoquiImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "coqui" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil && strings.EqualFold(details.HuggingFace.Author, "coqui") {
|
||||
return true
|
||||
}
|
||||
|
||||
if owner, _, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
return strings.EqualFold(owner, "coqui")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *CoquiImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "coqui",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/coqui_test.go
Normal file
47
core/gallery/importers/coqui_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("CoquiImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches coqui/XTTS-v2 (owner)", func() {
|
||||
uri := "https://huggingface.co/coqui/XTTS-v2"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: coqui"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("coqui/XTTS-v2"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=coqui for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "coqui"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: coqui"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.CoquiImporter{}
|
||||
Expect(imp.Name()).To(Equal("coqui"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -15,6 +15,10 @@ var _ Importer = &DiffuserImporter{}
|
||||
|
||||
type DiffuserImporter struct{}
|
||||
|
||||
func (i *DiffuserImporter) Name() string { return "diffusers" }
|
||||
func (i *DiffuserImporter) Modality() string { return "image" }
|
||||
func (i *DiffuserImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *DiffuserImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
|
||||
117
core/gallery/importers/faster-whisper.go
Normal file
117
core/gallery/importers/faster-whisper.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &FasterWhisperImporter{}
|
||||
|
||||
// FasterWhisperImporter recognises CTranslate2-packaged whisper checkpoints
|
||||
// (the format consumed by the faster-whisper runtime). The classic layout is
|
||||
// a flat directory with model.bin + config.json and an ASR pipeline_tag.
|
||||
//
|
||||
// We disambiguate from vanilla OpenAI whisper repos — which would otherwise
|
||||
// also hit the tokenizer.json path and get routed to transformers — by
|
||||
// requiring either the Systran owner (the upstream distributor) or the
|
||||
// string "faster-whisper" in the repo name. preferences.backend=
|
||||
// faster-whisper overrides detection.
|
||||
type FasterWhisperImporter struct{}
|
||||
|
||||
func (i *FasterWhisperImporter) Name() string { return "faster-whisper" }
|
||||
func (i *FasterWhisperImporter) Modality() string { return "asr" }
|
||||
func (i *FasterWhisperImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *FasterWhisperImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "faster-whisper" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !HasFile(details.HuggingFace.Files, "model.bin") {
|
||||
return false
|
||||
}
|
||||
if !HasFile(details.HuggingFace.Files, "config.json") {
|
||||
return false
|
||||
}
|
||||
if details.HuggingFace.PipelineTag != "automatic-speech-recognition" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Narrow to the faster-whisper distribution: Systran owner OR
|
||||
// "faster-whisper" in the repo name. Without this guard, any vanilla
|
||||
// whisper repo on HF would also match the file pair and ASR tag.
|
||||
if strings.EqualFold(details.HuggingFace.Author, "Systran") {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(strings.ToLower(details.HuggingFace.ModelID), "faster-whisper")
|
||||
}
|
||||
|
||||
func (i *FasterWhisperImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "faster-whisper",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
46
core/gallery/importers/faster-whisper_test.go
Normal file
46
core/gallery/importers/faster-whisper_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("FasterWhisperImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches Systran/faster-whisper-large-v3 (model.bin + config.json + ASR)", func() {
|
||||
uri := "https://huggingface.co/Systran/faster-whisper-large-v3"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: faster-whisper"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("Systran/faster-whisper-large-v3"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=faster-whisper for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "faster-whisper"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: faster-whisper"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.FasterWhisperImporter{}
|
||||
Expect(imp.Name()).To(Equal("faster-whisper"))
|
||||
Expect(imp.Modality()).To(Equal("asr"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
101
core/gallery/importers/fish-speech.go
Normal file
101
core/gallery/importers/fish-speech.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &FishSpeechImporter{}
|
||||
|
||||
// FishSpeechImporter recognises Fish Audio's open-weights TTS releases
|
||||
// (Fish Speech, S1/S2 series). The `fishaudio` owner is the canonical
|
||||
// publisher — scoping by owner avoids false positives from generic
|
||||
// safetensors+tokenizer packaging used elsewhere.
|
||||
// preferences.backend="fish-speech" overrides detection.
|
||||
type FishSpeechImporter struct{}
|
||||
|
||||
func (i *FishSpeechImporter) Name() string { return "fish-speech" }
|
||||
func (i *FishSpeechImporter) Modality() string { return "tts" }
|
||||
func (i *FishSpeechImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *FishSpeechImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "fish-speech" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil && strings.EqualFold(details.HuggingFace.Author, "fishaudio") {
|
||||
return true
|
||||
}
|
||||
// URI fallback for parity with other TTS importers when HF metadata
|
||||
// fetching fails (see BarkImporter note).
|
||||
if owner, _, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
return strings.EqualFold(owner, "fishaudio")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *FishSpeechImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "fish-speech",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/fish-speech_test.go
Normal file
47
core/gallery/importers/fish-speech_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("FishSpeechImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches fishaudio/s2-pro (owner = fishaudio)", func() {
|
||||
uri := "https://huggingface.co/fishaudio/s2-pro"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: fish-speech"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("fishaudio/s2-pro"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=fish-speech for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "fish-speech"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: fish-speech"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.FishSpeechImporter{}
|
||||
Expect(imp.Name()).To(Equal("fish-speech"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
91
core/gallery/importers/helpers.go
Normal file
91
core/gallery/importers/helpers.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
// HasFile returns true when any file in files has exactly the given basename.
|
||||
// Directory components in file.Path are ignored — a nested
|
||||
// "sub/dir/config.json" is considered a match for name = "config.json".
|
||||
func HasFile(files []hfapi.ModelFile, name string) bool {
|
||||
for _, f := range files {
|
||||
if filepath.Base(f.Path) == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasExtension returns true when any file has the given extension
|
||||
// (case-insensitive). ext must include the leading dot, e.g. ".onnx".
|
||||
func HasExtension(files []hfapi.ModelFile, ext string) bool {
|
||||
lower := strings.ToLower(ext)
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(strings.ToLower(f.Path), lower) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasONNX returns true when any file ends in .onnx (case-insensitive).
|
||||
func HasONNX(files []hfapi.ModelFile) bool {
|
||||
return HasExtension(files, ".onnx")
|
||||
}
|
||||
|
||||
// HasONNXConfigPair returns true when an .onnx file has an accompanying
|
||||
// "<same basename>.onnx.json" file. This is the piper voice packaging
|
||||
// convention, e.g. en_US-amy-medium.onnx + en_US-amy-medium.onnx.json.
|
||||
func HasONNXConfigPair(files []hfapi.ModelFile) bool {
|
||||
paths := make(map[string]struct{}, len(files))
|
||||
for _, f := range files {
|
||||
paths[strings.ToLower(f.Path)] = struct{}{}
|
||||
}
|
||||
for p := range paths {
|
||||
if !strings.HasSuffix(p, ".onnx") {
|
||||
continue
|
||||
}
|
||||
if _, ok := paths[p+".json"]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HFOwnerRepoFromURI extracts the "owner", "repo" pair from an HF URI.
|
||||
// Accepted prefixes: "https://huggingface.co/", "huggingface://", "hf://".
|
||||
// Returns ok=false when the URI is not an HF URI or is missing either
|
||||
// component. This exists so importers can fall back to URI-based matching
|
||||
// when pkg/huggingface-api's recursive tree listing errors out on repos
|
||||
// with nested subdirectories (a known pre-existing bug).
|
||||
func HFOwnerRepoFromURI(uri string) (owner, repo string, ok bool) {
|
||||
stripped := uri
|
||||
for _, pfx := range []string{"https://huggingface.co/", "huggingface://", "hf://"} {
|
||||
stripped = strings.TrimPrefix(stripped, pfx)
|
||||
}
|
||||
parts := strings.SplitN(stripped, "/", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
|
||||
// HasGGMLFile returns true when any file matches "<prefix>*.bin", which is
|
||||
// the whisper.cpp packaging convention (e.g. "ggml-base.en.bin"). Both prefix
|
||||
// and suffix match is case-sensitive on prefix and case-insensitive on the
|
||||
// .bin extension.
|
||||
func HasGGMLFile(files []hfapi.ModelFile, prefix string) bool {
|
||||
for _, f := range files {
|
||||
name := filepath.Base(f.Path)
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(name), ".bin") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
89
core/gallery/importers/helpers_test.go
Normal file
89
core/gallery/importers/helpers_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
var _ = Describe("importer helpers", func() {
|
||||
mkFiles := func(paths ...string) []hfapi.ModelFile {
|
||||
out := make([]hfapi.ModelFile, 0, len(paths))
|
||||
for _, p := range paths {
|
||||
out = append(out, hfapi.ModelFile{Path: p})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
Describe("HasFile", func() {
|
||||
It("returns false for an empty slice", func() {
|
||||
Expect(importers.HasFile(nil, "config.json")).To(BeFalse())
|
||||
Expect(importers.HasFile([]hfapi.ModelFile{}, "config.json")).To(BeFalse())
|
||||
})
|
||||
It("matches on exact basename, ignoring directory components", func() {
|
||||
files := mkFiles("sub/dir/config.json", "other.txt")
|
||||
Expect(importers.HasFile(files, "config.json")).To(BeTrue())
|
||||
})
|
||||
It("does not match partial basenames", func() {
|
||||
files := mkFiles("sub/dir/myconfig.json")
|
||||
Expect(importers.HasFile(files, "config.json")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HasExtension", func() {
|
||||
It("matches case-insensitively", func() {
|
||||
files := mkFiles("model.ONNX", "other.txt")
|
||||
Expect(importers.HasExtension(files, ".onnx")).To(BeTrue())
|
||||
})
|
||||
It("returns false when no file has the extension", func() {
|
||||
Expect(importers.HasExtension(mkFiles("README.md"), ".onnx")).To(BeFalse())
|
||||
})
|
||||
It("handles empty slices gracefully", func() {
|
||||
Expect(importers.HasExtension(nil, ".onnx")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HasONNX", func() {
|
||||
It("is true when any file ends in .onnx", func() {
|
||||
Expect(importers.HasONNX(mkFiles("voice/en_US-amy-medium.onnx"))).To(BeTrue())
|
||||
})
|
||||
It("is false otherwise", func() {
|
||||
Expect(importers.HasONNX(mkFiles("model.bin"))).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HasONNXConfigPair", func() {
|
||||
It("matches the piper .onnx + .onnx.json pair", func() {
|
||||
files := mkFiles(
|
||||
"en/en_US/amy/medium/en_US-amy-medium.onnx",
|
||||
"en/en_US/amy/medium/en_US-amy-medium.onnx.json",
|
||||
)
|
||||
Expect(importers.HasONNXConfigPair(files)).To(BeTrue())
|
||||
})
|
||||
It("requires the accompanying json to share the .onnx basename", func() {
|
||||
files := mkFiles("model.onnx", "config.json")
|
||||
Expect(importers.HasONNXConfigPair(files)).To(BeFalse())
|
||||
})
|
||||
It("returns false for a lone .onnx file", func() {
|
||||
files := mkFiles("model.onnx")
|
||||
Expect(importers.HasONNXConfigPair(files)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HasGGMLFile", func() {
|
||||
It("finds ggml-prefixed .bin files", func() {
|
||||
files := mkFiles("ggml-base.en.bin", "README.md")
|
||||
Expect(importers.HasGGMLFile(files, "ggml-")).To(BeTrue())
|
||||
})
|
||||
It("requires both prefix and .bin suffix", func() {
|
||||
files := mkFiles("ggml-base.en.gguf")
|
||||
Expect(importers.HasGGMLFile(files, "ggml-")).To(BeFalse())
|
||||
})
|
||||
It("returns false when prefix does not match", func() {
|
||||
files := mkFiles("whisper.bin")
|
||||
Expect(importers.HasGGMLFile(files, "ggml-")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -2,8 +2,10 @@ package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
@@ -15,7 +17,138 @@ import (
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
// ErrAmbiguousImport is returned when HuggingFace metadata hints at a known
|
||||
// modality (e.g. pipeline_tag: "automatic-speech-recognition") but no
|
||||
// importer's artefact-level detection matches the repository. Callers should
|
||||
// pass preferences.backend to disambiguate. Use errors.Is to match regardless
|
||||
// of wrapping — DiscoverModelConfig returns a typed AmbiguousImportError that
|
||||
// carries the detected modality + candidate backends, and whose Is() matches
|
||||
// this sentinel so legacy callers keep working.
|
||||
var ErrAmbiguousImport = errors.New("importer: ambiguous — specify preferences.backend")
|
||||
|
||||
// AmbiguousImportError is the concrete error DiscoverModelConfig returns when
|
||||
// it can't pick an importer automatically. It carries the importer-modality
|
||||
// key (e.g. "tts", "asr") and the list of candidate backend names so HTTP
|
||||
// consumers can render a picker without re-deriving the mapping from HF
|
||||
// pipeline_tag values.
|
||||
type AmbiguousImportError struct {
|
||||
// Modality is the importer modality key ("text", "asr", "tts", "image",
|
||||
// "embeddings", "reranker", "detection"). Pre-mapped from the HF
|
||||
// pipeline_tag so the UI doesn't have to.
|
||||
Modality string
|
||||
// Candidates is the list of backend names whose Modality() matches — a
|
||||
// subset of the importer registry plus AdditionalBackendsProvider
|
||||
// drop-ins.
|
||||
Candidates []string
|
||||
// URI is the original URI that triggered the ambiguity.
|
||||
URI string
|
||||
// PipelineTag is the raw HF pipeline_tag value as reported by the model
|
||||
// metadata — preserved for logging / debugging.
|
||||
PipelineTag string
|
||||
}
|
||||
|
||||
func (e *AmbiguousImportError) Error() string {
|
||||
return fmt.Sprintf("importer: ambiguous — detected modality %q (pipeline_tag=%q) for %s, candidates: %v",
|
||||
e.Modality, e.PipelineTag, e.URI, e.Candidates)
|
||||
}
|
||||
|
||||
// Is lets callers match with errors.Is(err, ErrAmbiguousImport) without caring
|
||||
// about the typed shape.
|
||||
func (e *AmbiguousImportError) Is(target error) bool {
|
||||
return target == ErrAmbiguousImport
|
||||
}
|
||||
|
||||
// ambiguousModalities enumerates the HF pipeline_tag values that are narrow
|
||||
// enough to be confident we should surface ambiguity instead of a generic
|
||||
// "no importer matched" error. Tags outside this whitelist keep the previous
|
||||
// behaviour (plain error) so we don't block uncommon-but-still-valid imports.
|
||||
// The mapped value is the importer modality key used to filter candidates.
|
||||
var ambiguousModalities = map[string]string{
|
||||
"automatic-speech-recognition": "asr",
|
||||
"text-to-speech": "tts",
|
||||
"sentence-similarity": "embeddings",
|
||||
"text-classification": "reranker",
|
||||
"object-detection": "detection",
|
||||
"text-to-image": "image",
|
||||
}
|
||||
|
||||
// PipelineTagToModality maps HF pipeline_tag strings to the importer modality
|
||||
// key used internally (and by /backends/known). Returns the modality + true
|
||||
// when the tag is in the ambiguous whitelist; "" + false otherwise.
|
||||
func PipelineTagToModality(pipelineTag string) (string, bool) {
|
||||
m, ok := ambiguousModalities[pipelineTag]
|
||||
return m, ok
|
||||
}
|
||||
|
||||
// CandidatesForModality returns the backend names whose importer modality
|
||||
// matches the requested key. Includes AdditionalBackendsProvider drop-ins so
|
||||
// entries like ik-llama-cpp surface for text modalities. Results are sorted
|
||||
// for deterministic ordering in API responses.
|
||||
func CandidatesForModality(modality string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
for _, imp := range defaultImporters {
|
||||
if imp.Modality() != modality {
|
||||
continue
|
||||
}
|
||||
seen[imp.Name()] = struct{}{}
|
||||
if host, ok := imp.(AdditionalBackendsProvider); ok {
|
||||
for _, extra := range host.AdditionalBackends() {
|
||||
if extra.Modality != modality {
|
||||
continue
|
||||
}
|
||||
seen[extra.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
out := make([]string, 0, len(seen))
|
||||
for n := range seen {
|
||||
out = append(out, n)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
var defaultImporters = []Importer{
|
||||
// ASR (Batch 1)
|
||||
&WhisperImporter{},
|
||||
&MoonshineImporter{},
|
||||
&NemoImporter{},
|
||||
&FasterWhisperImporter{},
|
||||
&QwenASRImporter{},
|
||||
// TTS (Batch 2)
|
||||
&PiperImporter{},
|
||||
&BarkImporter{},
|
||||
&FishSpeechImporter{},
|
||||
&OutettsImporter{},
|
||||
&VoxCPMImporter{},
|
||||
&KokoroImporter{},
|
||||
&KittenTTSImporter{},
|
||||
&NeuTTSImporter{},
|
||||
&ChatterboxImporter{},
|
||||
&VibeVoiceImporter{},
|
||||
&CoquiImporter{},
|
||||
// Image/Video (Batch 3)
|
||||
&StableDiffusionGGMLImporter{},
|
||||
&ACEStepImporter{},
|
||||
// Text LLM (Batch 4) — VLLMOmniImporter must stay ahead of
|
||||
// VLLMImporter so Qwen Omni repos (which also carry tokenizer
|
||||
// files) route to vllm-omni rather than plain vllm.
|
||||
&VLLMOmniImporter{},
|
||||
// Embeddings / rerankers / detection / VAD (Batch 5)
|
||||
// SileroVADImporter first — unique filename signal, cannot collide.
|
||||
&SileroVADImporter{},
|
||||
// RerankersImporter must run before SentenceTransformers and
|
||||
// Transformers — some reranker repos ship modules.json and tokenizer
|
||||
// files that those importers would otherwise claim.
|
||||
&RerankersImporter{},
|
||||
// SentenceTransformersImporter must run before TransformersImporter:
|
||||
// sentence-transformers repos ship tokenizer.json which transformers
|
||||
// would otherwise claim.
|
||||
&SentenceTransformersImporter{},
|
||||
// RFDetrImporter must run before TransformersImporter — RF-DETR
|
||||
// checkpoints may carry tokenizer-adjacent artefacts.
|
||||
&RFDetrImporter{},
|
||||
// Existing
|
||||
&LlamaCPPImporter{},
|
||||
&MLXImporter{},
|
||||
&VLLMImporter{},
|
||||
@@ -32,6 +165,42 @@ type Details struct {
|
||||
type Importer interface {
|
||||
Match(details Details) bool
|
||||
Import(details Details) (gallery.ModelConfig, error)
|
||||
// Name is the canonical backend name (e.g. "llama-cpp"). Used by
|
||||
// /backends/known to populate the import form dropdown.
|
||||
Name() string
|
||||
// Modality is the backend's primary modality ("text", "asr", "tts",
|
||||
// "image", "embeddings", "reranker", "detection", "vad"). Used for
|
||||
// grouping in the UI.
|
||||
Modality() string
|
||||
// AutoDetects is true when Match() can fire without an explicit
|
||||
// preferences.backend. Preference-only entries surface as
|
||||
// AutoDetect=false in /backends/known.
|
||||
AutoDetects() bool
|
||||
}
|
||||
|
||||
// KnownBackendEntry describes one backend advertised by an importer.
|
||||
// Importers that host drop-in replacements (e.g. llama-cpp hosting
|
||||
// ik-llama-cpp and turboquant) return additional entries via
|
||||
// AdditionalBackendsProvider so the endpoint can surface them without
|
||||
// registering separate importers.
|
||||
type KnownBackendEntry struct {
|
||||
Name string
|
||||
Modality string
|
||||
Description string
|
||||
}
|
||||
|
||||
// AdditionalBackendsProvider is implemented by importers that advertise
|
||||
// drop-in replacements sharing their Match/Import logic. The entries
|
||||
// appear in /backends/known with AutoDetect=false since they are
|
||||
// preference-only.
|
||||
type AdditionalBackendsProvider interface {
|
||||
AdditionalBackends() []KnownBackendEntry
|
||||
}
|
||||
|
||||
// Registry returns the list of registered importers. Callers must not
|
||||
// mutate the returned slice.
|
||||
func Registry() []Importer {
|
||||
return defaultImporters
|
||||
}
|
||||
|
||||
func hasYAMLExtension(uri string) bool {
|
||||
@@ -115,6 +284,19 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model
|
||||
}
|
||||
}
|
||||
if !importerMatched {
|
||||
// When HuggingFace metadata hints at a known, narrow modality but no
|
||||
// importer matched the artefacts, surface an explicit ambiguity so the
|
||||
// caller knows to pass preferences.backend rather than silently guess.
|
||||
if hfDetails != nil && hfDetails.PipelineTag != "" {
|
||||
if modality, known := ambiguousModalities[hfDetails.PipelineTag]; known {
|
||||
return gallery.ModelConfig{}, &AmbiguousImportError{
|
||||
Modality: modality,
|
||||
Candidates: CandidatesForModality(modality),
|
||||
URI: uri,
|
||||
PipelineTag: hfDetails.PipelineTag,
|
||||
}
|
||||
}
|
||||
}
|
||||
return gallery.ModelConfig{}, fmt.Errorf("no importer matched for %s", uri)
|
||||
}
|
||||
return modelConfig, nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -201,6 +202,99 @@ var _ = Describe("DiscoverModelConfig", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("ErrAmbiguousImport sentinel", func() {
|
||||
It("is defined so callers can match with errors.Is", func() {
|
||||
Expect(importers.ErrAmbiguousImport).ToNot(BeNil())
|
||||
// Wrapping-sanity: fmt.Errorf("%w", err) preserves identity.
|
||||
wrapped := fmt.Errorf("context: %w", importers.ErrAmbiguousImport)
|
||||
Expect(errors.Is(wrapped, importers.ErrAmbiguousImport)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("surfaces modality and candidates on the typed error for HTTP consumers", func() {
|
||||
// TTS fixture — pipeline_tag=text-to-speech, no importer matches.
|
||||
uri := "https://huggingface.co/nari-labs/Dia-1.6B"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue())
|
||||
|
||||
var amb *importers.AmbiguousImportError
|
||||
Expect(errors.As(err, &amb)).To(BeTrue(), "expected AmbiguousImportError, got: %v", err)
|
||||
Expect(amb.Modality).To(Equal("tts"))
|
||||
Expect(amb.Candidates).To(ContainElements("piper", "bark", "kokoro"))
|
||||
Expect(amb.Candidates).ToNot(ContainElement("llama-cpp"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
// These tests drive the /backends/known endpoint: each importer must
|
||||
// self-describe its canonical name, primary modality, and whether it
|
||||
// can auto-detect without an explicit preference.
|
||||
It("Registry returns all default importers", func() {
|
||||
registry := importers.Registry()
|
||||
Expect(registry).ToNot(BeEmpty())
|
||||
names := make([]string, 0, len(registry))
|
||||
for _, imp := range registry {
|
||||
names = append(names, imp.Name())
|
||||
}
|
||||
Expect(names).To(ContainElements("llama-cpp", "mlx", "vllm", "transformers", "diffusers"))
|
||||
})
|
||||
|
||||
It("LlamaCPPImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.LlamaCPPImporter{}
|
||||
Expect(imp.Name()).To(Equal("llama-cpp"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("MLXImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.MLXImporter{}
|
||||
Expect(imp.Name()).To(Equal("mlx"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("VLLMImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.VLLMImporter{}
|
||||
Expect(imp.Name()).To(Equal("vllm"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("TransformersImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.TransformersImporter{}
|
||||
Expect(imp.Name()).To(Equal("transformers"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("DiffuserImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.DiffuserImporter{}
|
||||
Expect(imp.Name()).To(Equal("diffusers"))
|
||||
Expect(imp.Modality()).To(Equal("image"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("LlamaCPPImporter advertises drop-in replacements", func() {
|
||||
imp := &importers.LlamaCPPImporter{}
|
||||
provider, ok := any(imp).(importers.AdditionalBackendsProvider)
|
||||
Expect(ok).To(BeTrue(), "LlamaCPPImporter must implement AdditionalBackendsProvider")
|
||||
|
||||
extras := provider.AdditionalBackends()
|
||||
names := make([]string, 0, len(extras))
|
||||
modalities := make([]string, 0, len(extras))
|
||||
for _, e := range extras {
|
||||
names = append(names, e.Name)
|
||||
modalities = append(modalities, e.Modality)
|
||||
}
|
||||
Expect(names).To(ContainElements("ik-llama-cpp", "turboquant"))
|
||||
for _, m := range modalities {
|
||||
Expect(m).To(Equal("text"))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("with invalid JSON preferences", func() {
|
||||
It("should return error when JSON is invalid even if URI matches", func() {
|
||||
uri := "https://example.com/model.gguf"
|
||||
|
||||
109
core/gallery/importers/kitten-tts.go
Normal file
109
core/gallery/importers/kitten-tts.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &KittenTTSImporter{}
|
||||
|
||||
// KittenTTSImporter recognises KittenML's kitten-tts releases. Detection
|
||||
// uses the `KittenML` owner or a "kitten-tts" substring in the repo name
|
||||
// for third-party mirrors. preferences.backend="kitten-tts" overrides.
|
||||
type KittenTTSImporter struct{}
|
||||
|
||||
func (i *KittenTTSImporter) Name() string { return "kitten-tts" }
|
||||
func (i *KittenTTSImporter) Modality() string { return "tts" }
|
||||
func (i *KittenTTSImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *KittenTTSImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "kitten-tts" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "KittenML") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "kitten-tts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "KittenML") || strings.Contains(strings.ToLower(repo), "kitten-tts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *KittenTTSImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "kitten-tts",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/kitten-tts_test.go
Normal file
47
core/gallery/importers/kitten-tts_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("KittenTTSImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches KittenML/kitten-tts-nano-0.1 (owner + repo name)", func() {
|
||||
uri := "https://huggingface.co/KittenML/kitten-tts-nano-0.1"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: kitten-tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("KittenML/kitten-tts-nano-0.1"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=kitten-tts for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "kitten-tts"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: kitten-tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.KittenTTSImporter{}
|
||||
Expect(imp.Name()).To(Equal("kitten-tts"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
110
core/gallery/importers/kokoro.go
Normal file
110
core/gallery/importers/kokoro.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &KokoroImporter{}
|
||||
|
||||
// KokoroImporter recognises hexgrad's Kokoro TTS family (Kokoro-82M,
|
||||
// Kokoro-82M-v1.1-zh, …). The repo name carries "Kokoro" and the weights
|
||||
// ship as a PyTorch .pth/.pt — pairing the two keeps us from claiming the
|
||||
// quantised GGUF mirrors (which llama-cpp handles) or the ONNX exports
|
||||
// (which the pref-only `kokoros` Rust runtime handles).
|
||||
// preferences.backend="kokoro" overrides detection; preferences.backend
|
||||
// ="kokoros" deliberately does *not* trigger this importer (see test).
|
||||
type KokoroImporter struct{}
|
||||
|
||||
func (i *KokoroImporter) Name() string { return "kokoro" }
|
||||
func (i *KokoroImporter) Modality() string { return "tts" }
|
||||
func (i *KokoroImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *KokoroImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Explicit "kokoro" overrides. "kokoros" is intentionally distinct:
|
||||
// comparing against the exact string prevents pref-only kokoros
|
||||
// requests from hijacking this importer.
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "kokoro" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(repoName), "kokoro") {
|
||||
return false
|
||||
}
|
||||
// Require a PyTorch checkpoint to disambiguate from ONNX-only or
|
||||
// GGUF-only mirrors that route to other backends.
|
||||
return HasExtension(details.HuggingFace.Files, ".pth") || HasExtension(details.HuggingFace.Files, ".pt")
|
||||
}
|
||||
|
||||
func (i *KokoroImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "kokoro",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
63
core/gallery/importers/kokoro_test.go
Normal file
63
core/gallery/importers/kokoro_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("KokoroImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches hexgrad/Kokoro-82M (repo name + .pth)", func() {
|
||||
uri := "https://huggingface.co/hexgrad/Kokoro-82M"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: kokoro"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("hexgrad/Kokoro-82M"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=kokoro for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "kokoro"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: kokoro"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("kokoros preference disambiguation", func() {
|
||||
It("does not hijack preferences.backend=kokoros (pref-only)", func() {
|
||||
// The kokoros Rust runtime is pref-only (listed in
|
||||
// knownPrefOnlyBackends). The autodetect path for the kokoro
|
||||
// importer must NOT fire when the user explicitly selects the
|
||||
// kokoros backend for an arbitrary URI — if it did, DiscoverModelConfig
|
||||
// would incorrectly emit backend=kokoro.
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "kokoros"}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
// kokoros has no importer, so discovery should not match anything.
|
||||
Expect(err).To(HaveOccurred(), "kokoros is pref-only — DiscoverModelConfig should not match any importer")
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.KokoroImporter{}
|
||||
Expect(imp.Name()).To(Equal("kokoro"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -3,7 +3,6 @@ package importers
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -11,14 +10,34 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
"github.com/mudler/xlog"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &LlamaCPPImporter{}
|
||||
var (
|
||||
_ Importer = &LlamaCPPImporter{}
|
||||
_ AdditionalBackendsProvider = &LlamaCPPImporter{}
|
||||
)
|
||||
|
||||
type LlamaCPPImporter struct{}
|
||||
|
||||
func (i *LlamaCPPImporter) Name() string { return "llama-cpp" }
|
||||
func (i *LlamaCPPImporter) Modality() string { return "text" }
|
||||
func (i *LlamaCPPImporter) AutoDetects() bool { return true }
|
||||
|
||||
// AdditionalBackends advertises drop-in replacements that share the
|
||||
// llama-cpp detection logic. They are preference-only: selecting one
|
||||
// from the import form swaps the emitted YAML backend field but reuses
|
||||
// the llama-cpp Match/Import pipeline.
|
||||
func (i *LlamaCPPImporter) AdditionalBackends() []KnownBackendEntry {
|
||||
return []KnownBackendEntry{
|
||||
{Name: "ik-llama-cpp", Modality: "text", Description: "GGUF drop-in replacement for llama-cpp with ik-quants"},
|
||||
{Name: "turboquant", Modality: "text", Description: "GGUF drop-in replacement for llama-cpp with TurboQuant optimizations"},
|
||||
{Name: "buun-llama-cpp", Modality: "text", Description: "GGUF drop-in replacement for llama-cpp with DFlash speculative decoding and TurboQuant/TCQ KV-cache quantization"},
|
||||
}
|
||||
}
|
||||
|
||||
func (i *LlamaCPPImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
@@ -101,12 +120,25 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
|
||||
embeddings, _ := preferencesMap["embeddings"].(string)
|
||||
|
||||
// Honour drop-in replacement preferences. Only the curated names
|
||||
// advertised via AdditionalBackends() are accepted; anything else
|
||||
// (including "llama-cpp" itself, or an unknown value) keeps the
|
||||
// default backend field so arbitrary input can't leak through. See
|
||||
// the AdditionalBackends method for the canonical list.
|
||||
backend := "llama-cpp"
|
||||
if b, ok := preferencesMap["backend"].(string); ok {
|
||||
switch b {
|
||||
case "ik-llama-cpp", "turboquant", "buun-llama-cpp":
|
||||
backend = b
|
||||
}
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Options: []string{"use_jinja:true"},
|
||||
Backend: "llama-cpp",
|
||||
Backend: backend,
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
@@ -172,59 +204,34 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
},
|
||||
}
|
||||
case details.HuggingFace != nil:
|
||||
// We want to:
|
||||
// Get first the chosen quants that match filenames
|
||||
// OR the first mmproj/gguf file found
|
||||
var lastMMProjFile *gallery.File
|
||||
var lastGGUFFile *gallery.File
|
||||
foundPreferedQuant := false
|
||||
foundPreferedMMprojQuant := false
|
||||
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
// Get the mmproj prefered quants
|
||||
if strings.Contains(strings.ToLower(file.Path), "mmproj") {
|
||||
lastMMProjFile = &gallery.File{
|
||||
URI: file.URL,
|
||||
Filename: filepath.Join("llama-cpp", "mmproj", name, filepath.Base(file.Path)),
|
||||
SHA256: file.SHA256,
|
||||
}
|
||||
if slices.ContainsFunc(mmprojQuantsList, func(quant string) bool {
|
||||
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastMMProjFile)
|
||||
foundPreferedMMprojQuant = true
|
||||
}
|
||||
} else if strings.HasSuffix(strings.ToLower(file.Path), "gguf") {
|
||||
lastGGUFFile = &gallery.File{
|
||||
URI: file.URL,
|
||||
Filename: filepath.Join("llama-cpp", "models", name, filepath.Base(file.Path)),
|
||||
SHA256: file.SHA256,
|
||||
}
|
||||
// get the files of the prefered quants
|
||||
if slices.ContainsFunc(quants, func(quant string) bool {
|
||||
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
|
||||
}) {
|
||||
foundPreferedQuant = true
|
||||
cfg.Files = append(cfg.Files, *lastGGUFFile)
|
||||
}
|
||||
// Split the repo listing into mmproj vs plain GGUF files, then group
|
||||
// shards so every multi-part GGUF (llama.cpp `-NNNNN-of-MMMMM.gguf`
|
||||
// pattern) is treated as one logical selection candidate. The
|
||||
// previous implementation picked files one at a time, so sharded
|
||||
// models ended up with only the last part referenced in the gallery
|
||||
// entry — useless to llama.cpp, which needs shard 1 and the whole
|
||||
// set to load a split model.
|
||||
var mmprojFiles, ggufFiles []hfapi.ModelFile
|
||||
for _, f := range details.HuggingFace.Files {
|
||||
lowerPath := strings.ToLower(f.Path)
|
||||
switch {
|
||||
case strings.Contains(lowerPath, "mmproj"):
|
||||
mmprojFiles = append(mmprojFiles, f)
|
||||
case strings.HasSuffix(lowerPath, ".gguf"):
|
||||
ggufFiles = append(ggufFiles, f)
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure to add at least one file if not already present (which is the latest one)
|
||||
if lastMMProjFile != nil && !foundPreferedMMprojQuant {
|
||||
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
|
||||
return f.Filename == lastMMProjFile.Filename
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastMMProjFile)
|
||||
}
|
||||
}
|
||||
mmprojGroups := hfapi.GroupShards(mmprojFiles)
|
||||
ggufGroups := hfapi.GroupShards(ggufFiles)
|
||||
|
||||
if lastGGUFFile != nil && !foundPreferedQuant {
|
||||
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
|
||||
return f.Filename == lastGGUFFile.Filename
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastGGUFFile)
|
||||
}
|
||||
// Emit the model group first so cfg.Files[0] is the model — callers
|
||||
// and tests rely on the model file preceding any mmproj companion.
|
||||
if group := pickPreferredGroup(ggufGroups, quants); group != nil {
|
||||
appendShardGroup(&cfg, *group, filepath.Join("llama-cpp", "models", name))
|
||||
}
|
||||
if group := pickPreferredGroup(mmprojGroups, mmprojQuantsList); group != nil {
|
||||
appendShardGroup(&cfg, *group, filepath.Join("llama-cpp", "mmproj", name))
|
||||
}
|
||||
|
||||
// Find first mmproj file and configure it in the config file
|
||||
@@ -236,7 +243,9 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
break
|
||||
}
|
||||
|
||||
// Find first non-mmproj file and configure it in the config file
|
||||
// Find first non-mmproj file and configure it in the config file.
|
||||
// For sharded models this is shard 1 — llama.cpp's split loader
|
||||
// discovers the remaining shards by filename pattern from there.
|
||||
for _, file := range cfg.Files {
|
||||
if strings.Contains(strings.ToLower(file.Filename), "mmproj") {
|
||||
continue
|
||||
@@ -262,3 +271,48 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// pickPreferredGroup walks the preference list in priority order and returns
|
||||
// the first group whose base filename contains any preference. When nothing
|
||||
// matches, the last group wins — this preserves the historical "if the user
|
||||
// asked for a quant we don't have, fall back to whatever's available"
|
||||
// behaviour, lifted to whole shard sets.
|
||||
func pickPreferredGroup(groups []hfapi.ShardGroup, prefs []string) *hfapi.ShardGroup {
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, pref := range prefs {
|
||||
lower := strings.ToLower(pref)
|
||||
for i := range groups {
|
||||
if strings.Contains(strings.ToLower(groups[i].Base), lower) {
|
||||
return &groups[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return &groups[len(groups)-1]
|
||||
}
|
||||
|
||||
// appendShardGroup copies every shard of group into cfg.Files under dest,
|
||||
// skipping any entry whose target filename is already present so repeated
|
||||
// calls (e.g. the rare case of mmproj + model picking the same group)
|
||||
// don't produce duplicates.
|
||||
func appendShardGroup(cfg *gallery.ModelConfig, group hfapi.ShardGroup, dest string) {
|
||||
for _, f := range group.Files {
|
||||
target := filepath.Join(dest, filepath.Base(f.Path))
|
||||
duplicate := false
|
||||
for _, existing := range cfg.Files {
|
||||
if existing.Filename == target {
|
||||
duplicate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if duplicate {
|
||||
continue
|
||||
}
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: f.URL,
|
||||
Filename: target,
|
||||
SHA256: f.SHA256,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -129,4 +130,290 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("drop-in backend preferences", func() {
|
||||
// baseline: no preference keeps backend: llama-cpp and the file
|
||||
// layout that downstream assertions depend on.
|
||||
It("emits backend: llama-cpp when no backend preference is set", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("swaps the emitted backend to ik-llama-cpp when preferred", func() {
|
||||
preferences := json.RawMessage(`{"backend": "ik-llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: ik-llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).NotTo(ContainSubstring("backend: llama-cpp\n"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// Model path must remain identical to the llama-cpp baseline.
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"))
|
||||
})
|
||||
|
||||
It("swaps the emitted backend to turboquant when preferred", func() {
|
||||
preferences := json.RawMessage(`{"backend": "turboquant"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: turboquant"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).NotTo(ContainSubstring("backend: llama-cpp\n"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"))
|
||||
})
|
||||
|
||||
It("swaps the emitted backend to buun-llama-cpp when preferred", func() {
|
||||
preferences := json.RawMessage(`{"backend": "buun-llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: buun-llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).NotTo(ContainSubstring("backend: llama-cpp\n"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"))
|
||||
})
|
||||
|
||||
It("keeps backend: llama-cpp for unknown backend preferences", func() {
|
||||
// Unknown backend values must not leak into the emitted YAML —
|
||||
// we only honour the two curated drop-in replacements.
|
||||
preferences := json.RawMessage(`{"backend": "something-weird"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import from HuggingFace file listing", func() {
|
||||
// These tests exercise the HF branch of Import() without touching
|
||||
// the network — they construct a fake *hfapi.ModelDetails and
|
||||
// assert the emitted gallery entry directly. Historically the HF
|
||||
// branch was only covered by live-API integration specs in
|
||||
// importers_test.go; anything that happened in between (shard
|
||||
// grouping, quant fallback) had no unit-level regression net.
|
||||
|
||||
const repoBase = "https://huggingface.co/acme/example-GGUF/resolve/main/"
|
||||
|
||||
hfFile := func(path, sha string) hfapi.ModelFile {
|
||||
return hfapi.ModelFile{
|
||||
Path: path,
|
||||
SHA256: sha,
|
||||
URL: repoBase + path,
|
||||
}
|
||||
}
|
||||
|
||||
withHF := func(preferences string, files ...hfapi.ModelFile) Details {
|
||||
d := Details{
|
||||
URI: "https://huggingface.co/acme/example-GGUF",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "acme/example-GGUF",
|
||||
Files: files,
|
||||
},
|
||||
}
|
||||
if preferences != "" {
|
||||
d.Preferences = json.RawMessage(preferences)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
It("picks the preferred quant in a single-file repo", func() {
|
||||
details := withHF(`{"name":"example","quantizations":"Q4_K_M"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "aaa"),
|
||||
hfFile("model-Q3_K_M.gguf", "bbb"),
|
||||
hfFile("README.md", ""),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(1), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/example/model-Q4_K_M.gguf"))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal(repoBase + "model-Q4_K_M.gguf"))
|
||||
Expect(modelConfig.Files[0].SHA256).To(Equal("aaa"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/example/model-Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("falls back to the last group when no preference matches", func() {
|
||||
// Default preference is q4_k_m; the repo has only Q8_0 and
|
||||
// Q3_K_M. The old implementation would emit exactly the last
|
||||
// file seen — this test pins the fallback behaviour so the
|
||||
// group-level fallback keeps matching the historical intent.
|
||||
details := withHF(`{"name":"example"}`,
|
||||
hfFile("model-Q8_0.gguf", "aaa"),
|
||||
hfFile("model-Q3_K_M.gguf", "bbb"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(1), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/example/model-Q3_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("emits all shards of a multi-part GGUF and points Model at shard 1", func() {
|
||||
// Regression for PR #9510: unsloth/Kimi-K2.6-GGUF ships 14
|
||||
// Q8_K_XL shards. Default prefs are q4_k_m; none match, so the
|
||||
// fallback must take the whole shard group (not just the last
|
||||
// shard) and the config's `model:` must point at shard 1 so
|
||||
// llama.cpp's split loader can walk the rest.
|
||||
files := make([]hfapi.ModelFile, 0, 14)
|
||||
// Deliberately add shards out of order to prove sorting works.
|
||||
for _, idx := range []int{7, 1, 14, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12, 13} {
|
||||
files = append(files, hfFile(
|
||||
fmt.Sprintf("Kimi-K2.6-UD-Q8_K_XL-%05d-of-00014.gguf", idx),
|
||||
fmt.Sprintf("sha-%02d", idx),
|
||||
))
|
||||
}
|
||||
|
||||
details := withHF(`{"name":"Kimi-K2.6-GGUF"}`, files...)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(14), fmt.Sprintf("%+v", modelConfig))
|
||||
|
||||
// All 14 shards must be present, in order, under the models dir.
|
||||
for i := 1; i <= 14; i++ {
|
||||
expected := fmt.Sprintf("llama-cpp/models/Kimi-K2.6-GGUF/Kimi-K2.6-UD-Q8_K_XL-%05d-of-00014.gguf", i)
|
||||
Expect(modelConfig.Files[i-1].Filename).To(Equal(expected))
|
||||
Expect(modelConfig.Files[i-1].SHA256).To(Equal(fmt.Sprintf("sha-%02d", i)))
|
||||
}
|
||||
|
||||
// The configured model path must be shard 1 — this is the file
|
||||
// llama.cpp's split loader expects to be pointed at.
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: llama-cpp/models/Kimi-K2.6-GGUF/Kimi-K2.6-UD-Q8_K_XL-00001-of-00014.gguf",
|
||||
))
|
||||
})
|
||||
|
||||
It("emits all shards of the preferred quant alongside an mmproj", func() {
|
||||
// Sharded multimodal model: mmproj is single-file, the text
|
||||
// model is split in 3 parts and matches the user preference.
|
||||
details := withHF(`{"name":"VL-GGUF","quantizations":"Q4_K_M","mmproj_quantizations":"F16"}`,
|
||||
hfFile("mmproj-F16.gguf", "mm"),
|
||||
hfFile("model-Q4_K_M-00001-of-00003.gguf", "p1"),
|
||||
hfFile("model-Q4_K_M-00002-of-00003.gguf", "p2"),
|
||||
hfFile("model-Q4_K_M-00003-of-00003.gguf", "p3"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(4), fmt.Sprintf("%+v", modelConfig))
|
||||
|
||||
// Model shards come first, in order, then the mmproj.
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/VL-GGUF/model-Q4_K_M-00001-of-00003.gguf"))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/models/VL-GGUF/model-Q4_K_M-00002-of-00003.gguf"))
|
||||
Expect(modelConfig.Files[2].Filename).To(Equal("llama-cpp/models/VL-GGUF/model-Q4_K_M-00003-of-00003.gguf"))
|
||||
Expect(modelConfig.Files[3].Filename).To(Equal("llama-cpp/mmproj/VL-GGUF/mmproj-F16.gguf"))
|
||||
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: llama-cpp/models/VL-GGUF/model-Q4_K_M-00001-of-00003.gguf",
|
||||
))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"mmproj: llama-cpp/mmproj/VL-GGUF/mmproj-F16.gguf",
|
||||
))
|
||||
})
|
||||
|
||||
It("does not emit duplicate entries when called repeatedly on the same group", func() {
|
||||
// Guards appendShardGroup's dedup: if a shard ends up in the
|
||||
// Files slice via more than one code path (e.g. a future
|
||||
// refactor that processes mmproj and model candidates through
|
||||
// the same path), we must not accidentally duplicate downloads.
|
||||
details := withHF(`{"name":"dup","quantizations":"Q4_K_M,Q4_K_M"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "aaa"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(1), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/dup/model-Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("ignores non-gguf files in the repo listing", func() {
|
||||
// Real HF repos ship READMEs, tokenizer json, images, etc.
|
||||
// Only .gguf entries should surface as downloadable files.
|
||||
details := withHF(`{"name":"noise"}`,
|
||||
hfFile("README.md", ""),
|
||||
hfFile("config.json", ""),
|
||||
hfFile("logo.png", ""),
|
||||
hfFile("model-Q4_K_M.gguf", "aaa"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(1))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/noise/model-Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("produces no files when the repo contains no .gguf", func() {
|
||||
details := withHF(`{"name":"empty"}`,
|
||||
hfFile("README.md", ""),
|
||||
hfFile("config.json", ""),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("AdditionalBackends", func() {
|
||||
It("advertises ik-llama-cpp, turboquant, and buun-llama-cpp as drop-in replacements", func() {
|
||||
entries := importer.AdditionalBackends()
|
||||
|
||||
names := make([]string, 0, len(entries))
|
||||
byName := map[string]importers.KnownBackendEntry{}
|
||||
for _, e := range entries {
|
||||
names = append(names, e.Name)
|
||||
byName[e.Name] = e
|
||||
}
|
||||
Expect(names).To(ConsistOf("ik-llama-cpp", "turboquant", "buun-llama-cpp"))
|
||||
|
||||
ik := byName["ik-llama-cpp"]
|
||||
Expect(ik.Modality).To(Equal("text"))
|
||||
Expect(ik.Description).NotTo(BeEmpty())
|
||||
|
||||
tq := byName["turboquant"]
|
||||
Expect(tq.Modality).To(Equal("text"))
|
||||
Expect(tq.Description).NotTo(BeEmpty())
|
||||
|
||||
bn := byName["buun-llama-cpp"]
|
||||
Expect(bn.Modality).To(Equal("text"))
|
||||
Expect(bn.Description).NotTo(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -18,8 +18,11 @@ import (
|
||||
//
|
||||
// Detection order:
|
||||
// 1. GGUF files (*.gguf) — uses llama-cpp backend
|
||||
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
|
||||
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
|
||||
// 2. whisper.cpp ggml-*.bin — uses whisper backend
|
||||
// 3. silero_vad*.onnx — uses silero-vad backend
|
||||
// 4. piper .onnx + .onnx.json pair — uses piper backend
|
||||
// 5. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
|
||||
// 6. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
|
||||
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
// Make paths relative to the models directory (parent of dirPath)
|
||||
// so config YAML stays portable.
|
||||
@@ -51,7 +54,48 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 2. LoRA adapter: look for adapter_config.json
|
||||
// 2. whisper.cpp ggml-*.bin models
|
||||
if ggmlFile := findFileByPrefixSuffix(dirPath, "ggml-", ".bin"); ggmlFile != "" {
|
||||
xlog.Info("ImportLocalPath: detected whisper.cpp GGML model", "path", ggmlFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "whisper",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
}
|
||||
cfg.Model = relPath(ggmlFile)
|
||||
cfg.Description = buildDescription(dirPath, "Whisper GGML")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 3/4. Single .onnx file in dir — silero-vad or piper depending on signals.
|
||||
if onnxFile := findSingleONNX(dirPath); onnxFile != "" {
|
||||
base := filepath.Base(onnxFile)
|
||||
lowerBase := strings.ToLower(base)
|
||||
switch {
|
||||
case strings.HasPrefix(lowerBase, "silero"):
|
||||
xlog.Info("ImportLocalPath: detected Silero VAD model", "path", onnxFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "silero-vad",
|
||||
}
|
||||
cfg.Model = relPath(onnxFile)
|
||||
cfg.Description = buildDescription(dirPath, "Silero VAD")
|
||||
return cfg, nil
|
||||
case fileExists(onnxFile + ".json"):
|
||||
xlog.Info("ImportLocalPath: detected Piper voice", "path", onnxFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "piper",
|
||||
}
|
||||
cfg.Model = relPath(onnxFile)
|
||||
cfg.Description = buildDescription(dirPath, "Piper voice")
|
||||
return cfg, nil
|
||||
}
|
||||
// Lone .onnx without piper config and without silero prefix: fall
|
||||
// through — no reliable backend to assign.
|
||||
}
|
||||
|
||||
// 5. LoRA adapter: look for adapter_config.json
|
||||
|
||||
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
|
||||
if fileExists(adapterConfigPath) {
|
||||
@@ -116,6 +160,41 @@ func findGGUF(dir string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// findFileByPrefixSuffix returns the path to the first file in dir matching
|
||||
// both prefix (case-sensitive) and suffix (case-insensitive), or "".
|
||||
func findFileByPrefixSuffix(dir, prefix, suffix string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
lowerSuffix := strings.ToLower(suffix)
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if strings.HasPrefix(name, prefix) && strings.HasSuffix(strings.ToLower(name), lowerSuffix) {
|
||||
return filepath.Join(dir, name)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// findSingleONNX returns the path to the first .onnx file found in dir, or "".
|
||||
// Subdirectories are ignored — callers expect a flat layout.
|
||||
func findSingleONNX(dir string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".onnx") {
|
||||
return filepath.Join(dir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
|
||||
func readBaseModel(dirPath string) string {
|
||||
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
|
||||
|
||||
@@ -117,6 +117,47 @@ var _ = Describe("ImportLocalPath", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("Whisper ggml-*.bin detection", func() {
|
||||
It("maps ggml-base.en.bin to the whisper backend", func() {
|
||||
modelDir := filepath.Join(tmpDir, "whisper-base")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "ggml-base.en.bin"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "whisper-base")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("whisper"))
|
||||
Expect(cfg.Model).To(ContainSubstring("ggml-base.en.bin"))
|
||||
Expect(cfg.KnownUsecaseStrings).To(ContainElement("transcript"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Piper ONNX + ONNX config detection", func() {
|
||||
It("maps the .onnx + .onnx.json pair to the piper backend", func() {
|
||||
modelDir := filepath.Join(tmpDir, "piper-amy")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "en_US-amy-medium.onnx"), []byte("fake"), 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "en_US-amy-medium.onnx.json"), []byte("{}"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "piper-amy")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("piper"))
|
||||
Expect(cfg.Model).To(ContainSubstring("en_US-amy-medium.onnx"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Silero VAD detection", func() {
|
||||
It("maps silero_vad.onnx to the silero-vad backend", func() {
|
||||
modelDir := filepath.Join(tmpDir, "silero")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "silero_vad.onnx"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "silero")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("silero-vad"))
|
||||
Expect(cfg.Model).To(ContainSubstring("silero_vad.onnx"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("fallback", func() {
|
||||
It("returns error for empty directory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "empty")
|
||||
|
||||
@@ -15,6 +15,10 @@ var _ Importer = &MLXImporter{}
|
||||
|
||||
type MLXImporter struct{}
|
||||
|
||||
func (i *MLXImporter) Name() string { return "mlx" }
|
||||
func (i *MLXImporter) Modality() string { return "text" }
|
||||
func (i *MLXImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *MLXImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
|
||||
109
core/gallery/importers/moonshine.go
Normal file
109
core/gallery/importers/moonshine.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &MoonshineImporter{}
|
||||
|
||||
// MoonshineImporter recognises the UsefulSensors Moonshine ASR models, which
|
||||
// ship as ONNX artefacts under HF repositories owned by "UsefulSensors".
|
||||
// Detection combines the owner and a .onnx file presence check so we don't
|
||||
// accidentally match other UsefulSensors projects that might not host ASR
|
||||
// weights. preferences.backend="moonshine" overrides detection.
|
||||
type MoonshineImporter struct{}
|
||||
|
||||
func (i *MoonshineImporter) Name() string { return "moonshine" }
|
||||
func (i *MoonshineImporter) Modality() string { return "asr" }
|
||||
func (i *MoonshineImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *MoonshineImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "moonshine" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(details.HuggingFace.Author, "UsefulSensors") {
|
||||
return false
|
||||
}
|
||||
// Accept either on-disk .onnx (the canonical Moonshine packaging) or an
|
||||
// ASR pipeline_tag on the metadata. The latter covers the transformers/
|
||||
// safetensors-only sibling repos (moonshine-tiny, moonshine-base, …)
|
||||
// that still route to the moonshine runtime.
|
||||
if HasONNX(details.HuggingFace.Files) {
|
||||
return true
|
||||
}
|
||||
return details.HuggingFace.PipelineTag == "automatic-speech-recognition"
|
||||
}
|
||||
|
||||
func (i *MoonshineImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// Prefer the canonical HF repo path ("owner/repo") so downstream
|
||||
// runtime tooling can resolve the model regardless of how the user
|
||||
// spelled the URI (hf://, https://huggingface.co/, etc.).
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "moonshine",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
48
core/gallery/importers/moonshine_test.go
Normal file
48
core/gallery/importers/moonshine_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MoonshineImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches UsefulSensors/moonshine-tiny (owner + ASR pipeline_tag)", func() {
|
||||
uri := "https://huggingface.co/UsefulSensors/moonshine-tiny"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: moonshine"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("transcript"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// Model should reference the HF repo path.
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("UsefulSensors/moonshine-tiny"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=moonshine for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "moonshine"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: moonshine"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.MoonshineImporter{}
|
||||
Expect(imp.Name()).To(Equal("moonshine"))
|
||||
Expect(imp.Modality()).To(Equal("asr"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
99
core/gallery/importers/nemo.go
Normal file
99
core/gallery/importers/nemo.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &NemoImporter{}
|
||||
|
||||
// NemoImporter matches NVIDIA NeMo ASR checkpoints, which always ship as
|
||||
// single-file ".nemo" archives under NVIDIA-owned HF repositories. Combining
|
||||
// owner=nvidia with the .nemo extension is narrow enough to avoid picking up
|
||||
// the unrelated NVIDIA LLM repos that only carry safetensors weights.
|
||||
// preferences.backend="nemo" overrides detection.
|
||||
type NemoImporter struct{}
|
||||
|
||||
func (i *NemoImporter) Name() string { return "nemo" }
|
||||
func (i *NemoImporter) Modality() string { return "asr" }
|
||||
func (i *NemoImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *NemoImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "nemo" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(details.HuggingFace.Author, "nvidia") {
|
||||
return false
|
||||
}
|
||||
return HasExtension(details.HuggingFace.Files, ".nemo")
|
||||
}
|
||||
|
||||
func (i *NemoImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "nemo",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/nemo_test.go
Normal file
47
core/gallery/importers/nemo_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("NemoImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches nvidia/parakeet-tdt-0.6b-v3 (owner + .nemo file)", func() {
|
||||
uri := "https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: nemo"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("transcript"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("nvidia/parakeet-tdt-0.6b-v3"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=nemo for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "nemo"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: nemo"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.NemoImporter{}
|
||||
Expect(imp.Name()).To(Equal("nemo"))
|
||||
Expect(imp.Modality()).To(Equal("asr"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
110
core/gallery/importers/neutts.go
Normal file
110
core/gallery/importers/neutts.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &NeuTTSImporter{}
|
||||
|
||||
// NeuTTSImporter recognises Neuphonic's NeuTTS releases. Detection uses
|
||||
// "neutts" (case-insensitive) substring in the repo name or the
|
||||
// `neuphonic` owner — covers both the primary "neutts-air" release and
|
||||
// community mirrors. preferences.backend="neutts" overrides detection.
|
||||
type NeuTTSImporter struct{}
|
||||
|
||||
func (i *NeuTTSImporter) Name() string { return "neutts" }
|
||||
func (i *NeuTTSImporter) Modality() string { return "tts" }
|
||||
func (i *NeuTTSImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *NeuTTSImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "neutts" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "neuphonic") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "neutts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "neuphonic") || strings.Contains(strings.ToLower(repo), "neutts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *NeuTTSImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "neutts",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/neutts_test.go
Normal file
47
core/gallery/importers/neutts_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("NeuTTSImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches neuphonic/neutts-air (owner)", func() {
|
||||
uri := "https://huggingface.co/neuphonic/neutts-air"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: neutts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("neuphonic/neutts-air"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=neutts for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "neutts"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: neutts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.NeuTTSImporter{}
|
||||
Expect(imp.Name()).To(Equal("neutts"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user