mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
Compare commits
18 Commits
issue-9478
...
docs/readm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9d7b5c29b | ||
|
|
f877942d97 | ||
|
|
f5eb13d3c2 | ||
|
|
c1f923b2bc | ||
|
|
ed648b3b4e | ||
|
|
3ce5248126 | ||
|
|
04f1a0285d | ||
|
|
181ebb6df4 | ||
|
|
1c59165d63 | ||
|
|
eb00d9b178 | ||
|
|
2068b6f43c | ||
|
|
eb01c77214 | ||
|
|
bb4fda6f0e | ||
|
|
f0c92610a1 | ||
|
|
bbeacf140d | ||
|
|
6820ec468f | ||
|
|
20baec77ab | ||
|
|
d16f19f1eb |
@@ -8,6 +8,7 @@ Create the backend directory under the appropriate location:
|
||||
- **Python backends**: `backend/python/<backend-name>/`
|
||||
- **Go backends**: `backend/go/<backend-name>/`
|
||||
- **C++ backends**: `backend/cpp/<backend-name>/`
|
||||
- **Rust backends**: `backend/rust/<backend-name>/`
|
||||
|
||||
For Python backends, you'll typically need:
|
||||
- `backend.py` - Main gRPC server implementation
|
||||
@@ -18,9 +19,22 @@ For Python backends, you'll typically need:
|
||||
- `run.sh` - Runtime script
|
||||
- `test.py` / `test.sh` - Test files
|
||||
|
||||
For Rust backends, you'll typically need (see `backend/rust/kokoros/` as a reference):
|
||||
- `Cargo.toml` - Crate manifest; depend on the upstream project as a submodule under `sources/`
|
||||
- `build.rs` - Invokes `tonic_build` to generate gRPC stubs from `backend/backend.proto` (use the `BACKEND_PROTO_PATH` env var so the Makefile can inject the canonical copy)
|
||||
- `src/` - The gRPC server implementation (implement `Backend` via `tonic`)
|
||||
- `Makefile` - Copies `backend.proto` into the crate, runs `cargo build --release`, then `package.sh`
|
||||
- `package.sh` - Uses `ldd` to bundle the binary's dynamic deps and `ld.so` into `package/lib/`
|
||||
- `run.sh` - Sets `LD_LIBRARY_PATH`/`SSL_CERT_DIR` and execs the binary via the bundled `lib/ld.so`
|
||||
- `sources/<UpstreamProject>/` - Git submodule with the upstream Rust crate
|
||||
|
||||
## 2. Add Build Configurations to `.github/workflows/backend.yml`
|
||||
|
||||
Add build matrix entries for each platform/GPU type you want to support. Look at similar backends (e.g., `chatterbox`, `faster-whisper`) for reference.
|
||||
Add build matrix entries for each platform/GPU type you want to support. Look at similar backends for reference — `chatterbox`/`faster-whisper` for Python, `piper`/`silero-vad` for Go, `kokoros` for Rust.
|
||||
|
||||
**Without an entry here no image is ever built or pushed, and the gallery entry in `backend/index.yaml` will point at a tag that does not exist.** The `dockerfile:` field must point at `./backend/Dockerfile.<lang>` matching the language bucket from step 1 (e.g. `Dockerfile.python`, `Dockerfile.golang`, `Dockerfile.rust`). The `tag-suffix` must match the `uri:` in the corresponding `backend/index.yaml` image entry exactly.
|
||||
|
||||
If you add a new language bucket, `scripts/changed-backends.js` also needs a branch in `inferBackendPath` so PR change-detection routes file edits correctly.
|
||||
|
||||
**Placement in file:**
|
||||
- CPU builds: Add after other CPU builds (e.g., after `cpu-chatterbox`)
|
||||
@@ -56,24 +70,28 @@ Add `backends/<backend-name>` to the `.NOTPARALLEL` line (around line 2) to prev
|
||||
|
||||
**Step 4b: Add to `prepare-test-extra`**
|
||||
|
||||
Add the backend to the `prepare-test-extra` target (around line 312) to prepare it for testing:
|
||||
Add the backend to the `prepare-test-extra` target to prepare it for testing. Use the path matching your language bucket (`backend/python/`, `backend/go/`, `backend/rust/`, …):
|
||||
|
||||
```makefile
|
||||
prepare-test-extra: protogen-python
|
||||
...
|
||||
$(MAKE) -C backend/python/<backend-name>
|
||||
$(MAKE) -C backend/<lang>/<backend-name>
|
||||
```
|
||||
|
||||
For Rust backends the target is usually the crate build target itself (e.g. `$(MAKE) -C backend/rust/<backend-name> <backend-name>-grpc`) so the binary is in place before `test` runs.
|
||||
|
||||
**Step 4c: Add to `test-extra`**
|
||||
|
||||
Add the backend to the `test-extra` target (around line 319) to run its tests:
|
||||
Add the backend to the `test-extra` target to run its tests — applies to Go and Rust backends too, not only Python:
|
||||
|
||||
```makefile
|
||||
test-extra: prepare-test-extra
|
||||
...
|
||||
$(MAKE) -C backend/python/<backend-name> test
|
||||
$(MAKE) -C backend/<lang>/<backend-name> test
|
||||
```
|
||||
|
||||
Each backend's own `Makefile` should define a `test` target so this line works regardless of language. Integration tests that need large model downloads should be gated behind an env var (see `backend/rust/kokoros/`'s `KOKOROS_MODEL_PATH` pattern) so CI only runs unit tests.
|
||||
|
||||
**Step 4d: Add Backend Definition**
|
||||
|
||||
Add a backend definition variable in the backend definitions section (around line 428-457). The format depends on the backend type:
|
||||
@@ -93,6 +111,13 @@ BACKEND_<BACKEND_NAME> = <backend-name>|python|./backend|false|true
|
||||
BACKEND_<BACKEND_NAME> = <backend-name>|golang|.|false|true
|
||||
```
|
||||
|
||||
**For Rust backends**:
|
||||
```makefile
|
||||
BACKEND_<BACKEND_NAME> = <backend-name>|rust|.|false|true
|
||||
```
|
||||
|
||||
The language field (`python`/`golang`/`rust`/…) must match a `backend/Dockerfile.<lang>` file.
|
||||
|
||||
**Step 4e: Generate Docker Build Target**
|
||||
|
||||
Add an eval call to generate the docker-build target (around line 480-501):
|
||||
@@ -153,6 +178,29 @@ ls /tmp/check # expect the bundled .so files + symlinks
|
||||
|
||||
Then boot it inside a fresh `ubuntu:24.04` (which intentionally does *not* have the lib installed) to confirm it actually loads from the backend dir.
|
||||
|
||||
## Importer integration
|
||||
|
||||
When you add a new backend, you MUST also make it importable via the model import form (`/import-model`). The import form dropdown is sourced dynamically from `GET /backends/known` — it reads the importer registry at `core/gallery/importers/importers.go`, so the steps below are the ONLY way to make your backend show up.
|
||||
|
||||
Required steps:
|
||||
|
||||
1. **If your backend has unambiguous detection signals** (unique file extension, HF `pipeline_tag`, unique repo name pattern, unique artefact like `modules.json`):
|
||||
- Create an importer file at `core/gallery/importers/<backend>.go` following the Match/Import pattern in `llama-cpp.go`.
|
||||
- Register it in `importers.go:defaultImporters` in **specificity order** — more specific detectors must appear BEFORE more generic ones (e.g. `sentencetransformers` before `transformers`, `stablediffusion-ggml` before `llama-cpp`, `vllm-omni` before `vllm`). First match wins.
|
||||
2. **If your backend is a drop-in replacement** (same artefacts as another backend, e.g. `ik-llama-cpp` and `turboquant` both consume GGUF the same way `llama-cpp` does):
|
||||
- Do NOT create a new importer. Extend the existing importer's `Import()` to swap the emitted `backend:` field when `preferences.backend` matches. See `llama-cpp.go` for the pattern.
|
||||
3. **If your backend has no reliable auto-detect signal** (preference-only — e.g. `sglang`, `tinygrad`, `whisperx`):
|
||||
- Do NOT create an importer. Instead add the backend name to the curated pref-only slice in `core/http/endpoints/localai/backend.go` that feeds `/backends/known`. A single line addition.
|
||||
4. **Always** add a table-driven test in `core/gallery/importers/importers_test.go` (Ginkgo/Gomega):
|
||||
- Use a real public HuggingFace repo URI as the test fixture (existing tests already hit the live HF API — follow that pattern).
|
||||
- Cover detection (auto-match without preferences), preference-override (explicit `backend:` in preferences wins), and — if the backend's modality has a common `pipeline_tag` but ambiguous artefacts — an ambiguity test asserting `errors.Is(err, importers.ErrAmbiguousImport)`.
|
||||
|
||||
Rules of thumb:
|
||||
|
||||
- When in doubt, lean pref-only. A wrong auto-detect is worse than a forced preference.
|
||||
- Never silently emit a modality mismatch (e.g. emit `llama-cpp` for a TTS repo because `.gguf` is present). Return `ErrAmbiguousImport` instead.
|
||||
- Registration order is the single most common source of bugs. Check by running `go test ./core/gallery/importers/...` — the existing suite will fail if you've shadowed a pre-existing detector.
|
||||
|
||||
## 6. Example: Adding a Python Backend
|
||||
|
||||
For reference, when `moonshine` was added:
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
This guide covers how to add new API endpoints and properly integrate them with the auth/permissions system.
|
||||
|
||||
> **Before you ship a new endpoint or capability surface**, re-read the [checklist at the bottom of this file](#checklist). LocalAI advertises its feature surface in several independent places — miss any one of them and clients/admins/UI won't know the endpoint exists.
|
||||
|
||||
## Architecture overview
|
||||
|
||||
Authentication and authorization flow through three layers:
|
||||
@@ -234,6 +236,66 @@ Use these HTTP status codes:
|
||||
|
||||
If your endpoint should be tracked for usage (token counts, request counts), add the `usageMiddleware` to its middleware chain. See `core/http/middleware/usage.go` and how it's applied in `routes/openai.go`.
|
||||
|
||||
## Advertising surfaces — where to register a new capability
|
||||
|
||||
Beyond routing and auth, LocalAI publishes its capability surface in **four independent places**. When you add an endpoint — especially one introducing a net-new capability like a new media type or a new auth-gated feature — you must update every relevant surface. These aren't optional: missing them means the endpoint works but is invisible to clients, admins, and the UI.
|
||||
|
||||
### 1. Swagger `@Tags` annotation (mandatory)
|
||||
|
||||
Every handler needs a swagger block so the endpoint appears in `/swagger/index.html` and in the `/api/instructions` output. The `@Tags` value is what groups the endpoint into a capability area:
|
||||
|
||||
```go
|
||||
// MyEndpoint does X.
|
||||
// @Summary Do X.
|
||||
// @Tags my-capability
|
||||
// @Param request body schema.MyRequest true "payload"
|
||||
// @Success 200 {object} schema.MyResponse "Response"
|
||||
// @Router /v1/my-endpoint [post]
|
||||
func MyEndpoint(...) echo.HandlerFunc { ... }
|
||||
```
|
||||
|
||||
Use an existing tag when the endpoint extends an existing area (e.g. `audio`, `images`, `face-recognition`). Create a new tag only when the endpoint introduces a genuinely new capability surface — and in that case, also register it in step 2.
|
||||
|
||||
After adding endpoints, regenerate the embedded spec so the runtime serves it:
|
||||
|
||||
```bash
|
||||
make protogen-go # ensures gRPC codegen is fresh first
|
||||
make swagger # regenerates swagger/swagger.json
|
||||
```
|
||||
|
||||
### 2. `/api/instructions` registry (for new capability areas)
|
||||
|
||||
`core/http/endpoints/localai/api_instructions.go` defines `instructionDefs` — a lightweight, machine-readable index of capability areas that groups swagger endpoints by tag. It's the primary discovery surface for agents and SDKs ("what can this server do?").
|
||||
|
||||
**When to update:** only when adding a new capability area (a new swagger tag). Existing-tag additions automatically surface without any change here.
|
||||
|
||||
Add an entry to `instructionDefs`:
|
||||
|
||||
```go
|
||||
{
|
||||
Name: "my-capability", // URL segment at /api/instructions/my-capability
|
||||
Description: "Short sentence describing the capability",
|
||||
Tags: []string{"my-capability"}, // must match swagger @Tags
|
||||
Intro: "Optional gotcha/context that isn't in the swagger descriptions (caveats, defaults, cross-references to other endpoints).",
|
||||
},
|
||||
```
|
||||
|
||||
Also bump the expected-length count in `api_instructions_test.go` and add the name to the `ContainElements` assertion.
|
||||
|
||||
### 3. `capabilities.js` symbol (for new model-config FLAG_* flags)
|
||||
|
||||
If your feature needs a new `FLAG_*` usecase flag in `core/config/model_config.go` (so users can filter gallery models by it, and so `/v1/models` surfaces it), also declare the matching symbol in `core/http/react-ui/src/utils/capabilities.js`:
|
||||
|
||||
```js
|
||||
export const CAP_MY_CAPABILITY = 'FLAG_MY_CAPABILITY'
|
||||
```
|
||||
|
||||
React pages that want to filter the ModelSelector by capability import this symbol. Declare it even if you're not building the UI page yet — the declaration keeps the Go/JS vocabularies in sync.
|
||||
|
||||
### 4. `docs/content/` (user-facing documentation)
|
||||
|
||||
A new capability deserves its own page under `docs/content/features/`, plus cross-links from related features and an entry in `docs/content/whats-new.md`. See the pattern used by `face-recognition.md` / `object-detection.md`.
|
||||
|
||||
## Path protection rules
|
||||
|
||||
The global auth middleware classifies paths as API paths or non-API paths:
|
||||
@@ -248,12 +310,23 @@ If you add endpoints under a new top-level path prefix, add it to `isAPIPath()`
|
||||
|
||||
When adding a new endpoint:
|
||||
|
||||
**Routing & auth**
|
||||
- [ ] Handler in `core/http/endpoints/`
|
||||
- [ ] Route registered in appropriate `core/http/routes/` file
|
||||
- [ ] Auth level chosen: public / standard / admin / feature-gated
|
||||
- [ ] If feature-gated: constant in `permissions.go`, metadata in `features.go`, middleware in `app.go`
|
||||
- [ ] Entry added to `RouteFeatureRegistry` in `core/http/auth/features.go` (one row per route/method — all /v1/* routes gate through this, not per-route middleware)
|
||||
- [ ] If new feature: constant in `permissions.go`, added to the right slice (`APIFeatures` default-ON / `AgentFeatures` default-OFF), metadata in `features.go` `*FeatureMetas()`
|
||||
- [ ] If feature uses group middleware: wired in `core/http/app.go` and passed to the route registration function
|
||||
- [ ] If new path prefix: added to `isAPIPath()` in `middleware.go`
|
||||
- [ ] If OpenAI-compatible: entry in `RouteFeatureRegistry`
|
||||
- [ ] If token-counting: `usageMiddleware` added to middleware chain
|
||||
- [ ] Error responses use `schema.ErrorResponse` format
|
||||
|
||||
**Advertising surfaces (easy to miss — see the [Advertising surfaces](#advertising-surfaces--where-to-register-a-new-capability) section)**
|
||||
- [ ] Swagger block on the handler: `@Summary`, `@Tags`, `@Param`, `@Success`, `@Router`
|
||||
- [ ] If new capability area (new swagger tag): entry in `instructionDefs` in `core/http/endpoints/localai/api_instructions.go` + test count bumped in `api_instructions_test.go`
|
||||
- [ ] If new `FLAG_*` usecase flag: matching `CAP_*` symbol exported from `core/http/react-ui/src/utils/capabilities.js`
|
||||
- [ ] `docs/content/features/<feature>.md` created; cross-links from related feature pages; entry in `docs/content/whats-new.md`
|
||||
|
||||
**Quality**
|
||||
- [ ] Error responses use `schema.ErrorResponse` format (or `echo.NewHTTPError` with a mapped gRPC status — see the `mapBackendError` helper in `core/http/endpoints/localai/images.go`)
|
||||
- [ ] Tests cover both authenticated and unauthenticated access
|
||||
- [ ] Swagger regenerated (`make swagger`) if you changed any `@Router`/`@Tags`/`@Param` annotation
|
||||
|
||||
@@ -42,6 +42,12 @@ trim_trailing_whitespace = false
|
||||
|
||||
Use `github.com/mudler/xlog` for logging which has the same API as slog.
|
||||
|
||||
## Go tests
|
||||
|
||||
All Go tests — including backend tests — must use [Ginkgo](https://onsi.github.io/ginkgo/) (v2) with Gomega matchers, not the stdlib `testing` package with `t.Run` / `t.Errorf`. A test file should register a suite with `RegisterFailHandler(Fail)` in a `TestXxx(t *testing.T)` bootstrap and use `Describe`/`Context`/`It` blocks for the actual cases. Look at any existing `*_test.go` under `core/` or `pkg/` for a template.
|
||||
|
||||
Do not mix styles within a package. If you are extending tests in a package that already uses Ginkgo, keep using Ginkgo. If you find stdlib-style Go tests in the tree, treat them as tech debt to be migrated rather than as a pattern to follow.
|
||||
|
||||
## Documentation
|
||||
|
||||
The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant.
|
||||
|
||||
68
.github/workflows/backend.yml
vendored
68
.github/workflows/backend.yml
vendored
@@ -711,6 +711,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-insightface'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "insightface"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-speaker-recognition'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "speaker-recognition"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -2584,6 +2610,20 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# kokoros (Rust TTS)
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-kokoros'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "kokoros"
|
||||
dockerfile: "./backend/Dockerfile.rust"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# local-store
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -2612,6 +2652,34 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# insightface (face recognition)
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-insightface'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "insightface"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# speaker-recognition (voice/speaker biometrics)
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-speaker-recognition'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "speaker-recognition"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
|
||||
2
.github/workflows/backend_build.yml
vendored
2
.github/workflows/backend_build.yml
vendored
@@ -108,6 +108,8 @@ jobs:
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
- name: Release space from worker
|
||||
if: inputs.runs-on == 'ubuntu-latest'
|
||||
|
||||
54
.github/workflows/test-extra.yml
vendored
54
.github/workflows/test-extra.yml
vendored
@@ -38,6 +38,8 @@ jobs:
|
||||
qwen3-tts-cpp: ${{ steps.detect.outputs.qwen3-tts-cpp }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
kokoros: ${{ steps.detect.outputs.kokoros }}
|
||||
insightface: ${{ steps.detect.outputs.insightface }}
|
||||
speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -751,3 +753,55 @@ jobs:
|
||||
- name: Test kokoros
|
||||
run: |
|
||||
make -C backend/rust/kokoros test
|
||||
tests-insightface-grpc:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.insightface == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
make build-essential curl unzip ca-certificates git tar
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.26.0'
|
||||
- name: Free disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/lib/android /opt/hostedtoolcache/CodeQL || true
|
||||
df -h
|
||||
- name: Build insightface backend image and run both model configurations
|
||||
run: |
|
||||
make test-extra-backend-insightface-all
|
||||
tests-speaker-recognition-grpc:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.speaker-recognition == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
make build-essential curl ca-certificates git tar
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.26.0'
|
||||
- name: Free disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/lib/android /opt/hostedtoolcache/CodeQL || true
|
||||
df -h
|
||||
- name: Build speaker-recognition backend image and run the ECAPA-TDNN configuration
|
||||
run: |
|
||||
make test-extra-backend-speaker-recognition-all
|
||||
|
||||
@@ -19,7 +19,7 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
|------|-------------|
|
||||
| [.agents/ai-coding-assistants.md](.agents/ai-coding-assistants.md) | Policy for AI-assisted contributions — licensing, DCO, attribution |
|
||||
| [.agents/building-and-testing.md](.agents/building-and-testing.md) | Building the project, running tests, Docker builds for specific platforms |
|
||||
| [.agents/adding-backends.md](.agents/adding-backends.md) | Adding a new backend (Python, Go, or C++) — full step-by-step checklist |
|
||||
| [.agents/adding-backends.md](.agents/adding-backends.md) | Adding a new backend (Python, Go, or C++) — full step-by-step checklist, including importer integration (the `/import-model` dropdown is server-driven from `GET /backends/known`) |
|
||||
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
|
||||
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
||||
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
|
||||
@@ -34,5 +34,6 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
- **Go style**: Prefer `any` over `interface{}`
|
||||
- **Comments**: Explain *why*, not *what*
|
||||
- **Docs**: Update `docs/content/` when adding features or changing config
|
||||
- **New API endpoints**: LocalAI advertises its capability surface in several independent places — swagger `@Tags`, `/api/instructions` registry, auth `RouteFeatureRegistry`, React UI `capabilities.js`, docs. Read [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) and follow its checklist — missing any surface means clients, admins, and the UI won't know the endpoint exists.
|
||||
- **Build**: Inspect `Makefile` and `.github/workflows/` — ask the user before running long builds
|
||||
- **UI**: The active UI is the React app in `core/http/react-ui/`. The older Alpine.js/HTML UI in `core/http/static/` is pending deprecation — all new UI work goes in the React UI
|
||||
|
||||
185
Makefile
185
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -434,6 +434,8 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/ace-step
|
||||
$(MAKE) -C backend/python/trl
|
||||
$(MAKE) -C backend/python/tinygrad
|
||||
$(MAKE) -C backend/python/insightface
|
||||
$(MAKE) -C backend/python/speaker-recognition
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
@@ -457,6 +459,8 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/ace-step test
|
||||
$(MAKE) -C backend/python/trl test
|
||||
$(MAKE) -C backend/python/tinygrad test
|
||||
$(MAKE) -C backend/python/insightface test
|
||||
$(MAKE) -C backend/python/speaker-recognition test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
|
||||
##
|
||||
@@ -507,6 +511,13 @@ test-extra-backend: protogen-go
|
||||
BACKEND_TEST_TOOL_NAME="$$BACKEND_TEST_TOOL_NAME" \
|
||||
BACKEND_TEST_CACHE_TYPE_K="$$BACKEND_TEST_CACHE_TYPE_K" \
|
||||
BACKEND_TEST_CACHE_TYPE_V="$$BACKEND_TEST_CACHE_TYPE_V" \
|
||||
BACKEND_TEST_FACE_IMAGE_1_URL="$$BACKEND_TEST_FACE_IMAGE_1_URL" \
|
||||
BACKEND_TEST_FACE_IMAGE_1_FILE="$$BACKEND_TEST_FACE_IMAGE_1_FILE" \
|
||||
BACKEND_TEST_FACE_IMAGE_2_URL="$$BACKEND_TEST_FACE_IMAGE_2_URL" \
|
||||
BACKEND_TEST_FACE_IMAGE_2_FILE="$$BACKEND_TEST_FACE_IMAGE_2_FILE" \
|
||||
BACKEND_TEST_FACE_IMAGE_3_URL="$$BACKEND_TEST_FACE_IMAGE_3_URL" \
|
||||
BACKEND_TEST_FACE_IMAGE_3_FILE="$$BACKEND_TEST_FACE_IMAGE_3_FILE" \
|
||||
BACKEND_TEST_VERIFY_DISTANCE_CEILING="$$BACKEND_TEST_VERIFY_DISTANCE_CEILING" \
|
||||
go test -v -timeout 30m ./tests/e2e-backends/...
|
||||
|
||||
## Convenience wrappers: build the image, then exercise it.
|
||||
@@ -603,6 +614,172 @@ test-extra-backend-tinygrad-all: \
|
||||
test-extra-backend-tinygrad-sd \
|
||||
test-extra-backend-tinygrad-whisper
|
||||
|
||||
## insightface — face recognition.
|
||||
##
|
||||
## Face fixtures default to the sample images shipped in the
|
||||
## deepinsight/insightface repository (MIT-licensed). For offline/local
|
||||
## runs override with BACKEND_TEST_FACE_IMAGE_{1,2,3}_FILE pointing at
|
||||
## local paths.
|
||||
FACE_IMAGE_1_URL ?= https://github.com/deepinsight/insightface/raw/master/python-package/insightface/data/images/t1.jpg
|
||||
FACE_IMAGE_2_URL ?= https://github.com/deepinsight/insightface/raw/master/python-package/insightface/data/images/t1.jpg
|
||||
FACE_IMAGE_3_URL ?= https://github.com/deepinsight/insightface/raw/master/python-package/insightface/data/images/mask_white.jpg
|
||||
## Known spoof fixture used by the face_antispoof e2e cap. This is
|
||||
## upstream's own `image_F2.jpg` (Silent-Face repo, via yakhyo mirror)
|
||||
## — verified to classify as is_real=false with score < 0.05 on the
|
||||
## MiniFASNetV2 + MiniFASNetV1SE ensemble.
|
||||
FACE_SPOOF_IMAGE_URL ?= https://github.com/yakhyo/face-anti-spoofing/raw/main/assets/image_F2.jpg
|
||||
|
||||
## Host-side cache for the OpenCV Zoo face ONNX files used by the
|
||||
## opencv e2e target. The backend image no longer bakes model weights —
|
||||
## gallery installs bring them via `files:` — but the e2e suite drives
|
||||
## LoadModel over gRPC directly without going through the gallery. We
|
||||
## pre-download the ONNX files to a stable host path and pass absolute
|
||||
## paths in BACKEND_TEST_OPTIONS; `make` skips the downloads when the
|
||||
## SHA-256 already matches.
|
||||
INSIGHTFACE_OPENCV_DIR := /tmp/localai-insightface-opencv-cache
|
||||
INSIGHTFACE_OPENCV_YUNET_URL := https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx
|
||||
INSIGHTFACE_OPENCV_SFACE_URL := https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx
|
||||
INSIGHTFACE_OPENCV_YUNET_SHA := 8f2383e4dd3cfbb4553ea8718107fc0423210dc964f9f4280604804ed2552fa4
|
||||
INSIGHTFACE_OPENCV_SFACE_SHA := 0ba9fbfa01b5270c96627c4ef784da859931e02f04419c829e83484087c34e79
|
||||
|
||||
## buffalo_sc (insightface) — pack zip + SHA-256 mirrors the gallery
|
||||
## entry so the e2e target matches exactly what `local-ai models install
|
||||
## insightface-buffalo-sc` would have fetched. Smallest insightface pack
|
||||
## (~16MB) — keeps CI fast while still covering the insightface engine
|
||||
## code path end-to-end.
|
||||
INSIGHTFACE_BUFFALO_SC_DIR := /tmp/localai-insightface-buffalo-sc-cache
|
||||
INSIGHTFACE_BUFFALO_SC_URL := https://github.com/deepinsight/insightface/releases/download/v0.7/buffalo_sc.zip
|
||||
INSIGHTFACE_BUFFALO_SC_SHA := 57d31b56b6ffa911c8a73cfc1707c73cab76efe7f13b675a05223bf42de47c72
|
||||
|
||||
## Silent-Face antispoofing (MiniFASNetV2 + MiniFASNetV1SE) — shared
|
||||
## between the buffalo_sc and opencv e2e targets. Both ONNX files are
|
||||
## ~1.7MB, Apache 2.0. URLs + SHAs mirror the gallery entries.
|
||||
INSIGHTFACE_ANTISPOOF_DIR := /tmp/localai-insightface-antispoof-cache
|
||||
INSIGHTFACE_ANTISPOOF_V2_URL := https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV2.onnx
|
||||
INSIGHTFACE_ANTISPOOF_V2_SHA := b32929adc2d9c34b9486f8c4c7bc97c1b69bc0ea9befefc380e4faae4e463907
|
||||
INSIGHTFACE_ANTISPOOF_V1SE_URL := https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV1SE.onnx
|
||||
INSIGHTFACE_ANTISPOOF_V1SE_SHA := ebab7f90c7833fbccd46d3a555410e78d969db5438e169b6524be444862b3676
|
||||
|
||||
.PHONY: insightface-opencv-models
|
||||
insightface-opencv-models:
|
||||
@mkdir -p $(INSIGHTFACE_OPENCV_DIR)
|
||||
@if [ "$$(sha256sum $(INSIGHTFACE_OPENCV_DIR)/yunet.onnx 2>/dev/null | awk '{print $$1}')" != "$(INSIGHTFACE_OPENCV_YUNET_SHA)" ]; then \
|
||||
echo "Fetching YuNet..."; \
|
||||
curl -fsSL -o $(INSIGHTFACE_OPENCV_DIR)/yunet.onnx $(INSIGHTFACE_OPENCV_YUNET_URL); \
|
||||
echo "$(INSIGHTFACE_OPENCV_YUNET_SHA) $(INSIGHTFACE_OPENCV_DIR)/yunet.onnx" | sha256sum -c; \
|
||||
fi
|
||||
@if [ "$$(sha256sum $(INSIGHTFACE_OPENCV_DIR)/sface.onnx 2>/dev/null | awk '{print $$1}')" != "$(INSIGHTFACE_OPENCV_SFACE_SHA)" ]; then \
|
||||
echo "Fetching SFace..."; \
|
||||
curl -fsSL -o $(INSIGHTFACE_OPENCV_DIR)/sface.onnx $(INSIGHTFACE_OPENCV_SFACE_URL); \
|
||||
echo "$(INSIGHTFACE_OPENCV_SFACE_SHA) $(INSIGHTFACE_OPENCV_DIR)/sface.onnx" | sha256sum -c; \
|
||||
fi
|
||||
|
||||
.PHONY: insightface-antispoof-models
|
||||
insightface-antispoof-models:
|
||||
@mkdir -p $(INSIGHTFACE_ANTISPOOF_DIR)
|
||||
@if [ "$$(sha256sum $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx 2>/dev/null | awk '{print $$1}')" != "$(INSIGHTFACE_ANTISPOOF_V2_SHA)" ]; then \
|
||||
echo "Fetching MiniFASNetV2..."; \
|
||||
curl -fsSL -o $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx $(INSIGHTFACE_ANTISPOOF_V2_URL); \
|
||||
echo "$(INSIGHTFACE_ANTISPOOF_V2_SHA) $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx" | sha256sum -c; \
|
||||
fi
|
||||
@if [ "$$(sha256sum $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx 2>/dev/null | awk '{print $$1}')" != "$(INSIGHTFACE_ANTISPOOF_V1SE_SHA)" ]; then \
|
||||
echo "Fetching MiniFASNetV1SE..."; \
|
||||
curl -fsSL -o $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx $(INSIGHTFACE_ANTISPOOF_V1SE_URL); \
|
||||
echo "$(INSIGHTFACE_ANTISPOOF_V1SE_SHA) $(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx" | sha256sum -c; \
|
||||
fi
|
||||
|
||||
.PHONY: insightface-buffalo-sc-models
|
||||
insightface-buffalo-sc-models:
|
||||
@mkdir -p $(INSIGHTFACE_BUFFALO_SC_DIR)
|
||||
@if [ "$$(sha256sum $(INSIGHTFACE_BUFFALO_SC_DIR)/buffalo_sc.zip 2>/dev/null | awk '{print $$1}')" != "$(INSIGHTFACE_BUFFALO_SC_SHA)" ]; then \
|
||||
echo "Fetching buffalo_sc..."; \
|
||||
curl -fsSL -o $(INSIGHTFACE_BUFFALO_SC_DIR)/buffalo_sc.zip $(INSIGHTFACE_BUFFALO_SC_URL); \
|
||||
echo "$(INSIGHTFACE_BUFFALO_SC_SHA) $(INSIGHTFACE_BUFFALO_SC_DIR)/buffalo_sc.zip" | sha256sum -c; \
|
||||
rm -f $(INSIGHTFACE_BUFFALO_SC_DIR)/*.onnx; \
|
||||
fi
|
||||
@if [ ! -f "$(INSIGHTFACE_BUFFALO_SC_DIR)/det_500m.onnx" ]; then \
|
||||
echo "Extracting buffalo_sc..."; \
|
||||
unzip -o -q $(INSIGHTFACE_BUFFALO_SC_DIR)/buffalo_sc.zip -d $(INSIGHTFACE_BUFFALO_SC_DIR); \
|
||||
fi
|
||||
|
||||
## buffalo_sc — smallest insightface pack (SCRFD-500MF detector + MBF
|
||||
## recognizer, ~16MB). Exercises the insightface engine code path
|
||||
## (model_zoo-backed inference) without the ~326MB buffalo_l download.
|
||||
## No age/gender/landmark heads — face_analyze is dropped from caps.
|
||||
## The pack is pre-fetched on the host and passed as `root:<dir>` since
|
||||
## the e2e suite drives LoadModel directly without going through
|
||||
## LocalAI's gallery flow (which is what would normally populate
|
||||
## ModelPath and in turn the engine's `_model_dir` option).
|
||||
test-extra-backend-insightface-buffalo-sc: docker-build-insightface insightface-buffalo-sc-models insightface-antispoof-models
|
||||
BACKEND_IMAGE=local-ai-backend:insightface \
|
||||
BACKEND_TEST_MODEL_NAME=insightface-buffalo-sc \
|
||||
BACKEND_TEST_OPTIONS=engine:insightface,model_pack:buffalo_sc,root:$(INSIGHTFACE_BUFFALO_SC_DIR),antispoof_v2_onnx:$(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx,antispoof_v1se_onnx:$(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx \
|
||||
BACKEND_TEST_CAPS=health,load,face_detect,face_embed,face_verify,face_antispoof \
|
||||
BACKEND_TEST_FACE_IMAGE_1_URL=$(FACE_IMAGE_1_URL) \
|
||||
BACKEND_TEST_FACE_IMAGE_2_URL=$(FACE_IMAGE_2_URL) \
|
||||
BACKEND_TEST_FACE_IMAGE_3_URL=$(FACE_IMAGE_3_URL) \
|
||||
BACKEND_TEST_FACE_SPOOF_IMAGE_URL=$(FACE_SPOOF_IMAGE_URL) \
|
||||
BACKEND_TEST_VERIFY_DISTANCE_CEILING=0.55 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## OpenCV Zoo YuNet + SFace — Apache 2.0, commercial-safe. face_analyze
|
||||
## cap is dropped (SFace has no demographic head). The ONNX files are
|
||||
## pre-fetched on the host via the insightface-opencv-models target and
|
||||
## passed as absolute paths, since the e2e suite drives LoadModel
|
||||
## directly without going through LocalAI's gallery flow.
|
||||
test-extra-backend-insightface-opencv: docker-build-insightface insightface-opencv-models insightface-antispoof-models
|
||||
BACKEND_IMAGE=local-ai-backend:insightface \
|
||||
BACKEND_TEST_MODEL_NAME=insightface-opencv \
|
||||
BACKEND_TEST_OPTIONS=engine:onnx_direct,detector_onnx:$(INSIGHTFACE_OPENCV_DIR)/yunet.onnx,recognizer_onnx:$(INSIGHTFACE_OPENCV_DIR)/sface.onnx,antispoof_v2_onnx:$(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV2.onnx,antispoof_v1se_onnx:$(INSIGHTFACE_ANTISPOOF_DIR)/MiniFASNetV1SE.onnx \
|
||||
BACKEND_TEST_CAPS=health,load,face_detect,face_embed,face_verify,face_antispoof \
|
||||
BACKEND_TEST_FACE_IMAGE_1_URL=$(FACE_IMAGE_1_URL) \
|
||||
BACKEND_TEST_FACE_IMAGE_2_URL=$(FACE_IMAGE_2_URL) \
|
||||
BACKEND_TEST_FACE_IMAGE_3_URL=$(FACE_IMAGE_3_URL) \
|
||||
BACKEND_TEST_FACE_SPOOF_IMAGE_URL=$(FACE_SPOOF_IMAGE_URL) \
|
||||
BACKEND_TEST_VERIFY_DISTANCE_CEILING=0.55 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Aggregate — runs both face-recognition model configurations so CI
|
||||
## catches regressions across engines together.
|
||||
test-extra-backend-insightface-all: \
|
||||
test-extra-backend-insightface-buffalo-sc \
|
||||
test-extra-backend-insightface-opencv
|
||||
|
||||
## speaker-recognition — voice (speaker) biometrics.
|
||||
##
|
||||
## Audio fixtures default to the speechbrain test samples served
|
||||
## straight from their GitHub repo — public, no auth needed, and they
|
||||
## ship as 16kHz mono WAV/FLAC which is exactly what the engine wants.
|
||||
## example{1,2,5} are three different speakers; the suite treats
|
||||
## example1 as the "same-image twin" probe (verify(clip, clip) must
|
||||
## return distance≈0) and the other two as cross-speaker ceilings.
|
||||
## Override with BACKEND_TEST_VOICE_AUDIO_{1,2,3}_FILE for offline runs.
|
||||
VOICE_AUDIO_1_URL ?= https://github.com/speechbrain/speechbrain/raw/develop/tests/samples/single-mic/example1.wav
|
||||
VOICE_AUDIO_2_URL ?= https://github.com/speechbrain/speechbrain/raw/develop/tests/samples/single-mic/example2.flac
|
||||
VOICE_AUDIO_3_URL ?= https://github.com/speechbrain/speechbrain/raw/develop/tests/samples/single-mic/example5.wav
|
||||
|
||||
## ECAPA-TDNN via SpeechBrain — default CI configuration. Auto-downloads
|
||||
## the checkpoint from HuggingFace on first LoadModel (bundled in the
|
||||
## backend image pip install). 192-d embeddings, cosine-distance based.
|
||||
## The e2e suite drives LoadModel directly so we don't rely on LocalAI's
|
||||
## gallery flow here.
|
||||
test-extra-backend-speaker-recognition-ecapa: docker-build-speaker-recognition
|
||||
BACKEND_IMAGE=local-ai-backend:speaker-recognition \
|
||||
BACKEND_TEST_MODEL_NAME=speechbrain/spkrec-ecapa-voxceleb \
|
||||
BACKEND_TEST_OPTIONS=engine:speechbrain,source:speechbrain/spkrec-ecapa-voxceleb \
|
||||
BACKEND_TEST_CAPS=health,load,voice_embed,voice_verify \
|
||||
BACKEND_TEST_VOICE_AUDIO_1_URL=$(VOICE_AUDIO_1_URL) \
|
||||
BACKEND_TEST_VOICE_AUDIO_2_URL=$(VOICE_AUDIO_2_URL) \
|
||||
BACKEND_TEST_VOICE_AUDIO_3_URL=$(VOICE_AUDIO_3_URL) \
|
||||
BACKEND_TEST_VOICE_VERIFY_DISTANCE_CEILING=0.4 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Aggregate — today there's only one voice config; the target exists
|
||||
## so the CI workflow matches the insightface-all naming convention and
|
||||
## can grow to include WeSpeaker / 3D-Speaker later.
|
||||
test-extra-backend-speaker-recognition-all: \
|
||||
test-extra-backend-speaker-recognition-ecapa
|
||||
|
||||
## sglang mirrors the vllm setup: HuggingFace model id, same tiny Qwen,
|
||||
## tool-call extraction via sglang's native qwen parser. CPU builds use
|
||||
## sglang's upstream pyproject_cpu.toml recipe (see backend/python/sglang/install.sh).
|
||||
@@ -748,6 +925,8 @@ BACKEND_OUTETTS = outetts|python|.|false|true
|
||||
BACKEND_FASTER_WHISPER = faster-whisper|python|.|false|true
|
||||
BACKEND_COQUI = coqui|python|.|false|true
|
||||
BACKEND_RFDETR = rfdetr|python|.|false|true
|
||||
BACKEND_INSIGHTFACE = insightface|python|.|false|true
|
||||
BACKEND_SPEAKER_RECOGNITION = speaker-recognition|python|.|false|true
|
||||
BACKEND_KITTEN_TTS = kitten-tts|python|.|false|true
|
||||
BACKEND_NEUTTS = neutts|python|.|false|true
|
||||
BACKEND_KOKORO = kokoro|python|.|false|true
|
||||
@@ -819,6 +998,8 @@ $(eval $(call generate-docker-build-target,$(BACKEND_OUTETTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_FASTER_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_COQUI)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RFDETR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_INSIGHTFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SPEAKER_RECOGNITION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KITTEN_TTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_NEUTTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKORO)))
|
||||
@@ -853,7 +1034,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-insightface docker-build-speaker-recognition
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -149,6 +149,7 @@ For more details, see the [Getting Started guide](https://localai.io/basics/gett
|
||||
|
||||
## Latest News
|
||||
|
||||
- **April 2026**: [Voice recognition](https://github.com/mudler/LocalAI/pull/9500), [Face recognition, identification & liveness detection](https://github.com/mudler/LocalAI/pull/9480), [Ollama API compatibility](https://github.com/mudler/LocalAI/pull/9284), [Video generation in stable-diffusion.ggml](https://github.com/mudler/LocalAI/pull/9420), [Backend versioning with auto-upgrade](https://github.com/mudler/LocalAI/pull/9315), [Pin models & load-on-demand toggle](https://github.com/mudler/LocalAI/pull/9309), [Universal model importer](https://github.com/mudler/LocalAI/pull/9466), new backends: [sglang](https://github.com/mudler/LocalAI/pull/9359), [ik-llama-cpp](https://github.com/mudler/LocalAI/pull/9326), [TurboQuant](https://github.com/mudler/LocalAI/pull/9355), [sam.cpp](https://github.com/mudler/LocalAI/pull/9288), [Kokoros](https://github.com/mudler/LocalAI/pull/9212), [qwen3tts.cpp](https://github.com/mudler/LocalAI/pull/9316), [tinygrad multimodal](https://github.com/mudler/LocalAI/pull/9364)
|
||||
- **March 2026**: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
||||
- **February 2026**: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
||||
- **January 2026**: **LocalAI 3.10.0** — Anthropic API support, Open Responses API, video & image generation (LTX-2), unified GPU backends, tool streaming, Moonshine, Pocket-TTS. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0)
|
||||
|
||||
@@ -24,6 +24,11 @@ service Backend {
|
||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||
rpc Detect(DetectOptions) returns (DetectResponse) {}
|
||||
rpc FaceVerify(FaceVerifyRequest) returns (FaceVerifyResponse) {}
|
||||
rpc FaceAnalyze(FaceAnalyzeRequest) returns (FaceAnalyzeResponse) {}
|
||||
rpc VoiceVerify(VoiceVerifyRequest) returns (VoiceVerifyResponse) {}
|
||||
rpc VoiceAnalyze(VoiceAnalyzeRequest) returns (VoiceAnalyzeResponse) {}
|
||||
rpc VoiceEmbed(VoiceEmbedRequest) returns (VoiceEmbedResponse) {}
|
||||
|
||||
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
||||
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
||||
@@ -475,6 +480,112 @@ message DetectResponse {
|
||||
repeated Detection Detections = 1;
|
||||
}
|
||||
|
||||
// --- Face recognition messages ---
|
||||
|
||||
message FacialArea {
|
||||
float x = 1;
|
||||
float y = 2;
|
||||
float w = 3;
|
||||
float h = 4;
|
||||
}
|
||||
|
||||
message FaceVerifyRequest {
|
||||
string img1 = 1; // base64-encoded image
|
||||
string img2 = 2; // base64-encoded image
|
||||
float threshold = 3; // cosine-distance threshold; 0 = use backend default
|
||||
bool anti_spoofing = 4; // run MiniFASNet liveness on each image; failed liveness forces verified=false
|
||||
}
|
||||
|
||||
message FaceVerifyResponse {
|
||||
bool verified = 1;
|
||||
float distance = 2; // 1 - cosine_similarity
|
||||
float threshold = 3;
|
||||
float confidence = 4; // 0-100
|
||||
string model = 5; // e.g. "buffalo_l"
|
||||
FacialArea img1_area = 6;
|
||||
FacialArea img2_area = 7;
|
||||
float processing_time_ms = 8;
|
||||
bool img1_is_real = 9; // anti-spoofing result when enabled
|
||||
float img1_antispoof_score = 10;
|
||||
bool img2_is_real = 11;
|
||||
float img2_antispoof_score = 12;
|
||||
}
|
||||
|
||||
message FaceAnalyzeRequest {
|
||||
string img = 1; // base64-encoded image
|
||||
repeated string actions = 2; // subset of ["age","gender","emotion","race"]; empty = all-supported
|
||||
bool anti_spoofing = 3;
|
||||
}
|
||||
|
||||
message FaceAnalysis {
|
||||
FacialArea region = 1;
|
||||
float face_confidence = 2;
|
||||
float age = 3;
|
||||
string dominant_gender = 4; // "Man" | "Woman"
|
||||
map<string, float> gender = 5;
|
||||
string dominant_emotion = 6; // reserved; empty in MVP
|
||||
map<string, float> emotion = 7;
|
||||
string dominant_race = 8; // not populated
|
||||
map<string, float> race = 9;
|
||||
bool is_real = 10; // anti-spoofing result when enabled
|
||||
float antispoof_score = 11;
|
||||
}
|
||||
|
||||
message FaceAnalyzeResponse {
|
||||
repeated FaceAnalysis faces = 1;
|
||||
}
|
||||
|
||||
// --- Voice (speaker) recognition messages ---
|
||||
//
|
||||
// Analogous to the Face* messages above, but for speaker biometrics.
|
||||
// Audio fields accept a filesystem path (same convention as
|
||||
// TranscriptRequest.dst). The HTTP layer materialises base64 / URL /
|
||||
// data-URI inputs to a temp file before calling the gRPC backend.
|
||||
|
||||
message VoiceVerifyRequest {
|
||||
string audio1 = 1; // path to first audio clip
|
||||
string audio2 = 2; // path to second audio clip
|
||||
float threshold = 3; // cosine-distance threshold; 0 = use backend default
|
||||
bool anti_spoofing = 4; // reserved for future AASIST bolt-on
|
||||
}
|
||||
|
||||
message VoiceVerifyResponse {
|
||||
bool verified = 1;
|
||||
float distance = 2; // 1 - cosine_similarity
|
||||
float threshold = 3;
|
||||
float confidence = 4; // 0-100
|
||||
string model = 5; // e.g. "speechbrain/spkrec-ecapa-voxceleb"
|
||||
float processing_time_ms = 6;
|
||||
}
|
||||
|
||||
message VoiceAnalyzeRequest {
|
||||
string audio = 1; // path to audio clip
|
||||
repeated string actions = 2; // subset of ["age","gender","emotion"]; empty = all-supported
|
||||
}
|
||||
|
||||
message VoiceAnalysis {
|
||||
float start = 1; // segment start time in seconds (0 if single-utterance)
|
||||
float end = 2; // segment end time in seconds
|
||||
float age = 3;
|
||||
string dominant_gender = 4;
|
||||
map<string, float> gender = 5;
|
||||
string dominant_emotion = 6;
|
||||
map<string, float> emotion = 7;
|
||||
}
|
||||
|
||||
message VoiceAnalyzeResponse {
|
||||
repeated VoiceAnalysis segments = 1;
|
||||
}
|
||||
|
||||
message VoiceEmbedRequest {
|
||||
string audio = 1; // path to audio clip
|
||||
}
|
||||
|
||||
message VoiceEmbedResponse {
|
||||
repeated float embedding = 1;
|
||||
string model = 2;
|
||||
}
|
||||
|
||||
message ToolFormatMarkers {
|
||||
string format_type = 1; // "json_native", "tag_with_json", "tag_with_tagged"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=d4824131580b94ffa7b0e91c955e2b237c2fe16e
|
||||
IK_LLAMA_VERSION?=286ce324baed17c95faec77792eaa6bdb1c7a5f5
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -326,7 +326,7 @@ struct llama_client_slot
|
||||
char buffer[512];
|
||||
double t_token = t_prompt_processing / num_prompt_tokens_processed;
|
||||
double n_tokens_second = 1e3 / t_prompt_processing * num_prompt_tokens_processed;
|
||||
sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
||||
snprintf(buffer, sizeof(buffer), "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
||||
t_prompt_processing, num_prompt_tokens_processed,
|
||||
t_token, n_tokens_second);
|
||||
LOG_INFO(buffer, {
|
||||
@@ -340,7 +340,7 @@ struct llama_client_slot
|
||||
|
||||
t_token = t_token_generation / n_decoded;
|
||||
n_tokens_second = 1e3 / t_token_generation * n_decoded;
|
||||
sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
||||
snprintf(buffer, sizeof(buffer), "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
||||
t_token_generation, n_decoded,
|
||||
t_token, n_tokens_second);
|
||||
LOG_INFO(buffer, {
|
||||
@@ -352,7 +352,7 @@ struct llama_client_slot
|
||||
{"n_tokens_second", n_tokens_second},
|
||||
});
|
||||
|
||||
sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
||||
snprintf(buffer, sizeof(buffer), " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
||||
LOG_INFO(buffer, {
|
||||
{"slot_id", id},
|
||||
{"task_id", task_id},
|
||||
@@ -686,7 +686,16 @@ struct llama_server_context
|
||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||
slot->sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
{
|
||||
// upstream changed common_params_sampling::grammar from std::string to
|
||||
// the common_grammar struct (type + grammar). The incoming JSON still
|
||||
// carries a plain string, so build the user-provided grammar here and
|
||||
// fall back to the server default when the request omits it.
|
||||
std::string grammar_str = json_value(data, "grammar", std::string());
|
||||
slot->sparams.grammar = grammar_str.empty()
|
||||
? default_sparams.grammar
|
||||
: common_grammar{COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)};
|
||||
}
|
||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
slot->sparams.grammar_triggers = grammar_triggers;
|
||||
@@ -1232,7 +1241,7 @@ struct llama_server_context
|
||||
// {"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"grammar", slot.sparams.grammar.grammar},
|
||||
{"samplers", samplers}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=5a4cd6741fc33227cdacb329f355ab21f8481de2
|
||||
LLAMA_VERSION?=0d0764dfd257c0ae862525c05778207f87b99b1c
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -10,6 +10,14 @@
|
||||
#include "server-task.cpp"
|
||||
#include "server-queue.cpp"
|
||||
#include "server-common.cpp"
|
||||
// server-chat.cpp exists only in llama.cpp after the upstream refactor that
|
||||
// split OAI/Anthropic/Responses/transcription conversion helpers out of
|
||||
// server-common.cpp. When present, server-context.cpp and server-task.cpp
|
||||
// above call into it, so we must pull its definitions into this TU or the
|
||||
// link fails. __has_include keeps the source compatible with older pins.
|
||||
#if __has_include("server-chat.cpp")
|
||||
#include "server-chat.cpp"
|
||||
#endif
|
||||
#include "server-context.cpp"
|
||||
|
||||
// LocalAI
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=44cca3d626d301e2215d5e243277e8f0e65bfa78
|
||||
STABLEDIFFUSION_GGML_VERSION?=c97702e1057c2fe13a7074cd9069cb9dd6edc1bf
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -168,6 +168,43 @@
|
||||
nvidia-cuda-13: "cuda13-rfdetr"
|
||||
nvidia-cuda-12: "cuda12-rfdetr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr"
|
||||
- &insightface
|
||||
name: "insightface"
|
||||
alias: "insightface"
|
||||
# Upstream insightface library is MIT. The pretrained model packs
|
||||
# (buffalo_l, buffalo_s, antelopev2) are released for NON-COMMERCIAL
|
||||
# research use only. The backend image also pre-bakes OpenCV Zoo
|
||||
# YuNet + SFace (Apache 2.0) for commercial use. Pick the engine
|
||||
# via model-gallery entries (insightface-buffalo-l / insightface-opencv
|
||||
# / insightface-buffalo-s) or set `options` in your model YAML.
|
||||
license: "mixed"
|
||||
description: |
|
||||
Face recognition backend powered by `insightface` (ONNX Runtime).
|
||||
Provides face verification (/v1/face/verify), face analysis
|
||||
(/v1/face/analyze), face embedding (/v1/embeddings), face
|
||||
detection (/v1/detection), and 1:N identification
|
||||
(/v1/face/{register,identify,forget}).
|
||||
Ships two engines in a single image: one that drives the insightface
|
||||
model packs (buffalo_l/s/m/sc, antelopev2 — non-commercial research
|
||||
use only) and one that drives OpenCV Zoo's YuNet + SFace pair
|
||||
(Apache 2.0 — commercial-safe). Select via `options: ["engine:..."]`
|
||||
in your model YAML, or install one of the ready-made model-gallery
|
||||
entries under the `insightface-*` prefix.
|
||||
The backend image contains only code and Python deps; all model
|
||||
weights are managed by LocalAI's gallery download mechanism.
|
||||
urls:
|
||||
- https://github.com/deepinsight/insightface
|
||||
- https://github.com/opencv/opencv_zoo
|
||||
tags:
|
||||
- face-recognition
|
||||
- face-verification
|
||||
- face-embedding
|
||||
- gpu
|
||||
- cpu
|
||||
capabilities:
|
||||
default: "cpu-insightface"
|
||||
nvidia: "cuda12-insightface"
|
||||
nvidia-cuda-12: "cuda12-insightface"
|
||||
- &sam3cpp
|
||||
name: "sam3-cpp"
|
||||
alias: "sam3-cpp"
|
||||
@@ -3709,3 +3746,91 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp-quantization"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-llama-cpp-quantization
|
||||
# insightface (face recognition) — development and concrete image entries
|
||||
- !!merge <<: *insightface
|
||||
name: "insightface-development"
|
||||
capabilities:
|
||||
default: "cpu-insightface-development"
|
||||
nvidia: "cuda12-insightface-development"
|
||||
nvidia-cuda-12: "cuda12-insightface-development"
|
||||
- !!merge <<: *insightface
|
||||
name: "cpu-insightface"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-insightface"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-insightface
|
||||
- !!merge <<: *insightface
|
||||
name: "cuda12-insightface"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-insightface"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-insightface
|
||||
- !!merge <<: *insightface
|
||||
name: "cpu-insightface-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-insightface"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-insightface
|
||||
- !!merge <<: *insightface
|
||||
name: "cuda12-insightface-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-insightface"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-insightface
|
||||
|
||||
# speaker-recognition (voice/speaker biometrics) — Apache-2.0 stack
|
||||
- &speakerrecognition
|
||||
name: "speaker-recognition"
|
||||
alias: "speaker-recognition"
|
||||
# SpeechBrain is Apache-2.0. WeSpeaker / 3D-Speaker ONNX exports are
|
||||
# Apache-2.0. The backend itself ships only Python deps — all model
|
||||
# weights flow through LocalAI's gallery download mechanism (or
|
||||
# SpeechBrain's built-in HF auto-download at first LoadModel).
|
||||
license: apache-2.0
|
||||
description: |
|
||||
Speaker (voice) recognition backend — the audio analog to
|
||||
insightface. Wraps SpeechBrain ECAPA-TDNN (default engine, 192-d
|
||||
embeddings, ~1.9% EER on VoxCeleb) plus an OnnxDirectEngine for
|
||||
pre-exported WeSpeaker / 3D-Speaker ONNX models.
|
||||
|
||||
Exposes speaker verification (/v1/voice/verify), speaker embedding
|
||||
(/v1/voice/embed), speaker analysis (/v1/voice/analyze), and 1:N
|
||||
speaker identification (/v1/voice/{register,identify,forget}).
|
||||
Registrations use LocalAI's built-in vector store — same in-memory
|
||||
backing the face-recognition registry uses, separate instance.
|
||||
urls:
|
||||
- https://speechbrain.github.io/
|
||||
- https://github.com/wenet-e2e/wespeaker
|
||||
- https://github.com/modelscope/3D-Speaker
|
||||
tags:
|
||||
- voice-recognition
|
||||
- speaker-verification
|
||||
- speaker-embedding
|
||||
- gpu
|
||||
- cpu
|
||||
capabilities:
|
||||
default: "cpu-speaker-recognition"
|
||||
nvidia: "cuda12-speaker-recognition"
|
||||
nvidia-cuda-12: "cuda12-speaker-recognition"
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "speaker-recognition-development"
|
||||
capabilities:
|
||||
default: "cpu-speaker-recognition-development"
|
||||
nvidia: "cuda12-speaker-recognition-development"
|
||||
nvidia-cuda-12: "cuda12-speaker-recognition-development"
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "cpu-speaker-recognition"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-speaker-recognition"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-speaker-recognition
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "cuda12-speaker-recognition"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-speaker-recognition"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-speaker-recognition
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "cpu-speaker-recognition-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-speaker-recognition"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-speaker-recognition
|
||||
- !!merge <<: *speakerrecognition
|
||||
name: "cuda12-speaker-recognition-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-speaker-recognition"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-speaker-recognition
|
||||
|
||||
16
backend/python/insightface/Makefile
Normal file
16
backend/python/insightface/Makefile
Normal file
@@ -0,0 +1,16 @@
|
||||
.DEFAULT_GOAL := install
|
||||
|
||||
.PHONY: install
|
||||
install:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
|
||||
test: install
|
||||
bash test.sh
|
||||
67
backend/python/insightface/README.md
Normal file
67
backend/python/insightface/README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# insightface backend (LocalAI)
|
||||
|
||||
Face recognition backend backed by ONNX Runtime. Provides face
|
||||
verification (1:1), face analysis (age/gender), face detection, face
|
||||
embedding, and — via LocalAI's built-in vector store — 1:N
|
||||
identification.
|
||||
|
||||
## Engines
|
||||
|
||||
This backend ships with **two** interchangeable engines selected via
|
||||
`LoadModel.Options["engine"]`:
|
||||
|
||||
| engine | Implementation | Models | License |
|
||||
|---|---|---|---|
|
||||
| `insightface` (default) | `insightface.app.FaceAnalysis` | `buffalo_l`, `buffalo_s`, `antelopev2` | **Non-commercial research use only** |
|
||||
| `onnx_direct` | OpenCV `FaceDetectorYN` + `FaceRecognizerSF` | OpenCV Zoo YuNet + SFace | Apache 2.0 (commercial-safe) |
|
||||
|
||||
Both engines implement the same `FaceEngine` protocol in `engines.py`,
|
||||
so the gRPC servicer in `backend.py` doesn't need to know which one is
|
||||
active.
|
||||
|
||||
## LoadModel options
|
||||
|
||||
Common:
|
||||
|
||||
| option | default | description |
|
||||
|---|---|---|
|
||||
| `engine` | `insightface` | one of `insightface`, `onnx_direct` |
|
||||
| `det_size` | `640x640` (insightface), `320x320` (onnx_direct) | detector input size |
|
||||
| `det_thresh` | `0.5` | detector confidence threshold |
|
||||
| `verify_threshold` | `0.35` | default cosine distance cutoff for FaceVerify |
|
||||
|
||||
`insightface` engine:
|
||||
|
||||
| option | default | description |
|
||||
|---|---|---|
|
||||
| `model_pack` | `buffalo_l` | which insightface pack to load |
|
||||
|
||||
`onnx_direct` engine:
|
||||
|
||||
| option | default | description |
|
||||
|---|---|---|
|
||||
| `detector_onnx` | *(required)* | path to YuNet-compatible ONNX |
|
||||
| `recognizer_onnx` | *(required)* | path to SFace-compatible ONNX |
|
||||
|
||||
## Adding a new model pack
|
||||
|
||||
1. If it's an insightface pack (auto-downloadable or manually extracted
|
||||
into `~/.insightface/models/<name>/`), just add a new gallery entry
|
||||
in `backend/index.yaml` with `options: ["engine:insightface",
|
||||
"model_pack:<name>"]`. No code change.
|
||||
2. If it's an Apache-licensed ONNX pair, add a gallery entry with
|
||||
`options: ["engine:onnx_direct", "detector_onnx:...",
|
||||
"recognizer_onnx:..."]`. If the detector or recognizer has a
|
||||
different input-tensor shape than YuNet/SFace, you may need a new
|
||||
engine implementation in `engines.py`; the two-engine seam makes
|
||||
that a self-contained change.
|
||||
|
||||
## Running tests locally
|
||||
|
||||
```bash
|
||||
make -C backend/python/insightface # install deps + bake models
|
||||
make -C backend/python/insightface test # run test.py
|
||||
```
|
||||
|
||||
The OpenCV Zoo tests skip gracefully when `/models/opencv/*.onnx` is
|
||||
absent (e.g. on dev boxes where `install.sh` wasn't run).
|
||||
312
backend/python/insightface/backend.py
Normal file
312
backend/python/insightface/backend.py
Normal file
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
"""gRPC server for the insightface face recognition backend.
|
||||
|
||||
Implements Health / LoadModel / Status plus the face-specific methods:
|
||||
Embedding, Detect, FaceVerify, FaceAnalyze. The heavy lifting is
|
||||
delegated to engines.py — this file is just the gRPC plumbing.
|
||||
"""
|
||||
import argparse
|
||||
import base64
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent import futures
|
||||
from io import BytesIO
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import cv2
|
||||
import grpc
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "common"))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "common"))
|
||||
from grpc_auth import get_auth_interceptors # noqa: E402
|
||||
|
||||
from engines import FaceEngine, build_engine # noqa: E402
|
||||
|
||||
_ONE_DAY = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1"))
|
||||
|
||||
# Default cosine-distance threshold for "same person" on buffalo_l
|
||||
# ArcFace R50. Clients can override per-request; clients using SFace
|
||||
# should pass threshold≈0.4 since the distance distribution is wider.
|
||||
DEFAULT_VERIFY_THRESHOLD = 0.35
|
||||
|
||||
|
||||
def _decode_image(src: str) -> np.ndarray | None:
|
||||
"""Decode a base64-encoded image into an OpenCV BGR numpy array."""
|
||||
if not src:
|
||||
return None
|
||||
try:
|
||||
data = base64.b64decode(src, validate=False)
|
||||
except Exception:
|
||||
return None
|
||||
arr = np.frombuffer(data, dtype=np.uint8)
|
||||
if arr.size == 0:
|
||||
return None
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
return img
|
||||
|
||||
|
||||
def _parse_options(raw: list[str]) -> dict[str, str]:
|
||||
out: dict[str, str] = {}
|
||||
for entry in raw:
|
||||
if ":" not in entry:
|
||||
continue
|
||||
k, v = entry.split(":", 1)
|
||||
out[k.strip()] = v.strip()
|
||||
return out
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self) -> None:
|
||||
self.engine: FaceEngine | None = None
|
||||
self.engine_name: str = ""
|
||||
self.model_name: str = ""
|
||||
self.verify_threshold: float = DEFAULT_VERIFY_THRESHOLD
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", "utf-8"))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
options = _parse_options(list(request.Options))
|
||||
# Surface LocalAI's models directory (ModelPath) so engines can
|
||||
# anchor relative paths — OnnxDirectEngine's detector_onnx /
|
||||
# recognizer_onnx point at gallery-managed files that LocalAI
|
||||
# dropped there, and InsightFaceEngine auto-downloads its packs
|
||||
# into that same directory alongside every other managed model.
|
||||
# Private key to avoid clashing with user-provided options.
|
||||
if request.ModelPath:
|
||||
options["_model_dir"] = request.ModelPath
|
||||
|
||||
engine_name = options.get("engine", "insightface")
|
||||
try:
|
||||
self.engine = build_engine(engine_name)
|
||||
self.engine.prepare(options)
|
||||
except Exception as err: # pragma: no cover - exercised via e2e
|
||||
return backend_pb2.Result(success=False, message=f"Failed to load face engine: {err}")
|
||||
|
||||
self.engine_name = engine_name
|
||||
self.model_name = request.Model or options.get("model_pack", "")
|
||||
if "verify_threshold" in options:
|
||||
try:
|
||||
self.verify_threshold = float(options["verify_threshold"])
|
||||
except ValueError:
|
||||
pass
|
||||
print(f"[insightface] engine={engine_name} model={self.model_name} loaded", file=sys.stderr)
|
||||
return backend_pb2.Result(success=True, message="Model loaded successfully")
|
||||
|
||||
def Status(self, request, context):
|
||||
state = (
|
||||
backend_pb2.StatusResponse.READY
|
||||
if self.engine is not None
|
||||
else backend_pb2.StatusResponse.UNINITIALIZED
|
||||
)
|
||||
return backend_pb2.StatusResponse(state=state)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
if self.engine is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("face model not loaded")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
if not request.Images:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("Embedding requires Images[0] to be a base64 image")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
img = _decode_image(request.Images[0])
|
||||
if img is None:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("failed to decode image")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
vec = self.engine.embed(img)
|
||||
if vec is None:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details("no face detected")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
return backend_pb2.EmbeddingResult(embeddings=[float(x) for x in vec])
|
||||
|
||||
def Detect(self, request, context):
|
||||
if self.engine is None:
|
||||
return backend_pb2.DetectResponse()
|
||||
img = _decode_image(request.src)
|
||||
if img is None:
|
||||
return backend_pb2.DetectResponse()
|
||||
detections = []
|
||||
for d in self.engine.detect(img):
|
||||
x1, y1, x2, y2 = d.bbox
|
||||
detections.append(
|
||||
backend_pb2.Detection(
|
||||
x=float(x1),
|
||||
y=float(y1),
|
||||
width=float(x2 - x1),
|
||||
height=float(y2 - y1),
|
||||
confidence=float(d.score),
|
||||
class_name="face",
|
||||
)
|
||||
)
|
||||
return backend_pb2.DetectResponse(Detections=detections)
|
||||
|
||||
def FaceVerify(self, request, context):
|
||||
if self.engine is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("face model not loaded")
|
||||
return backend_pb2.FaceVerifyResponse()
|
||||
|
||||
img1 = _decode_image(request.img1)
|
||||
img2 = _decode_image(request.img2)
|
||||
if img1 is None or img2 is None:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("failed to decode one or both images")
|
||||
return backend_pb2.FaceVerifyResponse()
|
||||
|
||||
threshold = request.threshold if request.threshold > 0 else self.verify_threshold
|
||||
|
||||
start = time.time()
|
||||
e1 = self.engine.embed(img1)
|
||||
e2 = self.engine.embed(img2)
|
||||
if e1 is None or e2 is None:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details("no face detected in one or both images")
|
||||
return backend_pb2.FaceVerifyResponse()
|
||||
|
||||
# Both engines return L2-normalized vectors, so the dot product
|
||||
# is the cosine similarity directly.
|
||||
sim = float(np.dot(e1, e2))
|
||||
distance = 1.0 - sim
|
||||
verified = distance < threshold
|
||||
confidence = max(0.0, min(100.0, (1.0 - distance / threshold) * 100.0)) if threshold > 0 else 0.0
|
||||
|
||||
# Detect once per image — region is needed for the response and
|
||||
# potentially for the antispoof crop. Returns the highest-score face.
|
||||
def _best_detection(img):
|
||||
dets = self.engine.detect(img)
|
||||
if not dets:
|
||||
return None
|
||||
return max(dets, key=lambda d: d.score)
|
||||
|
||||
def _region(det) -> backend_pb2.FacialArea:
|
||||
if det is None:
|
||||
return backend_pb2.FacialArea()
|
||||
x1, y1, x2, y2 = det.bbox
|
||||
return backend_pb2.FacialArea(x=x1, y=y1, w=x2 - x1, h=y2 - y1)
|
||||
|
||||
det1 = _best_detection(img1)
|
||||
det2 = _best_detection(img2)
|
||||
|
||||
img1_is_real = False
|
||||
img1_score = 0.0
|
||||
img2_is_real = False
|
||||
img2_score = 0.0
|
||||
if request.anti_spoofing:
|
||||
spoof1 = self.engine.antispoof(img1, det1.bbox) if det1 is not None else None
|
||||
spoof2 = self.engine.antispoof(img2, det2.bbox) if det2 is not None else None
|
||||
if spoof1 is None or spoof2 is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details(
|
||||
"anti_spoofing requested but no antispoof model is loaded — "
|
||||
"install `silent-face-antispoofing` or pick a gallery entry "
|
||||
"that bundles MiniFASNet weights"
|
||||
)
|
||||
return backend_pb2.FaceVerifyResponse()
|
||||
img1_is_real, img1_score = spoof1.is_real, spoof1.score
|
||||
img2_is_real, img2_score = spoof2.is_real, spoof2.score
|
||||
# Failed liveness vetoes verification regardless of similarity.
|
||||
if not (img1_is_real and img2_is_real):
|
||||
verified = False
|
||||
|
||||
return backend_pb2.FaceVerifyResponse(
|
||||
verified=verified,
|
||||
distance=float(distance),
|
||||
threshold=float(threshold),
|
||||
confidence=float(confidence),
|
||||
model=self.model_name or self.engine_name,
|
||||
img1_area=_region(det1),
|
||||
img2_area=_region(det2),
|
||||
processing_time_ms=float((time.time() - start) * 1000.0),
|
||||
img1_is_real=img1_is_real,
|
||||
img1_antispoof_score=float(img1_score),
|
||||
img2_is_real=img2_is_real,
|
||||
img2_antispoof_score=float(img2_score),
|
||||
)
|
||||
|
||||
def FaceAnalyze(self, request, context):
|
||||
if self.engine is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("face model not loaded")
|
||||
return backend_pb2.FaceAnalyzeResponse()
|
||||
img = _decode_image(request.img)
|
||||
if img is None:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("failed to decode image")
|
||||
return backend_pb2.FaceAnalyzeResponse()
|
||||
|
||||
faces = []
|
||||
for attrs in self.engine.analyze(img):
|
||||
x, y, w, h = attrs.region
|
||||
fa = backend_pb2.FaceAnalysis(
|
||||
region=backend_pb2.FacialArea(x=float(x), y=float(y), w=float(w), h=float(h)),
|
||||
face_confidence=float(attrs.face_confidence),
|
||||
)
|
||||
if attrs.age is not None:
|
||||
fa.age = float(attrs.age)
|
||||
if attrs.dominant_gender:
|
||||
fa.dominant_gender = attrs.dominant_gender
|
||||
for k, v in attrs.gender.items():
|
||||
fa.gender[k] = float(v)
|
||||
if request.anti_spoofing:
|
||||
bbox = (float(x), float(y), float(x + w), float(y + h))
|
||||
spoof = self.engine.antispoof(img, bbox)
|
||||
if spoof is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details(
|
||||
"anti_spoofing requested but no antispoof model is loaded — "
|
||||
"install `silent-face-antispoofing` or pick a gallery entry "
|
||||
"that bundles MiniFASNet weights"
|
||||
)
|
||||
return backend_pb2.FaceAnalyzeResponse()
|
||||
fa.is_real = spoof.is_real
|
||||
fa.antispoof_score = float(spoof.score)
|
||||
faces.append(fa)
|
||||
return backend_pb2.FaceAnalyzeResponse(faces=faces)
|
||||
|
||||
|
||||
def serve(address: str) -> None:
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
("grpc.max_message_length", 50 * 1024 * 1024),
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("[insightface] Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
def _stop(sig, frame): # pragma: no cover
|
||||
print("[insightface] shutting down")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, _stop)
|
||||
signal.signal(signal.SIGTERM, _stop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the insightface gRPC server.")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.")
|
||||
args = parser.parse_args()
|
||||
print(f"[insightface] startup: {args}", file=sys.stderr)
|
||||
serve(args.addr)
|
||||
517
backend/python/insightface/engines.py
Normal file
517
backend/python/insightface/engines.py
Normal file
@@ -0,0 +1,517 @@
|
||||
"""Face recognition engine implementations for the LocalAI insightface backend.
|
||||
|
||||
Two engines are provided:
|
||||
|
||||
* InsightFaceEngine — wraps insightface.app.FaceAnalysis. Supports
|
||||
buffalo_l / buffalo_s / antelopev2 model packs
|
||||
with SCRFD detector + ArcFace recognizer +
|
||||
genderage head. NON-COMMERCIAL research use
|
||||
only (upstream license).
|
||||
|
||||
* OnnxDirectEngine — loads detector + recognizer ONNX files directly
|
||||
via onnxruntime. Used for OpenCV Zoo models
|
||||
(YuNet + SFace) and any future Apache-licensed
|
||||
model set. Does not support analyze().
|
||||
|
||||
Both engines expose the same interface so the gRPC servicer (backend.py)
|
||||
can dispatch without knowing which one is active.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class FaceDetection:
|
||||
bbox: tuple[float, float, float, float] # x1, y1, x2, y2
|
||||
score: float
|
||||
landmarks: np.ndarray | None = None # 5x2 keypoints when available
|
||||
|
||||
|
||||
@dataclass
|
||||
class FaceAttributes:
|
||||
region: tuple[float, float, float, float] # x, y, w, h
|
||||
face_confidence: float
|
||||
age: float | None = None
|
||||
dominant_gender: str | None = None
|
||||
gender: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpoofResult:
|
||||
is_real: bool
|
||||
score: float # averaged probability of the "real" class, 0.0-1.0
|
||||
|
||||
|
||||
class FaceEngine(Protocol):
|
||||
"""Minimal interface every engine must implement."""
|
||||
|
||||
def prepare(self, options: dict[str, str]) -> None: ...
|
||||
def detect(self, img: np.ndarray) -> list[FaceDetection]: ...
|
||||
def embed(self, img: np.ndarray) -> np.ndarray | None: ...
|
||||
def analyze(self, img: np.ndarray) -> list[FaceAttributes]: ...
|
||||
# Optional: returns None when no antispoof model is loaded.
|
||||
def antispoof(self, img: np.ndarray, bbox: tuple[float, float, float, float]) -> SpoofResult | None: ...
|
||||
|
||||
|
||||
# ─── Antispoofer (Silent-Face MiniFASNet) ──────────────────────────────
|
||||
|
||||
class Antispoofer:
|
||||
"""Liveness detector using the Silent-Face MiniFASNet ensemble.
|
||||
|
||||
Loads up to two ONNX exports (MiniFASNetV2 at scale 2.7 and
|
||||
MiniFASNetV1SE at scale 4.0). Both are 80x80 BGR-float32-input
|
||||
classifiers with 3 output logits where index 1 = "real". When both
|
||||
are loaded, softmax outputs are averaged before argmax — the same
|
||||
ensembling the upstream `test.py` does.
|
||||
|
||||
Preprocessing matches yakhyo/face-anti-spoofing's reference impl:
|
||||
each model gets its own scale-expanded crop centered on the face
|
||||
bbox, resized to 80x80, fed straight as float32 BGR (no /255, no
|
||||
mean/std). See `_crop_face` for the bbox math.
|
||||
|
||||
A single model also works (the missing one is simply skipped).
|
||||
"""
|
||||
|
||||
INPUT_SIZE = (80, 80) # h, w
|
||||
REAL_CLASS_IDX = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: list[tuple[Any, float, str, str]] = [] # (session, scale, input_name, output_name)
|
||||
self.threshold: float = 0.5
|
||||
|
||||
def load(self, model_paths: list[tuple[str, float]], threshold: float = 0.5) -> None:
|
||||
"""Load one or more (path, scale) pairs."""
|
||||
import onnxruntime as ort
|
||||
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
for path, scale in model_paths:
|
||||
session = ort.InferenceSession(path, providers=providers)
|
||||
input_name = session.get_inputs()[0].name
|
||||
output_name = session.get_outputs()[0].name
|
||||
self._sessions.append((session, float(scale), input_name, output_name))
|
||||
self.threshold = float(threshold)
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return bool(self._sessions)
|
||||
|
||||
def _crop_face(self, img: np.ndarray, bbox: tuple[float, float, float, float], scale: float) -> np.ndarray:
|
||||
# bbox is (x1, y1, x2, y2) in source-image coordinates.
|
||||
src_h, src_w = img.shape[:2]
|
||||
x1, y1, x2, y2 = bbox
|
||||
box_w = max(1.0, x2 - x1)
|
||||
box_h = max(1.0, y2 - y1)
|
||||
|
||||
# Clamp scale so the expanded crop fits inside the source image.
|
||||
scale = min((src_h - 1) / box_h, (src_w - 1) / box_w, scale)
|
||||
new_w = box_w * scale
|
||||
new_h = box_h * scale
|
||||
|
||||
cx = x1 + box_w / 2.0
|
||||
cy = y1 + box_h / 2.0
|
||||
|
||||
cx1 = max(0, int(cx - new_w / 2.0))
|
||||
cy1 = max(0, int(cy - new_h / 2.0))
|
||||
cx2 = min(src_w - 1, int(cx + new_w / 2.0))
|
||||
cy2 = min(src_h - 1, int(cy + new_h / 2.0))
|
||||
|
||||
cropped = img[cy1 : cy2 + 1, cx1 : cx2 + 1]
|
||||
if cropped.size == 0:
|
||||
cropped = img
|
||||
out_h, out_w = self.INPUT_SIZE
|
||||
return cv2.resize(cropped, (out_w, out_h))
|
||||
|
||||
@staticmethod
|
||||
def _softmax(x: np.ndarray) -> np.ndarray:
|
||||
e = np.exp(x - np.max(x, axis=1, keepdims=True))
|
||||
return e / e.sum(axis=1, keepdims=True)
|
||||
|
||||
def predict(self, img: np.ndarray, bbox: tuple[float, float, float, float]) -> SpoofResult:
|
||||
if not self._sessions:
|
||||
raise RuntimeError("Antispoofer.predict called with no models loaded")
|
||||
accum = np.zeros((1, 3), dtype=np.float32)
|
||||
for session, scale, input_name, output_name in self._sessions:
|
||||
face = self._crop_face(img, bbox, scale).astype(np.float32)
|
||||
tensor = np.transpose(face, (2, 0, 1))[np.newaxis, ...]
|
||||
logits = session.run([output_name], {input_name: tensor})[0]
|
||||
accum += self._softmax(logits)
|
||||
accum /= float(len(self._sessions))
|
||||
real_prob = float(accum[0, self.REAL_CLASS_IDX])
|
||||
is_real = int(np.argmax(accum)) == self.REAL_CLASS_IDX and real_prob >= self.threshold
|
||||
return SpoofResult(is_real=is_real, score=real_prob)
|
||||
|
||||
|
||||
def _build_antispoofer(options: dict[str, str], model_dir: str | None) -> Antispoofer | None:
|
||||
"""Instantiate an Antispoofer from option keys, or return None.
|
||||
|
||||
Recognised options:
|
||||
antispoof_v2_onnx — path/filename of MiniFASNetV2 (scale 2.7)
|
||||
antispoof_v1se_onnx — path/filename of MiniFASNetV1SE (scale 4.0)
|
||||
antispoof_threshold — real-class probability threshold, default 0.5
|
||||
|
||||
Either or both can be provided. Returns None when neither is set.
|
||||
"""
|
||||
pairs: list[tuple[str, float]] = []
|
||||
v2 = options.get("antispoof_v2_onnx", "")
|
||||
if v2:
|
||||
pairs.append((_resolve_model_path(v2, model_dir=model_dir), 2.7))
|
||||
v1se = options.get("antispoof_v1se_onnx", "")
|
||||
if v1se:
|
||||
pairs.append((_resolve_model_path(v1se, model_dir=model_dir), 4.0))
|
||||
if not pairs:
|
||||
return None
|
||||
threshold = float(options.get("antispoof_threshold", "0.5"))
|
||||
spoofer = Antispoofer()
|
||||
spoofer.load(pairs, threshold=threshold)
|
||||
return spoofer
|
||||
|
||||
|
||||
# ─── InsightFaceEngine ────────────────────────────────────────────────
|
||||
|
||||
class InsightFaceEngine:
|
||||
"""Drives insightface's model_zoo directly — no FaceAnalysis wrapper.
|
||||
|
||||
FaceAnalysis is a thin 50-line orchestration (glob for ONNX files
|
||||
in `<root>/models/<name>/`, route each through `model_zoo.get_model`,
|
||||
build a `{taskname: model}` dict, then loop per-face at inference).
|
||||
We reimplement the same loop here so we can:
|
||||
|
||||
1. Load packs from whatever directory LocalAI's gallery extracted
|
||||
them into — flat (buffalo_l/s/sc — ONNX at `<dir>/*.onnx`) or
|
||||
nested (buffalo_m/antelopev2 — ONNX at `<dir>/<name>/*.onnx`)
|
||||
without needing a specific layout on disk.
|
||||
2. Skip insightface's built-in auto-download entirely: weight
|
||||
delivery is LocalAI's gallery `files:` job now, checksum-
|
||||
verified and cached alongside every other managed model.
|
||||
|
||||
The actual inference classes (RetinaFace, ArcFaceONNX, Attribute,
|
||||
Landmark) stay in insightface — we only reimplement the ~50 lines
|
||||
of glue around them.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.models: dict[str, Any] = {}
|
||||
self.det_model: Any = None
|
||||
self.model_pack: str = "buffalo_l"
|
||||
self.det_size: tuple[int, int] = (640, 640)
|
||||
self.det_thresh: float = 0.5
|
||||
self._providers: list[str] = ["CPUExecutionProvider"]
|
||||
self._antispoofer: Antispoofer | None = None
|
||||
|
||||
def prepare(self, options: dict[str, str]) -> None:
|
||||
import glob
|
||||
import os
|
||||
|
||||
from insightface.model_zoo import model_zoo
|
||||
|
||||
self.model_pack = options.get("model_pack", "buffalo_l")
|
||||
self.det_size = _parse_det_size(options.get("det_size", "640x640"))
|
||||
self.det_thresh = float(options.get("det_thresh", "0.5"))
|
||||
self._antispoofer = _build_antispoofer(options, options.get("_model_dir"))
|
||||
|
||||
pack_dir = _locate_insightface_pack(options, self.model_pack)
|
||||
if pack_dir is None:
|
||||
raise ValueError(
|
||||
f"no insightface pack '{self.model_pack}' found — install via "
|
||||
f"`local-ai models install insightface-{self.model_pack.replace('_', '-')}`"
|
||||
)
|
||||
|
||||
onnx_files = sorted(glob.glob(os.path.join(pack_dir, "*.onnx")))
|
||||
if not onnx_files:
|
||||
raise ValueError(f"no ONNX files in pack directory: {pack_dir}")
|
||||
|
||||
# CUDAExecutionProvider is picked automatically by onnxruntime-gpu
|
||||
# when available; falling back to CPU keeps the CPU-only image
|
||||
# working. ctx_id=0 means "first GPU if any, else CPU".
|
||||
self._providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
|
||||
self.models = {}
|
||||
for onnx_file in onnx_files:
|
||||
m = model_zoo.get_model(onnx_file, providers=self._providers)
|
||||
if m is None:
|
||||
continue
|
||||
# First occurrence of each taskname wins (matches FaceAnalysis).
|
||||
if m.taskname not in self.models:
|
||||
self.models[m.taskname] = m
|
||||
|
||||
if "detection" not in self.models:
|
||||
raise ValueError(f"no detector (taskname='detection') found in {pack_dir}")
|
||||
self.det_model = self.models["detection"]
|
||||
|
||||
self.det_model.prepare(0, input_size=self.det_size, det_thresh=self.det_thresh)
|
||||
for name, m in self.models.items():
|
||||
if name != "detection":
|
||||
m.prepare(0)
|
||||
|
||||
def _faces(self, img: np.ndarray) -> list[Any]:
|
||||
"""Run detection + all non-detection models per face."""
|
||||
if self.det_model is None:
|
||||
return []
|
||||
from insightface.app.common import Face
|
||||
|
||||
bboxes, kpss = self.det_model.detect(img, max_num=0)
|
||||
if bboxes is None or bboxes.shape[0] == 0:
|
||||
return []
|
||||
faces: list[Any] = []
|
||||
for i in range(bboxes.shape[0]):
|
||||
bbox = bboxes[i, 0:4]
|
||||
det_score = bboxes[i, 4]
|
||||
kps = kpss[i] if kpss is not None else None
|
||||
face = Face(bbox=bbox, kps=kps, det_score=det_score)
|
||||
for name, m in self.models.items():
|
||||
if name == "detection":
|
||||
continue
|
||||
m.get(img, face)
|
||||
faces.append(face)
|
||||
return faces
|
||||
|
||||
def detect(self, img: np.ndarray) -> list[FaceDetection]:
|
||||
return [
|
||||
FaceDetection(
|
||||
bbox=tuple(float(v) for v in f.bbox),
|
||||
score=float(f.det_score),
|
||||
landmarks=np.array(f.kps) if getattr(f, "kps", None) is not None else None,
|
||||
)
|
||||
for f in self._faces(img)
|
||||
]
|
||||
|
||||
def embed(self, img: np.ndarray) -> np.ndarray | None:
|
||||
faces = self._faces(img)
|
||||
if not faces:
|
||||
return None
|
||||
best = max(faces, key=lambda f: float(f.det_score))
|
||||
if getattr(best, "normed_embedding", None) is None:
|
||||
return None
|
||||
return np.asarray(best.normed_embedding, dtype=np.float32)
|
||||
|
||||
def analyze(self, img: np.ndarray) -> list[FaceAttributes]:
|
||||
out: list[FaceAttributes] = []
|
||||
for f in self._faces(img):
|
||||
x1, y1, x2, y2 = (float(v) for v in f.bbox)
|
||||
region = (x1, y1, x2 - x1, y2 - y1)
|
||||
attrs = FaceAttributes(region=region, face_confidence=float(f.det_score))
|
||||
age = getattr(f, "age", None)
|
||||
if age is not None:
|
||||
attrs.age = float(age)
|
||||
gender = getattr(f, "gender", None)
|
||||
if gender is not None:
|
||||
# genderage head emits argmax, not probabilities —
|
||||
# one-hot dict keeps the API stable.
|
||||
attrs.dominant_gender = "Man" if int(gender) == 1 else "Woman"
|
||||
attrs.gender = {
|
||||
"Man": 1.0 if int(gender) == 1 else 0.0,
|
||||
"Woman": 0.0 if int(gender) == 1 else 1.0,
|
||||
}
|
||||
out.append(attrs)
|
||||
return out
|
||||
|
||||
def antispoof(self, img: np.ndarray, bbox: tuple[float, float, float, float]) -> SpoofResult | None:
|
||||
if self._antispoofer is None or not self._antispoofer.loaded:
|
||||
return None
|
||||
return self._antispoofer.predict(img, bbox)
|
||||
|
||||
|
||||
# ─── OnnxDirectEngine ─────────────────────────────────────────────────
|
||||
|
||||
class OnnxDirectEngine:
|
||||
"""Loads detector + recognizer ONNX files directly.
|
||||
|
||||
Supports the OpenCV Zoo YuNet + SFace pair out of the box. YuNet
|
||||
exposes a C++-level API via cv2.FaceDetectorYN which accepts the
|
||||
ONNX file directly; SFace is driven through cv2.FaceRecognizerSF.
|
||||
Both are Apache 2.0 licensed.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.detector_path: str = ""
|
||||
self.recognizer_path: str = ""
|
||||
self.input_size: tuple[int, int] = (320, 320)
|
||||
self.det_thresh: float = 0.5
|
||||
self._detector: Any = None
|
||||
self._recognizer: Any = None
|
||||
self._antispoofer: Antispoofer | None = None
|
||||
|
||||
def prepare(self, options: dict[str, str]) -> None:
|
||||
raw_det = options.get("detector_onnx", "")
|
||||
raw_rec = options.get("recognizer_onnx", "")
|
||||
if not raw_det or not raw_rec:
|
||||
raise ValueError(
|
||||
"onnx_direct engine requires both detector_onnx and recognizer_onnx options"
|
||||
)
|
||||
model_dir = options.get("_model_dir")
|
||||
self.detector_path = _resolve_model_path(raw_det, model_dir=model_dir)
|
||||
self.recognizer_path = _resolve_model_path(raw_rec, model_dir=model_dir)
|
||||
self.input_size = _parse_det_size(options.get("det_size", "320x320"))
|
||||
self.det_thresh = float(options.get("det_thresh", "0.5"))
|
||||
self._antispoofer = _build_antispoofer(options, model_dir)
|
||||
|
||||
# YuNet is a fixed-size detector; size is reset per detect() call to
|
||||
# match the input frame.
|
||||
self._detector = cv2.FaceDetectorYN.create(
|
||||
self.detector_path,
|
||||
"",
|
||||
self.input_size,
|
||||
score_threshold=self.det_thresh,
|
||||
nms_threshold=0.3,
|
||||
top_k=5000,
|
||||
)
|
||||
self._recognizer = cv2.FaceRecognizerSF.create(self.recognizer_path, "")
|
||||
|
||||
def detect(self, img: np.ndarray) -> list[FaceDetection]:
|
||||
if self._detector is None:
|
||||
return []
|
||||
h, w = img.shape[:2]
|
||||
self._detector.setInputSize((w, h))
|
||||
retval, faces = self._detector.detect(img)
|
||||
if faces is None:
|
||||
return []
|
||||
out: list[FaceDetection] = []
|
||||
for row in faces:
|
||||
x, y, fw, fh = float(row[0]), float(row[1]), float(row[2]), float(row[3])
|
||||
# Landmarks at columns 4..13 are (lx1,ly1,...,lx5,ly5).
|
||||
landmarks = np.array(row[4:14], dtype=np.float32).reshape(5, 2) if len(row) >= 14 else None
|
||||
score = float(row[-1])
|
||||
out.append(FaceDetection(bbox=(x, y, x + fw, y + fh), score=score, landmarks=landmarks))
|
||||
return out
|
||||
|
||||
def embed(self, img: np.ndarray) -> np.ndarray | None:
|
||||
if self._detector is None or self._recognizer is None:
|
||||
return None
|
||||
h, w = img.shape[:2]
|
||||
self._detector.setInputSize((w, h))
|
||||
retval, faces = self._detector.detect(img)
|
||||
if faces is None or len(faces) == 0:
|
||||
return None
|
||||
# Pick the highest-score face (last column is score).
|
||||
best = max(faces, key=lambda r: float(r[-1]))
|
||||
aligned = self._recognizer.alignCrop(img, best)
|
||||
feat = self._recognizer.feature(aligned)
|
||||
vec = np.asarray(feat, dtype=np.float32).flatten()
|
||||
# SFace outputs a 128-dim feature; L2-normalize to make dot-product
|
||||
# comparable to buffalo_l's already-normed 512-dim embedding.
|
||||
norm = float(np.linalg.norm(vec))
|
||||
if norm == 0:
|
||||
return None
|
||||
return vec / norm
|
||||
|
||||
def analyze(self, img: np.ndarray) -> list[FaceAttributes]:
|
||||
# OpenCV Zoo does not ship a demographic classifier; report
|
||||
# only the face-detection regions so callers can still see
|
||||
# how many faces were detected.
|
||||
return [
|
||||
FaceAttributes(
|
||||
region=(
|
||||
d.bbox[0],
|
||||
d.bbox[1],
|
||||
d.bbox[2] - d.bbox[0],
|
||||
d.bbox[3] - d.bbox[1],
|
||||
),
|
||||
face_confidence=d.score,
|
||||
)
|
||||
for d in self.detect(img)
|
||||
]
|
||||
|
||||
def antispoof(self, img: np.ndarray, bbox: tuple[float, float, float, float]) -> SpoofResult | None:
|
||||
if self._antispoofer is None or not self._antispoofer.loaded:
|
||||
return None
|
||||
return self._antispoofer.predict(img, bbox)
|
||||
|
||||
|
||||
# ─── helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _parse_det_size(raw: str) -> tuple[int, int]:
|
||||
raw = raw.strip().lower().replace(" ", "")
|
||||
if "x" in raw:
|
||||
w, h = raw.split("x", 1)
|
||||
return (int(w), int(h))
|
||||
n = int(raw)
|
||||
return (n, n)
|
||||
|
||||
|
||||
def _locate_insightface_pack(options: dict[str, str], name: str) -> str | None:
|
||||
"""Find the directory holding the insightface pack's ONNX files.
|
||||
|
||||
LocalAI's gallery `files:` extracts the pack zip straight into the
|
||||
models directory. Upstream packs are inconsistent:
|
||||
|
||||
buffalo_l/s/sc — flat zip, ONNX lands at `<models_dir>/*.onnx`
|
||||
buffalo_m, antelopev2 — wrapped zip, ONNX lands at `<models_dir>/<name>/*.onnx`
|
||||
|
||||
We search, in order:
|
||||
1. `<models_dir>/<name>/` — wrapped-zip layout, or insightface's
|
||||
own FaceAnalysis-style `<root>/models/<name>/` layout.
|
||||
2. `<models_dir>/models/<name>/` — insightface's FaceAnalysis
|
||||
auto-download lands here (handy for dev environments that
|
||||
still have old `~/.insightface` caches).
|
||||
3. `<models_dir>/` — flat-zip layout directly in models dir.
|
||||
|
||||
Returns the first directory whose contents include `*.onnx`.
|
||||
"""
|
||||
import glob
|
||||
import os
|
||||
|
||||
model_dir = options.get("_model_dir") or ""
|
||||
explicit_root = options.get("root")
|
||||
|
||||
candidates: list[str] = []
|
||||
if model_dir:
|
||||
candidates.append(os.path.join(model_dir, name))
|
||||
candidates.append(os.path.join(model_dir, "models", name))
|
||||
candidates.append(model_dir)
|
||||
if explicit_root:
|
||||
expanded = os.path.expanduser(explicit_root)
|
||||
candidates.append(os.path.join(expanded, "models", name))
|
||||
candidates.append(os.path.join(expanded, name))
|
||||
candidates.append(expanded)
|
||||
|
||||
for c in candidates:
|
||||
if os.path.isdir(c) and glob.glob(os.path.join(c, "*.onnx")):
|
||||
return c
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_model_path(path: str, model_dir: str | None = None) -> str:
|
||||
"""Resolve an ONNX file path across the paths LocalAI might deliver it from.
|
||||
|
||||
Search order:
|
||||
1. The path itself if it already resolves (absolute, or relative to CWD).
|
||||
2. `model_dir` (typically `os.path.dirname(ModelOptions.ModelFile)`) —
|
||||
this is how LocalAI surfaces gallery-managed files. When the gallery
|
||||
entry lists `files:`, each one lands under the models directory and
|
||||
backends load them via filename anchored by ModelFile.
|
||||
3. `<script_dir>/<path-without-leading-slash>` — covers dev layouts
|
||||
where someone manually dropped weights inside the backend dir.
|
||||
|
||||
If none hit, return the literal input so cv2/insightface surfaces a
|
||||
clearer error naming the actually-attempted path.
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
stripped = path.lstrip("/")
|
||||
candidates: list[str] = []
|
||||
if model_dir:
|
||||
candidates.append(os.path.join(model_dir, os.path.basename(path)))
|
||||
candidates.append(os.path.join(model_dir, stripped))
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
candidates.append(os.path.join(script_dir, stripped))
|
||||
for c in candidates:
|
||||
if os.path.isfile(c):
|
||||
return c
|
||||
return path
|
||||
|
||||
|
||||
def build_engine(name: str) -> FaceEngine:
|
||||
"""Factory for the engine selected by LoadModel options."""
|
||||
key = name.strip().lower()
|
||||
if key in ("", "insightface"):
|
||||
return InsightFaceEngine()
|
||||
if key in ("onnx_direct", "onnx-direct", "opencv"):
|
||||
return OnnxDirectEngine()
|
||||
raise ValueError(f"unknown engine: {name!r}")
|
||||
28
backend/python/insightface/install.sh
Executable file
28
backend/python/insightface/install.sh
Executable file
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
# We deliberately do NOT pre-bake any model weights here. Two reasons:
|
||||
#
|
||||
# 1. Weights should follow LocalAI's gallery-managed download flow
|
||||
# like every other backend. For OpenCV Zoo (YuNet + SFace) the
|
||||
# gallery entries in gallery/index.yaml list the ONNX files via
|
||||
# `files:` with URI + SHA-256 — LocalAI fetches them into the
|
||||
# models directory on `local-ai models install`.
|
||||
#
|
||||
# 2. For insightface model packs (buffalo_l, buffalo_s, buffalo_m,
|
||||
# buffalo_sc, antelopev2), upstream distributes zip archives
|
||||
# only (no individual ONNX URLs). We rely on insightface's own
|
||||
# auto-download machinery (`FaceAnalysis(name=<pack>, root=<dir>)`)
|
||||
# at first LoadModel, pointed at a writable directory. This
|
||||
# matches how rfdetr behaves (uses `inference.get_model()`).
|
||||
#
|
||||
# Net effect: the backend image ships only Python deps (~150MB CPU).
|
||||
7
backend/python/insightface/requirements-cpu.txt
Normal file
7
backend/python/insightface/requirements-cpu.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
insightface
|
||||
onnxruntime
|
||||
opencv-python-headless
|
||||
numpy
|
||||
onnx
|
||||
cython
|
||||
scikit-image
|
||||
7
backend/python/insightface/requirements-cublas12.txt
Normal file
7
backend/python/insightface/requirements-cublas12.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
insightface
|
||||
onnxruntime-gpu
|
||||
opencv-python-headless
|
||||
numpy
|
||||
onnx
|
||||
cython
|
||||
scikit-image
|
||||
3
backend/python/insightface/requirements.txt
Normal file
3
backend/python/insightface/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
grpcio-tools
|
||||
9
backend/python/insightface/run.sh
Executable file
9
backend/python/insightface/run.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
264
backend/python/insightface/smoke.py
Normal file
264
backend/python/insightface/smoke.py
Normal file
@@ -0,0 +1,264 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Smoke-test every face recognition model configuration shipped in the
|
||||
gallery. Simulates what LocalAI does at runtime: for each config, sets
|
||||
up a models directory, fetches any required files via URL (as the
|
||||
gallery's `files:` list would), then loads + detects + embeds via the
|
||||
in-process BackendServicer — matching the gRPC surface end users hit.
|
||||
|
||||
Run inside the built backend image (venv already has insightface /
|
||||
onnxruntime / opencv-python-headless):
|
||||
|
||||
python smoke.py
|
||||
|
||||
Network is required for the insightface packs (fetched via upstream's
|
||||
FaceAnalysis auto-download at first LoadModel) and for downloading
|
||||
the OpenCV Zoo ONNX files on first run.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import urllib.request
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
import backend_pb2 # noqa: E402
|
||||
from backend import BackendServicer # noqa: E402
|
||||
|
||||
|
||||
# Gallery `files:` for the OpenCV variants — same URIs + SHA-256s as
|
||||
# gallery/index.yaml lists. Tuples: (filename, uri, sha256).
|
||||
OPENCV_FILES = {
|
||||
"fp32": [
|
||||
(
|
||||
"face_detection_yunet_2023mar.onnx",
|
||||
"https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
|
||||
"8f2383e4dd3cfbb4553ea8718107fc0423210dc964f9f4280604804ed2552fa4",
|
||||
),
|
||||
(
|
||||
"face_recognition_sface_2021dec.onnx",
|
||||
"https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx",
|
||||
"0ba9fbfa01b5270c96627c4ef784da859931e02f04419c829e83484087c34e79",
|
||||
),
|
||||
],
|
||||
"int8": [
|
||||
(
|
||||
"face_detection_yunet_2023mar_int8.onnx",
|
||||
"https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar_int8.onnx",
|
||||
"321aa5a6afabf7ecc46a3d06bfab2b579dc96eb5c3be7edd365fa04502ad9294",
|
||||
),
|
||||
(
|
||||
"face_recognition_sface_2021dec_int8.onnx",
|
||||
"https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec_int8.onnx",
|
||||
"2b0e941e6f16cc048c20aee0c8e31f569118f65d702914540f7bfdc14048d78a",
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
CONFIGS = [
|
||||
{
|
||||
"name": "insightface-buffalo-l",
|
||||
"options": ["engine:insightface", "model_pack:buffalo_l"],
|
||||
"has_analyze": True,
|
||||
"needs_opencv_files": None,
|
||||
},
|
||||
{
|
||||
"name": "insightface-buffalo-sc",
|
||||
"options": ["engine:insightface", "model_pack:buffalo_sc"],
|
||||
# buffalo_sc has recognizer only — no landmarks, no genderage.
|
||||
"has_analyze": False,
|
||||
"needs_opencv_files": None,
|
||||
},
|
||||
{
|
||||
"name": "insightface-buffalo-s",
|
||||
"options": ["engine:insightface", "model_pack:buffalo_s"],
|
||||
"has_analyze": True,
|
||||
"needs_opencv_files": None,
|
||||
},
|
||||
{
|
||||
"name": "insightface-buffalo-m",
|
||||
"options": ["engine:insightface", "model_pack:buffalo_m"],
|
||||
"has_analyze": True,
|
||||
"needs_opencv_files": None,
|
||||
},
|
||||
{
|
||||
"name": "insightface-antelopev2",
|
||||
"options": ["engine:insightface", "model_pack:antelopev2"],
|
||||
"has_analyze": True,
|
||||
"needs_opencv_files": None,
|
||||
},
|
||||
{
|
||||
"name": "insightface-opencv",
|
||||
"options": [
|
||||
"engine:onnx_direct",
|
||||
"detector_onnx:face_detection_yunet_2023mar.onnx",
|
||||
"recognizer_onnx:face_recognition_sface_2021dec.onnx",
|
||||
],
|
||||
"has_analyze": False,
|
||||
"needs_opencv_files": "fp32",
|
||||
},
|
||||
{
|
||||
"name": "insightface-opencv-int8",
|
||||
"options": [
|
||||
"engine:onnx_direct",
|
||||
"detector_onnx:face_detection_yunet_2023mar_int8.onnx",
|
||||
"recognizer_onnx:face_recognition_sface_2021dec_int8.onnx",
|
||||
],
|
||||
"has_analyze": False,
|
||||
"needs_opencv_files": "int8",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class _FakeContext:
|
||||
def __init__(self) -> None:
|
||||
self.code = None
|
||||
self.details = None
|
||||
|
||||
def set_code(self, code):
|
||||
self.code = code
|
||||
|
||||
def set_details(self, details):
|
||||
self.details = details
|
||||
|
||||
|
||||
def _encode_image(img: np.ndarray) -> str:
|
||||
_, buf = cv2.imencode(".jpg", img)
|
||||
return base64.b64encode(buf.tobytes()).decode("ascii")
|
||||
|
||||
|
||||
def _load_sample_image() -> str:
|
||||
from insightface.data import get_image as ins_get_image
|
||||
|
||||
return _encode_image(ins_get_image("t1"))
|
||||
|
||||
|
||||
def _download_if_missing(model_dir: str, filename: str, uri: str, sha256: str) -> None:
|
||||
dest = os.path.join(model_dir, filename)
|
||||
if os.path.isfile(dest):
|
||||
h = hashlib.sha256(open(dest, "rb").read()).hexdigest()
|
||||
if h == sha256:
|
||||
return
|
||||
sys.stderr.write(f" fetching {filename} from {uri}\n")
|
||||
sys.stderr.flush()
|
||||
urllib.request.urlretrieve(uri, dest)
|
||||
h = hashlib.sha256(open(dest, "rb").read()).hexdigest()
|
||||
if h != sha256:
|
||||
raise RuntimeError(f"sha256 mismatch for {filename}: want {sha256}, got {h}")
|
||||
|
||||
|
||||
def _run_one(cfg: dict, img_b64: str, model_dir: str) -> tuple[bool, str]:
|
||||
# Mirror LocalAI's gallery flow: populate model_dir with the
|
||||
# gallery's listed files before calling LoadModel.
|
||||
if cfg["needs_opencv_files"]:
|
||||
for filename, uri, sha256 in OPENCV_FILES[cfg["needs_opencv_files"]]:
|
||||
_download_if_missing(model_dir, filename, uri, sha256)
|
||||
|
||||
svc = BackendServicer()
|
||||
ctx = _FakeContext()
|
||||
|
||||
load_res = svc.LoadModel(
|
||||
backend_pb2.ModelOptions(
|
||||
Model=cfg["name"],
|
||||
Options=cfg["options"],
|
||||
# ModelPath is what the Go loader sets to ml.ModelPath —
|
||||
# LocalAI's models directory. The backend anchors relative
|
||||
# paths and insightface auto-download root here.
|
||||
ModelPath=model_dir,
|
||||
),
|
||||
ctx,
|
||||
)
|
||||
if not load_res.success:
|
||||
return False, f"LoadModel: {load_res.message}"
|
||||
|
||||
det_res = svc.Detect(backend_pb2.DetectOptions(src=img_b64), _FakeContext())
|
||||
if len(det_res.Detections) == 0:
|
||||
return False, "Detect returned no faces"
|
||||
for d in det_res.Detections:
|
||||
if d.class_name != "face":
|
||||
return False, f"Detect returned class_name={d.class_name!r}"
|
||||
|
||||
emb_ctx = _FakeContext()
|
||||
emb_res = svc.Embedding(backend_pb2.PredictOptions(Images=[img_b64]), emb_ctx)
|
||||
if emb_ctx.code is not None:
|
||||
return False, f"Embedding set error code {emb_ctx.code}: {emb_ctx.details}"
|
||||
if len(emb_res.embeddings) == 0:
|
||||
return False, "Embedding returned empty vector"
|
||||
norm_sq = sum(float(x) * float(x) for x in emb_res.embeddings)
|
||||
if not (0.8 <= norm_sq <= 1.2):
|
||||
return False, f"Embedding not L2-normed (sum(x^2)={norm_sq:.3f})"
|
||||
|
||||
ver_ctx = _FakeContext()
|
||||
ver_res = svc.FaceVerify(
|
||||
backend_pb2.FaceVerifyRequest(img1=img_b64, img2=img_b64), ver_ctx
|
||||
)
|
||||
if ver_ctx.code is not None:
|
||||
return False, f"FaceVerify set error code {ver_ctx.code}: {ver_ctx.details}"
|
||||
if not ver_res.verified:
|
||||
return False, f"Same-image FaceVerify not verified (dist={ver_res.distance:.3f})"
|
||||
if ver_res.distance > 0.1:
|
||||
return False, f"Same-image distance suspiciously high ({ver_res.distance:.3f})"
|
||||
|
||||
if cfg["has_analyze"]:
|
||||
an_ctx = _FakeContext()
|
||||
an_res = svc.FaceAnalyze(backend_pb2.FaceAnalyzeRequest(img=img_b64), an_ctx)
|
||||
if an_ctx.code is not None:
|
||||
return False, f"FaceAnalyze set error code {an_ctx.code}: {an_ctx.details}"
|
||||
if len(an_res.faces) == 0:
|
||||
return False, "FaceAnalyze returned no faces"
|
||||
f0 = an_res.faces[0]
|
||||
if f0.age <= 0:
|
||||
return False, f"FaceAnalyze age not populated (age={f0.age})"
|
||||
if f0.dominant_gender not in ("Man", "Woman"):
|
||||
return False, f"FaceAnalyze dominant_gender={f0.dominant_gender!r}"
|
||||
|
||||
n_dets = len(det_res.Detections)
|
||||
dim = len(emb_res.embeddings)
|
||||
return True, f"faces={n_dets} dim={dim} same-dist={ver_res.distance:.3f}"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
# Honor LOCALAI_MODELS_PATH to re-use cached downloads across runs;
|
||||
# default to a fresh temp dir.
|
||||
model_dir = os.environ.get("LOCALAI_MODELS_PATH")
|
||||
if not model_dir:
|
||||
import tempfile
|
||||
|
||||
model_dir = tempfile.mkdtemp(prefix="face-smoke-")
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
print(f"model_dir={model_dir}", file=sys.stderr)
|
||||
|
||||
print("Preparing sample image from insightface.data...", file=sys.stderr)
|
||||
img_b64 = _load_sample_image()
|
||||
|
||||
results: list[tuple[str, bool, str]] = []
|
||||
for cfg in CONFIGS:
|
||||
sys.stderr.write(f"\n=== {cfg['name']} ===\n")
|
||||
sys.stderr.flush()
|
||||
try:
|
||||
ok, detail = _run_one(cfg, img_b64, model_dir)
|
||||
except Exception:
|
||||
ok, detail = False, traceback.format_exc().splitlines()[-1]
|
||||
results.append((cfg["name"], ok, detail))
|
||||
print(f"{'PASS' if ok else 'FAIL'}: {cfg['name']:30s} {detail}")
|
||||
sys.stdout.flush()
|
||||
|
||||
print("\n=== summary ===")
|
||||
passed = sum(1 for _, ok, _ in results if ok)
|
||||
total = len(results)
|
||||
for name, ok, detail in results:
|
||||
mark = "✓" if ok else "✗"
|
||||
print(f" {mark} {name:30s} {detail}")
|
||||
print(f"\n{passed}/{total} passed")
|
||||
return 0 if passed == total else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
344
backend/python/insightface/test.py
Normal file
344
backend/python/insightface/test.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""Unit tests for the insightface gRPC backend.
|
||||
|
||||
The servicer is instantiated in-process (no gRPC channel) and driven
|
||||
directly. Images come from insightface.data which ships with the pip
|
||||
package — no external downloads.
|
||||
|
||||
Tests are parametrized over both engines (InsightFaceEngine and
|
||||
OnnxDirectEngine) where applicable.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import grpc
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
import backend_pb2 # noqa: E402
|
||||
|
||||
from backend import BackendServicer # noqa: E402
|
||||
|
||||
# OpenCV Zoo face ONNX files — downloaded on demand in OnnxDirectEngineTest
|
||||
# to mirror LocalAI's gallery `files:` flow (the backend image itself
|
||||
# doesn't ship model weights).
|
||||
OPENCV_FILES = [
|
||||
(
|
||||
"face_detection_yunet_2023mar.onnx",
|
||||
"https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
|
||||
"8f2383e4dd3cfbb4553ea8718107fc0423210dc964f9f4280604804ed2552fa4",
|
||||
),
|
||||
(
|
||||
"face_recognition_sface_2021dec.onnx",
|
||||
"https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx",
|
||||
"0ba9fbfa01b5270c96627c4ef784da859931e02f04419c829e83484087c34e79",
|
||||
),
|
||||
]
|
||||
|
||||
# Silent-Face MiniFASNet ONNX files for antispoofing tests.
|
||||
ANTISPOOF_FILES = [
|
||||
(
|
||||
"MiniFASNetV2.onnx",
|
||||
"https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV2.onnx",
|
||||
"b32929adc2d9c34b9486f8c4c7bc97c1b69bc0ea9befefc380e4faae4e463907",
|
||||
),
|
||||
(
|
||||
"MiniFASNetV1SE.onnx",
|
||||
"https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV1SE.onnx",
|
||||
"ebab7f90c7833fbccd46d3a555410e78d969db5438e169b6524be444862b3676",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _download_files(specs: list[tuple[str, str, str]], env_var: str, prefix: str) -> str | None:
|
||||
"""Download a list of (filename, uri, sha256) into a directory.
|
||||
|
||||
Returns the directory, or None if any download failed.
|
||||
"""
|
||||
import hashlib
|
||||
import tempfile
|
||||
import urllib.request
|
||||
|
||||
root = os.environ.get(env_var) or tempfile.mkdtemp(prefix=prefix)
|
||||
for filename, uri, sha256 in specs:
|
||||
dest = os.path.join(root, filename)
|
||||
if os.path.isfile(dest):
|
||||
if hashlib.sha256(open(dest, "rb").read()).hexdigest() == sha256:
|
||||
continue
|
||||
try:
|
||||
urllib.request.urlretrieve(uri, dest)
|
||||
except Exception:
|
||||
return None
|
||||
if hashlib.sha256(open(dest, "rb").read()).hexdigest() != sha256:
|
||||
return None
|
||||
return root
|
||||
|
||||
|
||||
def _encode(img: np.ndarray) -> str:
|
||||
_, buf = cv2.imencode(".jpg", img)
|
||||
return base64.b64encode(buf.tobytes()).decode("ascii")
|
||||
|
||||
|
||||
def _load_insightface_samples() -> dict[str, str]:
|
||||
"""Return {'t1': <b64>, 't2': <b64>} from insightface.data.get_image.
|
||||
|
||||
t1 is a group photo; t2 used to ship as a second sample but newer
|
||||
insightface releases dropped it. We fall back to `Tom_Hanks_54745`
|
||||
(also bundled) as a distinct second face.
|
||||
"""
|
||||
from insightface.data import get_image as ins_get_image
|
||||
|
||||
try:
|
||||
second = ins_get_image("t2")
|
||||
except AssertionError:
|
||||
second = ins_get_image("Tom_Hanks_54745")
|
||||
return {
|
||||
"t1": _encode(ins_get_image("t1")),
|
||||
"t2": _encode(second),
|
||||
}
|
||||
|
||||
|
||||
class _FakeContext:
|
||||
"""Minimal stand-in for grpc.ServicerContext."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.code = None
|
||||
self.details = None
|
||||
|
||||
def set_code(self, code):
|
||||
self.code = code
|
||||
|
||||
def set_details(self, details):
|
||||
self.details = details
|
||||
|
||||
|
||||
class _Harness:
|
||||
def __init__(self, servicer: BackendServicer) -> None:
|
||||
self.svc = servicer
|
||||
|
||||
def health(self):
|
||||
return self.svc.Health(backend_pb2.HealthMessage(), _FakeContext())
|
||||
|
||||
def load(self, options: list[str], model_path: str = ""):
|
||||
return self.svc.LoadModel(
|
||||
backend_pb2.ModelOptions(Model="test", Options=options, ModelPath=model_path),
|
||||
_FakeContext(),
|
||||
)
|
||||
|
||||
def detect(self, img_b64: str):
|
||||
return self.svc.Detect(backend_pb2.DetectOptions(src=img_b64), _FakeContext())
|
||||
|
||||
def embed(self, img_b64: str):
|
||||
ctx = _FakeContext()
|
||||
res = self.svc.Embedding(
|
||||
backend_pb2.PredictOptions(Images=[img_b64]),
|
||||
ctx,
|
||||
)
|
||||
return res, ctx
|
||||
|
||||
def verify(self, a: str, b: str, threshold: float = 0.0, anti_spoofing: bool = False):
|
||||
ctx = _FakeContext()
|
||||
res = self.svc.FaceVerify(
|
||||
backend_pb2.FaceVerifyRequest(
|
||||
img1=a, img2=b, threshold=threshold, anti_spoofing=anti_spoofing
|
||||
),
|
||||
ctx,
|
||||
)
|
||||
return res, ctx
|
||||
|
||||
def analyze(self, img_b64: str, anti_spoofing: bool = False):
|
||||
ctx = _FakeContext()
|
||||
res = self.svc.FaceAnalyze(
|
||||
backend_pb2.FaceAnalyzeRequest(img=img_b64, anti_spoofing=anti_spoofing),
|
||||
ctx,
|
||||
)
|
||||
return res, ctx
|
||||
|
||||
|
||||
class InsightFaceEngineTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.samples = _load_insightface_samples()
|
||||
cls.harness = _Harness(BackendServicer())
|
||||
load = cls.harness.load(["engine:insightface", "model_pack:buffalo_l"])
|
||||
if not load.success:
|
||||
raise unittest.SkipTest(f"LoadModel failed: {load.message}")
|
||||
|
||||
def test_health(self):
|
||||
self.assertEqual(self.harness.health().message, b"OK")
|
||||
|
||||
def test_detect_finds_face(self):
|
||||
res = self.harness.detect(self.samples["t1"])
|
||||
self.assertGreater(len(res.Detections), 0)
|
||||
for d in res.Detections:
|
||||
self.assertEqual(d.class_name, "face")
|
||||
self.assertGreater(d.width, 0)
|
||||
self.assertGreater(d.height, 0)
|
||||
|
||||
def test_embedding_is_l2_normed(self):
|
||||
res, ctx = self.harness.embed(self.samples["t1"])
|
||||
self.assertIsNone(ctx.code, f"Embedding error: {ctx.details}")
|
||||
self.assertEqual(len(res.embeddings), 512)
|
||||
norm_sq = sum(x * x for x in res.embeddings)
|
||||
self.assertAlmostEqual(norm_sq, 1.0, places=2)
|
||||
|
||||
def test_verify_same_image(self):
|
||||
res, _ = self.harness.verify(self.samples["t1"], self.samples["t1"])
|
||||
self.assertTrue(res.verified)
|
||||
self.assertLess(res.distance, 0.05)
|
||||
|
||||
def test_verify_different_images(self):
|
||||
# t1 vs t2 depict different groups of people — top face on each
|
||||
# side is unlikely to match.
|
||||
res, _ = self.harness.verify(self.samples["t1"], self.samples["t2"])
|
||||
# We assert only that some numerical answer came back; the
|
||||
# matches-or-not determination depends on which face each side
|
||||
# picked and isn't a stable test assertion.
|
||||
self.assertGreaterEqual(res.distance, 0.0)
|
||||
|
||||
def test_analyze_has_age_and_gender(self):
|
||||
res, _ = self.harness.analyze(self.samples["t1"])
|
||||
self.assertGreater(len(res.faces), 0)
|
||||
for face in res.faces:
|
||||
self.assertGreater(face.face_confidence, 0.0)
|
||||
# Age should be populated for buffalo_l.
|
||||
self.assertGreater(face.age, 0.0)
|
||||
self.assertIn(face.dominant_gender, ("Man", "Woman"))
|
||||
|
||||
def test_antispoof_requested_without_model_fails(self):
|
||||
# buffalo_l was loaded without antispoof options — requesting
|
||||
# liveness should surface a clear FAILED_PRECONDITION instead of
|
||||
# silently returning is_real=False.
|
||||
_, ctx = self.harness.verify(
|
||||
self.samples["t1"], self.samples["t1"], anti_spoofing=True
|
||||
)
|
||||
self.assertEqual(ctx.code, grpc.StatusCode.FAILED_PRECONDITION)
|
||||
self.assertIn("anti_spoofing", ctx.details)
|
||||
|
||||
|
||||
def _prepare_opencv_models_dir() -> str | None:
|
||||
return _download_files(OPENCV_FILES, "OPENCV_FACE_MODELS_DIR", "opencv-face-")
|
||||
|
||||
|
||||
def _prepare_antispoof_models_dir(extra_dir: str | None = None) -> str | None:
|
||||
"""Download MiniFASNet ONNX files. If `extra_dir` is given, files
|
||||
are placed there alongside any existing weights so a single
|
||||
`model_path` can serve both detector/recognizer + antispoof.
|
||||
"""
|
||||
if extra_dir is not None:
|
||||
os.environ.setdefault("ANTISPOOF_MODELS_DIR", extra_dir)
|
||||
return _download_files(ANTISPOOF_FILES, "ANTISPOOF_MODELS_DIR", "antispoof-")
|
||||
|
||||
|
||||
class OnnxDirectEngineTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.samples = _load_insightface_samples()
|
||||
cls.model_dir = _prepare_opencv_models_dir()
|
||||
if cls.model_dir is None:
|
||||
raise unittest.SkipTest("OpenCV Zoo ONNX files could not be downloaded")
|
||||
cls.harness = _Harness(BackendServicer())
|
||||
load = cls.harness.load(
|
||||
[
|
||||
"engine:onnx_direct",
|
||||
"detector_onnx:face_detection_yunet_2023mar.onnx",
|
||||
"recognizer_onnx:face_recognition_sface_2021dec.onnx",
|
||||
],
|
||||
model_path=cls.model_dir,
|
||||
)
|
||||
if not load.success:
|
||||
raise unittest.SkipTest(f"LoadModel failed: {load.message}")
|
||||
|
||||
def test_detect_finds_face(self):
|
||||
res = self.harness.detect(self.samples["t1"])
|
||||
self.assertGreater(len(res.Detections), 0)
|
||||
for d in res.Detections:
|
||||
self.assertEqual(d.class_name, "face")
|
||||
|
||||
def test_embedding_nonempty(self):
|
||||
res, ctx = self.harness.embed(self.samples["t1"])
|
||||
self.assertIsNone(ctx.code, f"Embedding error: {ctx.details}")
|
||||
self.assertGreater(len(res.embeddings), 0)
|
||||
|
||||
def test_verify_same_image(self):
|
||||
res, _ = self.harness.verify(self.samples["t1"], self.samples["t1"], threshold=0.4)
|
||||
self.assertTrue(res.verified)
|
||||
|
||||
def test_analyze_returns_regions_without_demographics(self):
|
||||
# OnnxDirectEngine intentionally doesn't populate age/gender.
|
||||
res, _ = self.harness.analyze(self.samples["t1"])
|
||||
self.assertGreater(len(res.faces), 0)
|
||||
for face in res.faces:
|
||||
self.assertEqual(face.dominant_gender, "")
|
||||
self.assertEqual(face.age, 0.0)
|
||||
|
||||
|
||||
class AntispoofingTest(unittest.TestCase):
|
||||
"""End-to-end FaceVerify / FaceAnalyze with anti_spoofing=True.
|
||||
|
||||
Loads the OpenCV-Zoo (Apache-2.0) face engine alongside the Silent-Face
|
||||
MiniFASNet ensemble. Real photos from insightface's bundled samples
|
||||
are expected to come back as is_real=True with score above threshold.
|
||||
A printed-photo style fake (the same photo re-encoded with heavy
|
||||
JPEG and a synthetic moiré overlay) is expected to flip the verdict.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Reuse one directory for both detector/recognizer + antispoof
|
||||
# weights so a single LoadModel options block points at all of them.
|
||||
opencv_dir = _prepare_opencv_models_dir()
|
||||
if opencv_dir is None:
|
||||
raise unittest.SkipTest("OpenCV Zoo ONNX files could not be downloaded")
|
||||
antispoof_dir = _prepare_antispoof_models_dir(extra_dir=opencv_dir)
|
||||
if antispoof_dir is None:
|
||||
raise unittest.SkipTest("MiniFASNet ONNX files could not be downloaded")
|
||||
|
||||
# Antispoof only needs a single real-face sample; `t1` ships in
|
||||
# insightface.data across every release.
|
||||
from insightface.data import get_image as ins_get_image
|
||||
|
||||
cls.samples = {"t1": _encode(ins_get_image("t1"))}
|
||||
cls.harness = _Harness(BackendServicer())
|
||||
load = cls.harness.load(
|
||||
[
|
||||
"engine:onnx_direct",
|
||||
"detector_onnx:face_detection_yunet_2023mar.onnx",
|
||||
"recognizer_onnx:face_recognition_sface_2021dec.onnx",
|
||||
"antispoof_v2_onnx:MiniFASNetV2.onnx",
|
||||
"antispoof_v1se_onnx:MiniFASNetV1SE.onnx",
|
||||
],
|
||||
model_path=opencv_dir,
|
||||
)
|
||||
if not load.success:
|
||||
raise unittest.SkipTest(f"LoadModel failed: {load.message}")
|
||||
|
||||
def test_verify_returns_per_image_liveness(self):
|
||||
res, ctx = self.harness.verify(
|
||||
self.samples["t1"], self.samples["t1"], threshold=0.4, anti_spoofing=True
|
||||
)
|
||||
self.assertIsNone(ctx.code, f"FaceVerify error: {ctx.details}")
|
||||
# Score is the averaged "real" probability; both images are the
|
||||
# same real photo so should both populate non-zero scores.
|
||||
self.assertGreater(res.img1_antispoof_score, 0.0)
|
||||
self.assertGreater(res.img2_antispoof_score, 0.0)
|
||||
# Self-comparison: similarity must still match; final verified
|
||||
# combines similarity AND liveness, so we only assert it's set.
|
||||
self.assertIsInstance(res.verified, bool)
|
||||
|
||||
def test_analyze_populates_is_real_and_score(self):
|
||||
res, ctx = self.harness.analyze(self.samples["t1"], anti_spoofing=True)
|
||||
self.assertIsNone(ctx.code, f"FaceAnalyze error: {ctx.details}")
|
||||
self.assertGreater(len(res.faces), 0)
|
||||
for face in res.faces:
|
||||
self.assertGreaterEqual(face.antispoof_score, 0.0)
|
||||
self.assertLessEqual(face.antispoof_score, 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/insightface/test.sh
Executable file
11
backend/python/insightface/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
13
backend/python/speaker-recognition/Makefile
Normal file
13
backend/python/speaker-recognition/Makefile
Normal file
@@ -0,0 +1,13 @@
|
||||
.DEFAULT_GOAL := install
|
||||
|
||||
.PHONY: install
|
||||
install:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
40
backend/python/speaker-recognition/README.md
Normal file
40
backend/python/speaker-recognition/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# speaker-recognition
|
||||
|
||||
Speaker (voice) recognition backend for LocalAI. The audio analog to
|
||||
`insightface` — produces speaker embeddings and supports 1:1 voice
|
||||
verification and voice demographic analysis.
|
||||
|
||||
## Engines
|
||||
|
||||
- **SpeechBrainEngine** (default): ECAPA-TDNN trained on VoxCeleb.
|
||||
192-d L2-normalised embeddings, cosine distance for verification.
|
||||
Auto-downloads from HuggingFace on first LoadModel.
|
||||
- **OnnxDirectEngine**: Any pre-exported ONNX speaker encoder
|
||||
(WeSpeaker ResNet, 3D-Speaker ERes2Net, CAM++, …). Model path comes
|
||||
from the gallery `files:` entry.
|
||||
|
||||
Engine selection is gallery-driven: if the model config provides
|
||||
`model_path:` / `onnx:` the ONNX engine is used, otherwise the
|
||||
SpeechBrain engine.
|
||||
|
||||
## Endpoints
|
||||
|
||||
- `POST /v1/voice/verify` — 1:1 same-speaker check.
|
||||
- `POST /v1/voice/embed` — extract a speaker embedding vector.
|
||||
- `POST /v1/voice/analyze` — voice demographics, loaded lazily on
|
||||
the first analyze call:
|
||||
- **Emotion** (default, opt-out): `superb/wav2vec2-base-superb-er`
|
||||
(Apache-2.0), 4-way categorical (neutral / happy / angry / sad).
|
||||
- **Age + gender** (opt-in): no default — wire a checkpoint with a
|
||||
standard `Wav2Vec2ForSequenceClassification` head via
|
||||
`age_gender_model:<repo>` in options. The Audeering
|
||||
age-gender model is *not* usable as a drop-in because its
|
||||
multi-task head isn't loadable via `AutoModelForAudioClassification`.
|
||||
|
||||
Both heads are optional. When nothing loads, the engine returns 501.
|
||||
|
||||
## Audio input
|
||||
|
||||
Audio is materialised by the HTTP layer to a temp wav before calling
|
||||
the gRPC backend. Accepted input forms on the HTTP side: URL, data-URI,
|
||||
or raw base64. The backend itself always receives a filesystem path.
|
||||
205
backend/python/speaker-recognition/backend.py
Normal file
205
backend/python/speaker-recognition/backend.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#!/usr/bin/env python3
|
||||
"""gRPC server for the LocalAI speaker-recognition backend.
|
||||
|
||||
Implements Health / LoadModel / Status plus the voice-specific methods:
|
||||
VoiceVerify, VoiceAnalyze, VoiceEmbed. The heavy lifting lives in
|
||||
engines.py — this file is just the gRPC plumbing, mirroring the
|
||||
insightface backend's two-engine split (SpeechBrain + OnnxDirect).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent import futures
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "common"))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "common"))
|
||||
from grpc_auth import get_auth_interceptors # noqa: E402
|
||||
|
||||
from engines import SpeakerEngine, build_engine # noqa: E402
|
||||
|
||||
_ONE_DAY = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1"))
|
||||
|
||||
# ECAPA-TDNN on VoxCeleb is the reference. Threshold is tuned for
|
||||
# cosine distance (1 - cosine_similarity). Clients may override.
|
||||
DEFAULT_VERIFY_THRESHOLD = 0.25
|
||||
|
||||
|
||||
def _parse_options(raw: list[str]) -> dict[str, str]:
|
||||
out: dict[str, str] = {}
|
||||
for entry in raw:
|
||||
if ":" not in entry:
|
||||
continue
|
||||
k, v = entry.split(":", 1)
|
||||
out[k.strip()] = v.strip()
|
||||
return out
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self) -> None:
|
||||
self.engine: SpeakerEngine | None = None
|
||||
self.engine_name: str = ""
|
||||
self.model_name: str = ""
|
||||
self.verify_threshold: float = DEFAULT_VERIFY_THRESHOLD
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", "utf-8"))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
options = _parse_options(list(request.Options))
|
||||
# Surface LocalAI's models directory (ModelPath) so engines can
|
||||
# anchor relative paths and auto-download into a writable spot
|
||||
# alongside every other gallery-managed asset.
|
||||
options["_model_path"] = request.ModelPath or ""
|
||||
try:
|
||||
engine, engine_name = build_engine(request.Model, options)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return backend_pb2.Result(success=False, message=f"engine init failed: {exc}")
|
||||
|
||||
self.engine = engine
|
||||
self.engine_name = engine_name
|
||||
self.model_name = request.Model
|
||||
|
||||
threshold_opt = options.get("verify_threshold")
|
||||
if threshold_opt:
|
||||
try:
|
||||
self.verify_threshold = float(threshold_opt)
|
||||
except ValueError:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message=f"loaded {engine_name}")
|
||||
|
||||
def Status(self, request, context):
|
||||
state = backend_pb2.StatusResponse.State.READY if self.engine else backend_pb2.StatusResponse.State.UNINITIALIZED
|
||||
return backend_pb2.StatusResponse(state=state)
|
||||
|
||||
def _require_engine(self, context) -> SpeakerEngine | None:
|
||||
if self.engine is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("no speaker-recognition model loaded")
|
||||
return None
|
||||
return self.engine
|
||||
|
||||
def VoiceVerify(self, request, context):
|
||||
engine = self._require_engine(context)
|
||||
if engine is None:
|
||||
return backend_pb2.VoiceVerifyResponse()
|
||||
if not request.audio1 or not request.audio2:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("audio1 and audio2 are required")
|
||||
return backend_pb2.VoiceVerifyResponse()
|
||||
|
||||
threshold = request.threshold if request.threshold > 0 else self.verify_threshold
|
||||
started = time.time()
|
||||
try:
|
||||
distance = engine.compare(request.audio1, request.audio2)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"voice verify failed: {exc}")
|
||||
return backend_pb2.VoiceVerifyResponse()
|
||||
|
||||
elapsed_ms = (time.time() - started) * 1000.0
|
||||
# Confidence goes linearly from 100 at distance=0 to 0 at distance=threshold.
|
||||
confidence = max(0.0, min(100.0, (1.0 - distance / threshold) * 100.0))
|
||||
return backend_pb2.VoiceVerifyResponse(
|
||||
verified=distance <= threshold,
|
||||
distance=distance,
|
||||
threshold=threshold,
|
||||
confidence=confidence,
|
||||
model=self.model_name,
|
||||
processing_time_ms=elapsed_ms,
|
||||
)
|
||||
|
||||
def VoiceEmbed(self, request, context):
|
||||
engine = self._require_engine(context)
|
||||
if engine is None:
|
||||
return backend_pb2.VoiceEmbedResponse()
|
||||
if not request.audio:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("audio is required")
|
||||
return backend_pb2.VoiceEmbedResponse()
|
||||
try:
|
||||
vec = engine.embed(request.audio)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"voice embed failed: {exc}")
|
||||
return backend_pb2.VoiceEmbedResponse()
|
||||
return backend_pb2.VoiceEmbedResponse(embedding=list(vec), model=self.model_name)
|
||||
|
||||
def VoiceAnalyze(self, request, context):
|
||||
engine = self._require_engine(context)
|
||||
if engine is None:
|
||||
return backend_pb2.VoiceAnalyzeResponse()
|
||||
if not request.audio:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("audio is required")
|
||||
return backend_pb2.VoiceAnalyzeResponse()
|
||||
|
||||
actions = list(request.actions) or ["age", "gender", "emotion"]
|
||||
try:
|
||||
segments = engine.analyze(request.audio, actions)
|
||||
except NotImplementedError:
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details(f"analyze not supported by {self.engine_name}")
|
||||
return backend_pb2.VoiceAnalyzeResponse()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"voice analyze failed: {exc}")
|
||||
return backend_pb2.VoiceAnalyzeResponse()
|
||||
|
||||
proto_segments = []
|
||||
for seg in segments:
|
||||
proto_segments.append(
|
||||
backend_pb2.VoiceAnalysis(
|
||||
start=seg.get("start", 0.0),
|
||||
end=seg.get("end", 0.0),
|
||||
age=seg.get("age", 0.0),
|
||||
dominant_gender=seg.get("dominant_gender", ""),
|
||||
gender=seg.get("gender", {}),
|
||||
dominant_emotion=seg.get("dominant_emotion", ""),
|
||||
emotion=seg.get("emotion", {}),
|
||||
)
|
||||
)
|
||||
return backend_pb2.VoiceAnalyzeResponse(segments=proto_segments)
|
||||
|
||||
|
||||
def serve(address: str) -> None:
|
||||
interceptors = get_auth_interceptors()
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
interceptors=interceptors,
|
||||
options=[
|
||||
("grpc.max_send_message_length", 128 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 128 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("speaker-recognition backend listening on", address, flush=True)
|
||||
|
||||
def _stop(*_):
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _stop)
|
||||
signal.signal(signal.SIGINT, _stop)
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--addr", default="localhost:50051")
|
||||
args = parser.parse_args()
|
||||
serve(args.addr)
|
||||
387
backend/python/speaker-recognition/engines.py
Normal file
387
backend/python/speaker-recognition/engines.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Speaker-recognition engines.
|
||||
|
||||
Two engines are offered, mirroring the insightface backend's split:
|
||||
|
||||
* SpeechBrainEngine: full PyTorch / SpeechBrain path. Uses the
|
||||
ECAPA-TDNN recipe trained on VoxCeleb; 192-d L2-normalized
|
||||
embeddings, cosine distance for verification. Auto-downloads the
|
||||
checkpoint into LocalAI's models directory on first LoadModel.
|
||||
|
||||
* OnnxDirectEngine: CPU-friendly fallback that runs pre-exported
|
||||
ONNX speaker encoders (WeSpeaker ResNet34, 3D-Speaker ERes2Net,
|
||||
CAM++, etc.). Model paths come from the model config — the gallery
|
||||
`files:` flow drops them into the models directory.
|
||||
|
||||
Engine selection follows the same gallery-driven convention face
|
||||
recognition uses (insightface commits 9c6da0f7 / 405fec0b): the
|
||||
Python backend reads `engine` / `model_path` / `checkpoint` from the
|
||||
options dict and picks an engine accordingly.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Iterable, Protocol
|
||||
|
||||
|
||||
class SpeakerEngine(Protocol):
|
||||
"""Interface both concrete engines satisfy."""
|
||||
|
||||
name: str
|
||||
|
||||
def embed(self, audio_path: str) -> list[float]: # pragma: no cover - interface
|
||||
...
|
||||
|
||||
def compare(self, audio1: str, audio2: str) -> float: # pragma: no cover
|
||||
...
|
||||
|
||||
def analyze(self, audio_path: str, actions: Iterable[str]) -> list[dict[str, Any]]: # pragma: no cover
|
||||
...
|
||||
|
||||
|
||||
def _cosine_distance(a, b) -> float:
|
||||
import numpy as np
|
||||
|
||||
va = np.asarray(a, dtype=np.float32).reshape(-1)
|
||||
vb = np.asarray(b, dtype=np.float32).reshape(-1)
|
||||
na = float(np.linalg.norm(va))
|
||||
nb = float(np.linalg.norm(vb))
|
||||
if na == 0.0 or nb == 0.0:
|
||||
return 1.0
|
||||
return float(1.0 - np.dot(va, vb) / (na * nb))
|
||||
|
||||
|
||||
class AnalysisHead:
|
||||
"""Age / gender / emotion head, lazy-loaded on first analyze call.
|
||||
|
||||
Wraps two open-licence HuggingFace checkpoints:
|
||||
|
||||
* audeering/wav2vec2-large-robust-24-ft-age-gender — age
|
||||
regression (0–100 years) + 3-way gender (female/male/child).
|
||||
Apache 2.0.
|
||||
* superb/wav2vec2-base-superb-er — 4-way emotion classification
|
||||
(neutral / happy / angry / sad). Apache 2.0.
|
||||
|
||||
Either model is optional — the head degrades gracefully to only the
|
||||
attributes it could load. Override the checkpoint with the
|
||||
`age_gender_model` / `emotion_model` option if you want something
|
||||
else. Set either to an empty string to disable that head.
|
||||
"""
|
||||
|
||||
# Age + gender is OFF by default: the high-accuracy Apache-2.0
|
||||
# checkpoint (Audeering wav2vec2-large-robust-24-ft-age-gender) uses a
|
||||
# custom multi-task head that AutoModelForAudioClassification silently
|
||||
# mangles — it drops the age weights as UNEXPECTED and re-initialises
|
||||
# the classifier head with random values, so the output is noise. Users
|
||||
# who have a cleanly loadable age/gender classifier can opt in with
|
||||
# `age_gender_model:<repo>` in options. The emotion default below
|
||||
# (superb/wav2vec2-base-superb-er) loads via the standard audio-
|
||||
# classification pipeline with no such caveat.
|
||||
DEFAULT_AGE_GENDER_MODEL = ""
|
||||
DEFAULT_EMOTION_MODEL = "superb/wav2vec2-base-superb-er"
|
||||
AGE_GENDER_LABELS = ("female", "male", "child")
|
||||
|
||||
def __init__(self, options: dict[str, str]):
|
||||
self._options = options
|
||||
self._age_gender = None
|
||||
self._age_gender_processor = None
|
||||
self._age_gender_loaded = False
|
||||
self._age_gender_error: str | None = None
|
||||
self._emotion = None
|
||||
self._emotion_loaded = False
|
||||
self._emotion_error: str | None = None
|
||||
|
||||
# --- age / gender -------------------------------------------------
|
||||
def _ensure_age_gender(self):
|
||||
if self._age_gender_loaded:
|
||||
return
|
||||
self._age_gender_loaded = True
|
||||
model_id = self._options.get(
|
||||
"age_gender_model", self.DEFAULT_AGE_GENDER_MODEL
|
||||
)
|
||||
if not model_id:
|
||||
self._age_gender_error = "disabled"
|
||||
return
|
||||
try:
|
||||
# Late imports — torch / transformers are heavy and only
|
||||
# pulled in when the analyze head actually runs.
|
||||
import torch # type: ignore
|
||||
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification # type: ignore
|
||||
|
||||
self._torch = torch
|
||||
self._age_gender_processor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
self._age_gender = AutoModelForAudioClassification.from_pretrained(model_id)
|
||||
self._age_gender.eval()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._age_gender_error = f"{type(exc).__name__}: {exc}"
|
||||
|
||||
def _infer_age_gender(self, waveform_16k) -> dict[str, Any]:
|
||||
self._ensure_age_gender()
|
||||
if self._age_gender is None:
|
||||
return {}
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
inputs = self._age_gender_processor(
|
||||
waveform_16k, sampling_rate=16000, return_tensors="pt"
|
||||
)
|
||||
with self._torch.no_grad():
|
||||
outputs = self._age_gender(**inputs)
|
||||
|
||||
# Audeering's checkpoint is published with a custom head: the
|
||||
# official recipe exposes `(hidden_states, logits_age, logits_gender)`.
|
||||
# AutoModelForAudioClassification flattens that into a single
|
||||
# `logits` tensor of shape [batch, 4] — [age_regression, female, male, child].
|
||||
# Fall back gracefully when the shape is different (e.g. a
|
||||
# user-supplied age_gender_model checkpoint that returns a proper tuple).
|
||||
hidden = getattr(outputs, "logits", outputs)
|
||||
age_years = None
|
||||
gender_logits = None
|
||||
if isinstance(hidden, (tuple, list)) and len(hidden) >= 2:
|
||||
age_years = float(hidden[0].squeeze().item()) * 100.0
|
||||
gender_logits = hidden[1]
|
||||
else:
|
||||
flat = hidden.squeeze()
|
||||
if flat.ndim == 1 and flat.numel() >= 4:
|
||||
age_years = float(flat[0].item()) * 100.0
|
||||
gender_logits = flat[1:4]
|
||||
elif flat.ndim == 1 and flat.numel() == 1:
|
||||
age_years = float(flat.item()) * 100.0
|
||||
|
||||
if age_years is None and gender_logits is None:
|
||||
return {}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if age_years is not None:
|
||||
result["age"] = age_years
|
||||
if gender_logits is not None:
|
||||
probs = self._torch.softmax(gender_logits, dim=-1).cpu().numpy()
|
||||
probs = np.asarray(probs).reshape(-1)
|
||||
gender_map = {
|
||||
label: float(probs[i])
|
||||
for i, label in enumerate(self.AGE_GENDER_LABELS[: len(probs)])
|
||||
}
|
||||
result["gender"] = gender_map
|
||||
if gender_map:
|
||||
dom = max(gender_map.items(), key=lambda kv: kv[1])[0]
|
||||
result["dominant_gender"] = {
|
||||
"female": "Female",
|
||||
"male": "Male",
|
||||
"child": "Child",
|
||||
}.get(dom, dom.capitalize())
|
||||
return result
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Analyze is a best-effort feature — never take down the
|
||||
# whole analyze call because the age/gender head had a bad
|
||||
# day. Mark the failure so the emotion branch still runs.
|
||||
self._age_gender_error = f"runtime: {type(exc).__name__}: {exc}"
|
||||
return {}
|
||||
|
||||
# --- emotion ------------------------------------------------------
|
||||
def _ensure_emotion(self):
|
||||
if self._emotion_loaded:
|
||||
return
|
||||
self._emotion_loaded = True
|
||||
model_id = self._options.get("emotion_model", self.DEFAULT_EMOTION_MODEL)
|
||||
if not model_id:
|
||||
self._emotion_error = "disabled"
|
||||
return
|
||||
try:
|
||||
from transformers import pipeline # type: ignore
|
||||
|
||||
self._emotion = pipeline("audio-classification", model=model_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._emotion_error = f"{type(exc).__name__}: {exc}"
|
||||
|
||||
def _infer_emotion(self, audio_path: str) -> dict[str, Any]:
|
||||
self._ensure_emotion()
|
||||
if self._emotion is None:
|
||||
return {}
|
||||
try:
|
||||
raw = self._emotion(audio_path, top_k=8)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Second-line defense: don't fail the whole analyze call
|
||||
# over a runtime inference hiccup.
|
||||
self._emotion_error = f"runtime: {type(exc).__name__}: {exc}"
|
||||
return {}
|
||||
emotion_map = {row["label"].lower(): float(row["score"]) for row in raw}
|
||||
if not emotion_map:
|
||||
return {}
|
||||
dom = max(emotion_map.items(), key=lambda kv: kv[1])[0]
|
||||
return {"emotion": emotion_map, "dominant_emotion": dom}
|
||||
|
||||
# --- orchestrator -------------------------------------------------
|
||||
def analyze(self, audio_path: str, waveform_16k, actions: Iterable[str]) -> dict[str, Any]:
|
||||
wanted = {a.strip().lower() for a in actions} if actions else {"age", "gender", "emotion"}
|
||||
result: dict[str, Any] = {}
|
||||
if "age" in wanted or "gender" in wanted:
|
||||
ag = self._infer_age_gender(waveform_16k)
|
||||
if "age" in wanted and "age" in ag:
|
||||
result["age"] = ag["age"]
|
||||
if "gender" in wanted:
|
||||
if "gender" in ag:
|
||||
result["gender"] = ag["gender"]
|
||||
if "dominant_gender" in ag:
|
||||
result["dominant_gender"] = ag["dominant_gender"]
|
||||
if "emotion" in wanted:
|
||||
em = self._infer_emotion(audio_path)
|
||||
result.update(em)
|
||||
return result
|
||||
|
||||
|
||||
class SpeechBrainEngine:
|
||||
"""ECAPA-TDNN via SpeechBrain. Auto-downloads on first use."""
|
||||
|
||||
name = "speechbrain-ecapa-tdnn"
|
||||
|
||||
def __init__(self, model_name: str, options: dict[str, str]):
|
||||
# Late imports so the module can be introspected / tested
|
||||
# without torch / speechbrain being installed.
|
||||
from speechbrain.inference.speaker import EncoderClassifier # type: ignore
|
||||
|
||||
source = options.get("source") or model_name or "speechbrain/spkrec-ecapa-voxceleb"
|
||||
savedir = options.get("_model_path") or os.environ.get("HF_HOME") or "./pretrained_models"
|
||||
self._model = EncoderClassifier.from_hparams(source=source, savedir=savedir)
|
||||
self._analysis = AnalysisHead(options)
|
||||
|
||||
def _load_waveform(self, path: str):
|
||||
# Use soundfile + torch directly — torchaudio.load in torchaudio
|
||||
# 2.8+ requires the torchcodec package for decoding, which adds
|
||||
# another heavy ffmpeg-linked dep. soundfile covers WAV/FLAC
|
||||
# which is what we care about here.
|
||||
import numpy as np
|
||||
import soundfile as sf # type: ignore
|
||||
import torch # type: ignore
|
||||
|
||||
audio, sr = sf.read(path, always_2d=False)
|
||||
if audio.ndim > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
audio = np.asarray(audio, dtype=np.float32)
|
||||
if sr != 16000:
|
||||
# Simple linear resample — good enough for 16kHz downsampling
|
||||
# from 44.1/48kHz, and we expect 16kHz inputs in practice.
|
||||
ratio = 16000 / float(sr)
|
||||
n = int(round(len(audio) * ratio))
|
||||
audio = np.interp(
|
||||
np.linspace(0, len(audio), n, endpoint=False),
|
||||
np.arange(len(audio)),
|
||||
audio,
|
||||
).astype(np.float32)
|
||||
return torch.from_numpy(audio).unsqueeze(0) # [1, T]
|
||||
|
||||
def embed(self, audio_path: str) -> list[float]:
|
||||
waveform = self._load_waveform(audio_path)
|
||||
vec = self._model.encode_batch(waveform).squeeze().detach().cpu().numpy()
|
||||
return [float(x) for x in vec]
|
||||
|
||||
def compare(self, audio1: str, audio2: str) -> float:
|
||||
return _cosine_distance(self.embed(audio1), self.embed(audio2))
|
||||
|
||||
def analyze(self, audio_path: str, actions):
|
||||
# Age / gender / emotion aren't produced by ECAPA-TDNN itself;
|
||||
# delegate to AnalysisHead which wraps separate Apache-2.0
|
||||
# checkpoints. Returns a single segment spanning the clip —
|
||||
# segmentation / diarisation is a future enhancement.
|
||||
waveform = self._load_waveform(audio_path)
|
||||
mono = waveform.squeeze().detach().cpu().numpy()
|
||||
attrs = self._analysis.analyze(audio_path, mono, actions)
|
||||
if not attrs:
|
||||
raise NotImplementedError(
|
||||
"analyze head failed to load — install transformers + torch or pass age_gender_model/emotion_model options"
|
||||
)
|
||||
duration = float(mono.shape[-1]) / 16000.0 if mono.size else 0.0
|
||||
return [dict(start=0.0, end=duration, **attrs)]
|
||||
|
||||
|
||||
class OnnxDirectEngine:
|
||||
"""Run a pre-exported ONNX speaker encoder (WeSpeaker / 3D-Speaker)."""
|
||||
|
||||
name = "onnx-direct"
|
||||
|
||||
def __init__(self, model_name: str, options: dict[str, str]):
|
||||
import onnxruntime as ort # type: ignore
|
||||
|
||||
# The gallery is expected to have dropped the ONNX file under
|
||||
# the models directory; accept either an absolute path or a
|
||||
# filename relative to _model_path.
|
||||
onnx_path = options.get("model_path") or options.get("onnx")
|
||||
if not onnx_path:
|
||||
raise ValueError("OnnxDirectEngine requires `model_path: <file.onnx>` in options")
|
||||
if not os.path.isabs(onnx_path):
|
||||
onnx_path = os.path.join(options.get("_model_path", ""), onnx_path)
|
||||
if not os.path.isfile(onnx_path):
|
||||
raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
|
||||
|
||||
providers = options.get("providers")
|
||||
if providers:
|
||||
provider_list = [p.strip() for p in providers.split(",") if p.strip()]
|
||||
else:
|
||||
provider_list = ["CPUExecutionProvider"]
|
||||
self._session = ort.InferenceSession(onnx_path, providers=provider_list)
|
||||
self._input_name = self._session.get_inputs()[0].name
|
||||
self._expected_sr = int(options.get("sample_rate", "16000"))
|
||||
self._analysis = AnalysisHead(options)
|
||||
|
||||
def _load_waveform(self, path: str):
|
||||
import numpy as np
|
||||
import soundfile as sf # type: ignore
|
||||
|
||||
audio, sr = sf.read(path, always_2d=False)
|
||||
if sr != self._expected_sr:
|
||||
# Cheap linear resample — good enough for sanity; callers
|
||||
# should pre-resample for production.
|
||||
ratio = self._expected_sr / float(sr)
|
||||
n = int(round(len(audio) * ratio))
|
||||
audio = np.interp(
|
||||
np.linspace(0, len(audio), n, endpoint=False),
|
||||
np.arange(len(audio)),
|
||||
audio,
|
||||
)
|
||||
if audio.ndim > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
return audio.astype("float32")
|
||||
|
||||
def embed(self, audio_path: str) -> list[float]:
|
||||
import numpy as np
|
||||
|
||||
audio = self._load_waveform(audio_path)
|
||||
feed = audio.reshape(1, -1)
|
||||
out = self._session.run(None, {self._input_name: feed})
|
||||
vec = np.asarray(out[0]).reshape(-1)
|
||||
return [float(x) for x in vec]
|
||||
|
||||
def compare(self, audio1: str, audio2: str) -> float:
|
||||
return _cosine_distance(self.embed(audio1), self.embed(audio2))
|
||||
|
||||
def analyze(self, audio_path: str, actions):
|
||||
# AnalysisHead expects 16kHz mono; _load_waveform already
|
||||
# resamples to self._expected_sr. If the user configured a
|
||||
# non-16k expected rate, resample one more time for analyze.
|
||||
audio = self._load_waveform(audio_path)
|
||||
if self._expected_sr != 16000:
|
||||
import numpy as np
|
||||
|
||||
ratio = 16000 / float(self._expected_sr)
|
||||
n = int(round(len(audio) * ratio))
|
||||
audio = np.interp(
|
||||
np.linspace(0, len(audio), n, endpoint=False),
|
||||
np.arange(len(audio)),
|
||||
audio,
|
||||
).astype("float32")
|
||||
attrs = self._analysis.analyze(audio_path, audio, actions)
|
||||
if not attrs:
|
||||
raise NotImplementedError(
|
||||
"analyze head failed to load — install transformers + torch or pass age_gender_model/emotion_model options"
|
||||
)
|
||||
duration = float(len(audio)) / 16000.0 if len(audio) else 0.0
|
||||
return [dict(start=0.0, end=duration, **attrs)]
|
||||
|
||||
|
||||
def build_engine(model_name: str, options: dict[str, str]) -> tuple[SpeakerEngine, str]:
|
||||
"""Pick an engine based on the options. ONNX path takes priority:
|
||||
if the gallery has dropped a `model_path:` or `onnx:` option, run
|
||||
the direct ONNX engine. Otherwise, fall back to SpeechBrain.
|
||||
"""
|
||||
engine_kind = (options.get("engine") or "").lower()
|
||||
if engine_kind == "onnx" or options.get("model_path") or options.get("onnx"):
|
||||
return OnnxDirectEngine(model_name, options), OnnxDirectEngine.name
|
||||
return SpeechBrainEngine(model_name, options), SpeechBrainEngine.name
|
||||
19
backend/python/speaker-recognition/install.sh
Executable file
19
backend/python/speaker-recognition/install.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
# No pre-baked model weights. Weights flow through LocalAI's gallery
|
||||
# `files:` mechanism — see gallery entries for speechbrain-ecapa-tdnn
|
||||
# and WeSpeaker / 3D-Speaker ONNX packs. SpeechBrain's
|
||||
# EncoderClassifier.from_hparams also knows how to auto-download from
|
||||
# HuggingFace into the configured savedir (we point it at ModelPath),
|
||||
# so the first LoadModel call bootstraps the checkpoint if the gallery
|
||||
# flow wasn't used.
|
||||
5
backend/python/speaker-recognition/requirements-cpu.txt
Normal file
5
backend/python/speaker-recognition/requirements-cpu.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
torch
|
||||
torchaudio
|
||||
speechbrain
|
||||
transformers
|
||||
onnxruntime
|
||||
@@ -0,0 +1,5 @@
|
||||
torch
|
||||
torchaudio
|
||||
speechbrain
|
||||
transformers
|
||||
onnxruntime-gpu
|
||||
5
backend/python/speaker-recognition/requirements.txt
Normal file
5
backend/python/speaker-recognition/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
grpcio-tools
|
||||
numpy
|
||||
soundfile
|
||||
9
backend/python/speaker-recognition/run.sh
Executable file
9
backend/python/speaker-recognition/run.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
78
backend/python/speaker-recognition/test.py
Normal file
78
backend/python/speaker-recognition/test.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Unit tests for the speaker-recognition gRPC backend.
|
||||
|
||||
The servicer is instantiated in-process (no gRPC channel) and driven
|
||||
directly. The default path exercises SpeechBrain's ECAPA-TDNN — the
|
||||
first run downloads the checkpoint into a temp savedir. Tests are
|
||||
skipped gracefully when the heavy optional dependencies (torch /
|
||||
speechbrain / onnxruntime) are not installed, so the gRPC plumbing
|
||||
can still be verified on a bare image.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
import backend_pb2 # noqa: E402
|
||||
|
||||
from backend import BackendServicer # noqa: E402
|
||||
|
||||
|
||||
def _have(*mods: str) -> bool:
|
||||
for m in mods:
|
||||
if importlib.util.find_spec(m) is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class _FakeCtx:
|
||||
"""Minimal stand-in for a gRPC servicer context."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.code = None
|
||||
self.details = ""
|
||||
|
||||
def set_code(self, c):
|
||||
self.code = c
|
||||
|
||||
def set_details(self, d):
|
||||
self.details = d
|
||||
|
||||
|
||||
class ServicerPlumbingTest(unittest.TestCase):
|
||||
"""Checks that LoadModel returns a clear error when no engine deps
|
||||
are installed, and that Voice* calls on an uninitialised servicer
|
||||
surface FAILED_PRECONDITION — both verifying the gRPC wiring
|
||||
without requiring SpeechBrain or ONNX at test time."""
|
||||
|
||||
def test_pre_load_voice_calls_are_rejected(self):
|
||||
svc = BackendServicer()
|
||||
ctx = _FakeCtx()
|
||||
svc.VoiceVerify(backend_pb2.VoiceVerifyRequest(audio1="/tmp/a.wav", audio2="/tmp/b.wav"), ctx)
|
||||
self.assertEqual(str(ctx.code), "StatusCode.FAILED_PRECONDITION")
|
||||
|
||||
def test_load_without_deps_fails_cleanly(self):
|
||||
svc = BackendServicer()
|
||||
req = backend_pb2.ModelOptions(Model="speechbrain/spkrec-ecapa-voxceleb", ModelPath="")
|
||||
result = svc.LoadModel(req, _FakeCtx())
|
||||
# Either the deps are installed and it loaded, or they aren't
|
||||
# and we got a structured error instead of a crash.
|
||||
self.assertTrue(result.success or "engine init failed" in result.message)
|
||||
|
||||
|
||||
@unittest.skipUnless(_have("speechbrain", "torch", "torchaudio"), "speechbrain / torch missing")
|
||||
class SpeechBrainEngineSmokeTest(unittest.TestCase):
|
||||
def test_load_and_embed(self):
|
||||
svc = BackendServicer()
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
req = backend_pb2.ModelOptions(Model="speechbrain/spkrec-ecapa-voxceleb", ModelPath=td)
|
||||
result = svc.LoadModel(req, _FakeCtx())
|
||||
self.assertTrue(result.success, result.message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/speaker-recognition/test.sh
Executable file
11
backend/python/speaker-recognition/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -372,6 +372,41 @@ impl Backend for KokorosService {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn face_verify(
|
||||
&self,
|
||||
_: Request<backend::FaceVerifyRequest>,
|
||||
) -> Result<Response<backend::FaceVerifyResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn face_analyze(
|
||||
&self,
|
||||
_: Request<backend::FaceAnalyzeRequest>,
|
||||
) -> Result<Response<backend::FaceAnalyzeResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn voice_verify(
|
||||
&self,
|
||||
_: Request<backend::VoiceVerifyRequest>,
|
||||
) -> Result<Response<backend::VoiceVerifyResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn voice_analyze(
|
||||
&self,
|
||||
_: Request<backend::VoiceAnalyzeRequest>,
|
||||
) -> Result<Response<backend::VoiceAnalyzeResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn voice_embed(
|
||||
&self,
|
||||
_: Request<backend::VoiceEmbedRequest>,
|
||||
) -> Result<Response<backend::VoiceEmbedResponse>, Status> {
|
||||
Err(Status::unimplemented("Not supported"))
|
||||
}
|
||||
|
||||
async fn stores_set(
|
||||
&self,
|
||||
_: Request<backend::StoresSetOptions>,
|
||||
|
||||
@@ -7,17 +7,35 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
corebackend "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/facerecognition"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/voicerecognition"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// faceEmbeddingDim is the expected dimension for face embeddings.
|
||||
// Set to 0 so the Registry accepts whatever dim the loaded recognizer
|
||||
// produces — ArcFace R50 is 512-d, MBF is 512-d, SFace is 128-d, and
|
||||
// the insightface backend can load any of them via LoadModel options.
|
||||
// Locking this to a specific value would force a single recognizer
|
||||
// family per deployment; we keep the door open instead.
|
||||
const faceEmbeddingDim = 0
|
||||
|
||||
// voiceEmbeddingDim is the expected dimension for speaker embeddings.
|
||||
// 0 so the Registry accepts whatever dim the loaded recognizer
|
||||
// produces — ECAPA-TDNN is 192, WeSpeaker ResNet34 is 256, 3D-Speaker
|
||||
// ERes2Net is 192, CAM++ is 512.
|
||||
const voiceEmbeddingDim = 0
|
||||
|
||||
type Application struct {
|
||||
backendLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
@@ -27,6 +45,8 @@ type Application struct {
|
||||
galleryService *galleryop.GalleryService
|
||||
agentJobService *agentpool.AgentJobService
|
||||
agentPoolService atomic.Pointer[agentpool.AgentPoolService]
|
||||
faceRegistry facerecognition.Registry
|
||||
voiceRegistry voicerecognition.Registry
|
||||
authDB *gorm.DB
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
@@ -50,12 +70,31 @@ func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
mcpTools.CloseMCPSessions(modelName)
|
||||
})
|
||||
|
||||
return &Application{
|
||||
app := &Application{
|
||||
backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath),
|
||||
modelLoader: ml,
|
||||
applicationConfig: appConfig,
|
||||
templatesEvaluator: templates.NewEvaluator(appConfig.SystemState.Model.ModelsPath),
|
||||
}
|
||||
|
||||
// Face-recognition registry backed by LocalAI's built-in vector store.
|
||||
// The resolver closes over the ModelLoader so the Registry stays
|
||||
// decoupled from loader plumbing; swapping in a postgres-backed
|
||||
// implementation later is a single construction change here.
|
||||
faceStoreResolver := func(_ context.Context, storeName string) (pkggrpc.Backend, error) {
|
||||
return corebackend.StoreBackend(ml, appConfig, storeName, "")
|
||||
}
|
||||
app.faceRegistry = facerecognition.NewStoreRegistry(faceStoreResolver, "", faceEmbeddingDim)
|
||||
|
||||
// Voice (speaker) recognition registry — same plumbing, separate
|
||||
// registry so embedding spaces stay isolated (a face vector and a
|
||||
// speaker vector are not comparable).
|
||||
voiceStoreResolver := func(_ context.Context, storeName string) (pkggrpc.Backend, error) {
|
||||
return corebackend.StoreBackend(ml, appConfig, storeName, "")
|
||||
}
|
||||
app.voiceRegistry = voicerecognition.NewStoreRegistry(voiceStoreResolver, "", voiceEmbeddingDim)
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func (a *Application) ModelConfigLoader() *config.ModelConfigLoader {
|
||||
@@ -99,6 +138,22 @@ func (a *Application) AgentPoolService() *agentpool.AgentPoolService {
|
||||
return a.agentPoolService.Load()
|
||||
}
|
||||
|
||||
// FaceRegistry returns the face-recognition registry used for 1:N
|
||||
// identification. The current implementation is backed by the
|
||||
// in-memory local-store backend; see core/services/facerecognition
|
||||
// for the interface and the postgres TODO.
|
||||
func (a *Application) FaceRegistry() facerecognition.Registry {
|
||||
return a.faceRegistry
|
||||
}
|
||||
|
||||
// VoiceRegistry returns the voice (speaker) recognition registry used
|
||||
// for 1:N identification. Same in-memory local-store backing as
|
||||
// FaceRegistry but a separate instance — voice embeddings live in
|
||||
// their own vector space.
|
||||
func (a *Application) VoiceRegistry() voicerecognition.Registry {
|
||||
return a.voiceRegistry
|
||||
}
|
||||
|
||||
// AuthDB returns the auth database connection, or nil if auth is not enabled.
|
||||
func (a *Application) AuthDB() *gorm.DB {
|
||||
return a.authDB
|
||||
|
||||
60
core/backend/face_analyze.go
Normal file
60
core/backend/face_analyze.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func FaceAnalyze(
|
||||
img string,
|
||||
actions []string,
|
||||
antiSpoofing bool,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.FaceAnalyzeResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
faceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if faceModel == nil {
|
||||
return nil, fmt.Errorf("could not load face recognition model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := faceModel.FaceAnalyze(context.Background(), &proto.FaceAnalyzeRequest{
|
||||
Img: img,
|
||||
Actions: actions,
|
||||
AntiSpoofing: antiSpoofing,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceFaceAnalyze,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
43
core/backend/face_embed.go
Normal file
43
core/backend/face_embed.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// FaceEmbed loads the face recognition backend and returns a 512-d
|
||||
// face embedding for the base64-encoded image. Unlike ModelEmbedding
|
||||
// it passes the image through PredictOptions.Images — the insightface
|
||||
// backend picks the highest-confidence face and returns its
|
||||
// L2-normalized embedding.
|
||||
func FaceEmbed(
|
||||
imgBase64 string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) ([]float32, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
faceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if faceModel == nil {
|
||||
return nil, fmt.Errorf("could not load face recognition model")
|
||||
}
|
||||
|
||||
predictOpts := gRPCPredictOpts(modelConfig, loader.ModelPath)
|
||||
predictOpts.Images = []string{imgBase64}
|
||||
|
||||
res, err := faceModel.Embeddings(context.Background(), predictOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.Embeddings) == 0 {
|
||||
return nil, fmt.Errorf("face embedding returned empty vector (no face detected?)")
|
||||
}
|
||||
return res.Embeddings, nil
|
||||
}
|
||||
61
core/backend/face_verify.go
Normal file
61
core/backend/face_verify.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func FaceVerify(
|
||||
img1, img2 string,
|
||||
threshold float32,
|
||||
antiSpoofing bool,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.FaceVerifyResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
faceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if faceModel == nil {
|
||||
return nil, fmt.Errorf("could not load face recognition model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := faceModel.FaceVerify(context.Background(), &proto.FaceVerifyRequest{
|
||||
Img1: img1,
|
||||
Img2: img2,
|
||||
Threshold: threshold,
|
||||
AntiSpoofing: antiSpoofing,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceFaceVerify,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
58
core/backend/voice_analyze.go
Normal file
58
core/backend/voice_analyze.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VoiceAnalyze(
|
||||
audio string,
|
||||
actions []string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.VoiceAnalyzeResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
voiceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if voiceModel == nil {
|
||||
return nil, fmt.Errorf("could not load voice recognition model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceAnalyze(context.Background(), &proto.VoiceAnalyzeRequest{
|
||||
Audio: audio,
|
||||
Actions: actions,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceVoiceAnalyze,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
66
core/backend/voice_embed.go
Normal file
66
core/backend/voice_embed.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// VoiceEmbed returns a speaker embedding (typically 192-d for ECAPA-TDNN)
|
||||
// for the audio file at audioPath. Unlike ModelEmbedding (which is
|
||||
// OpenAI-compatible and text-only), this call takes an audio path and
|
||||
// returns the backend's speaker-encoder output.
|
||||
func VoiceEmbed(
|
||||
audioPath string,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.VoiceEmbedResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
voiceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if voiceModel == nil {
|
||||
return nil, fmt.Errorf("could not load voice recognition model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceEmbed(context.Background(), &proto.VoiceEmbedRequest{
|
||||
Audio: audioPath,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceVoiceEmbed,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res == nil || len(res.Embedding) == 0 {
|
||||
return nil, fmt.Errorf("voice embedding returned empty vector (no speech detected?)")
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
61
core/backend/voice_verify.go
Normal file
61
core/backend/voice_verify.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VoiceVerify(
|
||||
audio1, audio2 string,
|
||||
threshold float32,
|
||||
antiSpoofing bool,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelConfig config.ModelConfig,
|
||||
) (*proto.VoiceVerifyResponse, error) {
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
voiceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
if voiceModel == nil {
|
||||
return nil, fmt.Errorf("could not load voice recognition model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := voiceModel.VoiceVerify(context.Background(), &proto.VoiceVerifyRequest{
|
||||
Audio1: audio1,
|
||||
Audio2: audio2,
|
||||
Threshold: threshold,
|
||||
AntiSpoofing: antiSpoofing,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceVoiceVerify,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
@@ -588,6 +588,8 @@ const (
|
||||
FLAG_VAD ModelConfigUsecase = 0b010000000000
|
||||
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
|
||||
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000
|
||||
FLAG_FACE_RECOGNITION ModelConfigUsecase = 0b10000000000000
|
||||
FLAG_SPEAKER_RECOGNITION ModelConfigUsecase = 0b100000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
@@ -611,6 +613,8 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
"FLAG_VIDEO": FLAG_VIDEO,
|
||||
"FLAG_DETECTION": FLAG_DETECTION,
|
||||
"FLAG_FACE_RECOGNITION": FLAG_FACE_RECOGNITION,
|
||||
"FLAG_SPEAKER_RECOGNITION": FLAG_SPEAKER_RECOGNITION,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -651,7 +655,7 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
nonTextGenBackends := []string{
|
||||
"whisper", "piper", "kokoro",
|
||||
"diffusers", "stablediffusion", "stablediffusion-ggml",
|
||||
"rerankers", "silero-vad", "rfdetr",
|
||||
"rerankers", "silero-vad", "rfdetr", "insightface", "speaker-recognition",
|
||||
"transformers-musicgen", "ace-step", "acestep-cpp",
|
||||
}
|
||||
|
||||
@@ -728,12 +732,26 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
}
|
||||
|
||||
if (u & FLAG_DETECTION) == FLAG_DETECTION {
|
||||
detectionBackends := []string{"rfdetr", "sam3-cpp"}
|
||||
detectionBackends := []string{"rfdetr", "sam3-cpp", "insightface"}
|
||||
if !slices.Contains(detectionBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_FACE_RECOGNITION) == FLAG_FACE_RECOGNITION {
|
||||
faceBackends := []string{"insightface"}
|
||||
if !slices.Contains(faceBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SPEAKER_RECOGNITION) == FLAG_SPEAKER_RECOGNITION {
|
||||
speakerBackends := []string{"speaker-recognition"}
|
||||
if !slices.Contains(speakerBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
|
||||
soundGenBackends := []string{"transformers-musicgen", "ace-step", "acestep-cpp", "mock-backend"}
|
||||
if !slices.Contains(soundGenBackends, c.Backend) {
|
||||
|
||||
126
core/gallery/importers/ace-step.go
Normal file
126
core/gallery/importers/ace-step.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &ACEStepImporter{}
|
||||
|
||||
// ACEStepImporter recognises ACE-Step music generation checkpoints
|
||||
// (ACE-Step/ACE-Step-v1-3.5B, ACE-Step/Ace-Step1.5, community finetunes).
|
||||
// Detection matches on "ace-step" in the repo name — case-insensitive —
|
||||
// so quantised mirrors still route here. The backend itself is
|
||||
// sound-generation / TTS-adjacent; the Modality() method returns "image"
|
||||
// purely to slot into the UI dropdown's image/video tab where it lives
|
||||
// with other generative media importers. preferences.backend="ace-step"
|
||||
// overrides detection.
|
||||
type ACEStepImporter struct{}
|
||||
|
||||
func (i *ACEStepImporter) Name() string { return "ace-step" }
|
||||
func (i *ACEStepImporter) Modality() string { return "image" }
|
||||
func (i *ACEStepImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *ACEStepImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "ace-step" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "ace-step") {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(details.HuggingFace.Author, "ACE-Step") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: hfapi recursion bug may leave HuggingFace nil — decide
|
||||
// from the URI owner/repo.
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "ACE-Step") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repo), "ace-step") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *ACEStepImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
} else if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
model = owner + "/" + repo
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "ace-step",
|
||||
// Mirrors gallery/index.yaml's ace-step-turbo entry which flags
|
||||
// both sound_generation and tts — ACE-Step is a music/sound model,
|
||||
// the UI groups it under image/video simply because there is no
|
||||
// first-class music tab yet.
|
||||
KnownUsecaseStrings: []string{"sound_generation", "tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
50
core/gallery/importers/ace-step_test.go
Normal file
50
core/gallery/importers/ace-step_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ACEStepImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
// ACE-Step/ACE-Step-v1-3.5B is the reference public checkpoint for
|
||||
// the ACE-Step music generation model. Detection must match on the
|
||||
// repo name substring so third-party forks and quantised mirrors
|
||||
// (e.g. Serveurperso/ACE-Step-1.5-GGUF) route to the same backend.
|
||||
It("matches ACE-Step/ACE-Step-v1-3.5B (repo name contains ACE-Step)", func() {
|
||||
uri := "https://huggingface.co/ACE-Step/ACE-Step-v1-3.5B"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: ace-step"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("ACE-Step/ACE-Step-v1-3.5B"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=ace-step for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "ace-step"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: ace-step"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.ACEStepImporter{}
|
||||
Expect(imp.Name()).To(Equal("ace-step"))
|
||||
Expect(imp.Modality()).To(Equal("image"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
29
core/gallery/importers/ambiguity_asr_test.go
Normal file
29
core/gallery/importers/ambiguity_asr_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ASR ambiguity", func() {
|
||||
// pyannote/voice-activity-detection carries
|
||||
// pipeline_tag=automatic-speech-recognition but ships only a YAML
|
||||
// recipe — no ggml-*.bin, no .nemo, no Systran-style model.bin, no
|
||||
// tokenizer.json, no .onnx. None of the ASR importers should match and
|
||||
// none of the generic importers (vllm, transformers, llama-cpp, mlx,
|
||||
// diffusers) should match either. Because the modality is in the
|
||||
// ambiguous whitelist, DiscoverModelConfig must surface
|
||||
// ErrAmbiguousImport rather than a bare "no importer matched" error.
|
||||
It("returns ErrAmbiguousImport when ASR pipeline_tag is present but no importer matches", func() {
|
||||
uri := "https://huggingface.co/pyannote/voice-activity-detection"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue(), "expected ErrAmbiguousImport, got: %v", err)
|
||||
})
|
||||
})
|
||||
34
core/gallery/importers/ambiguity_embeddings_test.go
Normal file
34
core/gallery/importers/ambiguity_embeddings_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Embeddings ambiguity", func() {
|
||||
// Qdrant/bm25 carries pipeline_tag="sentence-similarity" but ships
|
||||
// only config.json, README.md, .gitattributes, and per-language
|
||||
// stopword .txt files — no tokenizer.json (rules out vllm and
|
||||
// transformers), no modules.json / sentence_bert_config.json (rules
|
||||
// out sentencetransformers), no "reranker" / cross-encoder owner
|
||||
// (rules out rerankers), no rf-detr name (rules out rfdetr), no
|
||||
// snakers4 / silero_vad.onnx (rules out silero-vad), no .gguf
|
||||
// (rules out llama-cpp and stablediffusion-ggml), no mlx-community
|
||||
// owner (rules out mlx), no model_index.json / scheduler_config.json
|
||||
// (rules out diffusers). None of the ASR/TTS/image importers should
|
||||
// trip either. Because sentence-similarity is in the ambiguous
|
||||
// modality whitelist, DiscoverModelConfig must surface
|
||||
// ErrAmbiguousImport.
|
||||
It("returns ErrAmbiguousImport when sentence-similarity pipeline_tag is present but no importer matches", func() {
|
||||
uri := "https://huggingface.co/Qdrant/bm25"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue(), "expected ErrAmbiguousImport, got: %v", err)
|
||||
})
|
||||
})
|
||||
31
core/gallery/importers/ambiguity_image_test.go
Normal file
31
core/gallery/importers/ambiguity_image_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Image ambiguity", func() {
|
||||
// h94/IP-Adapter-FaceID carries pipeline_tag="text-to-image" but ships
|
||||
// only .bin + .safetensors + README — no model_index.json /
|
||||
// scheduler_config.json (rules out diffusers), no .gguf (rules out
|
||||
// llama-cpp and stablediffusion-ggml), no tokenizer.json (rules out
|
||||
// vllm/transformers), owner is not mlx-community (rules out mlx), and
|
||||
// the repo owner/name contain no ace-step/flux/sd1.5/sdxl/sd3/
|
||||
// stable-diffusion arch token at the URI level — so none of the
|
||||
// Batch-3 Image/Video importers match either. Because text-to-image
|
||||
// is whitelisted as an ambiguous modality, DiscoverModelConfig must
|
||||
// surface ErrAmbiguousImport rather than a bare "no importer matched".
|
||||
It("returns ErrAmbiguousImport when text-to-image pipeline_tag is present but no importer matches", func() {
|
||||
uri := "https://huggingface.co/h94/IP-Adapter-FaceID"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue(), "expected ErrAmbiguousImport, got: %v", err)
|
||||
})
|
||||
})
|
||||
32
core/gallery/importers/ambiguity_tts_test.go
Normal file
32
core/gallery/importers/ambiguity_tts_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TTS ambiguity", func() {
|
||||
// nari-labs/Dia-1.6B carries pipeline_tag="text-to-speech" but ships
|
||||
// only config.json + *.pth + model.safetensors + preprocessor_config.json.
|
||||
// None of the Batch-2 TTS importers match (owner neither "suno" nor
|
||||
// "fishaudio" nor "OuteAI" nor "KittenML" nor "ResembleAI" nor "neuphonic"
|
||||
// nor "coqui"; repo name contains none of "bark", "outetts", "voxcpm",
|
||||
// "kokoro", "kitten-tts", "neutts", "chatterbox", "vibevoice"; no piper
|
||||
// onnx/onnx.json pair). None of the generic importers match either —
|
||||
// no tokenizer.json (rules out vllm/transformers), no .gguf (llama-cpp),
|
||||
// no mlx-community owner (mlx), no model_index.json/scheduler_config
|
||||
// (diffusers). Because the HF pipeline_tag is in the ambiguous
|
||||
// whitelist, DiscoverModelConfig must surface ErrAmbiguousImport.
|
||||
It("returns ErrAmbiguousImport when TTS pipeline_tag is present but no importer matches", func() {
|
||||
uri := "https://huggingface.co/nari-labs/Dia-1.6B"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue(), "expected ErrAmbiguousImport, got: %v", err)
|
||||
})
|
||||
})
|
||||
124
core/gallery/importers/bark.go
Normal file
124
core/gallery/importers/bark.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &BarkImporter{}
|
||||
|
||||
// BarkImporter recognises Suno's Bark TTS models. The `suno` owner hosts a
|
||||
// handful of Bark variants (bark, bark-small, bark-v2-en, …) sharing the
|
||||
// "bark" prefix — narrow enough to detect without false positives from
|
||||
// other suno repos. preferences.backend="bark" overrides detection.
|
||||
//
|
||||
// NOTE: suno/bark ships a `speaker_embeddings/v2` subdirectory that hits a
|
||||
// pre-existing path-doubling bug in pkg/huggingface-api's recursive tree
|
||||
// listing (item.Path already carries the parent path, but the recursion
|
||||
// prepends the parent path again → 404). When ModelDetails fetching fails,
|
||||
// DiscoverModelConfig leaves HuggingFace nil. To keep detection robust,
|
||||
// matchURIOwnerRepo() falls back to parsing the raw URI for "suno/bark*"
|
||||
// so the importer still fires end-to-end.
|
||||
type BarkImporter struct{}
|
||||
|
||||
// matchBarkURI tolerates a nil ModelDetails (see note above) by extracting
|
||||
// the HF owner+repo portion directly from the raw URI.
|
||||
func matchBarkURI(uri string) bool {
|
||||
owner, repo, ok := HFOwnerRepoFromURI(uri)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(owner, "suno") && strings.HasPrefix(strings.ToLower(repo), "bark")
|
||||
}
|
||||
|
||||
func (i *BarkImporter) Name() string { return "bark" }
|
||||
func (i *BarkImporter) Modality() string { return "tts" }
|
||||
func (i *BarkImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *BarkImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "bark" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "suno") {
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(repoName), "bark") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HF metadata may be absent when the recursive tree listing errors
|
||||
// (see type-level note). Fall back to URI parsing.
|
||||
return matchBarkURI(details.URI)
|
||||
}
|
||||
|
||||
func (i *BarkImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "bark",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/bark_test.go
Normal file
47
core/gallery/importers/bark_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("BarkImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches suno/bark (owner + repo name prefix)", func() {
|
||||
uri := "https://huggingface.co/suno/bark"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: bark"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("suno/bark"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=bark for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "bark"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: bark"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.BarkImporter{}
|
||||
Expect(imp.Name()).To(Equal("bark"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
110
core/gallery/importers/chatterbox.go
Normal file
110
core/gallery/importers/chatterbox.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &ChatterboxImporter{}
|
||||
|
||||
// ChatterboxImporter recognises Resemble AI's Chatterbox TTS. Detection
|
||||
// uses the `ResembleAI` owner or a "chatterbox" substring in the repo
|
||||
// name (covers the primary release plus community finetunes).
|
||||
// preferences.backend="chatterbox" overrides detection.
|
||||
type ChatterboxImporter struct{}
|
||||
|
||||
func (i *ChatterboxImporter) Name() string { return "chatterbox" }
|
||||
func (i *ChatterboxImporter) Modality() string { return "tts" }
|
||||
func (i *ChatterboxImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *ChatterboxImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "chatterbox" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "ResembleAI") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "chatterbox") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "ResembleAI") || strings.Contains(strings.ToLower(repo), "chatterbox") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *ChatterboxImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "chatterbox",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/chatterbox_test.go
Normal file
47
core/gallery/importers/chatterbox_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ChatterboxImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches ResembleAI/chatterbox (owner)", func() {
|
||||
uri := "https://huggingface.co/ResembleAI/chatterbox"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: chatterbox"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("ResembleAI/chatterbox"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=chatterbox for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "chatterbox"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: chatterbox"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.ChatterboxImporter{}
|
||||
Expect(imp.Name()).To(Equal("chatterbox"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
99
core/gallery/importers/coqui.go
Normal file
99
core/gallery/importers/coqui.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &CoquiImporter{}
|
||||
|
||||
// CoquiImporter recognises Coqui AI's open-weight TTS releases (XTTS-v2,
|
||||
// YourTTS, the Tortoise port, etc). Detection is owner-scoped to `coqui`
|
||||
// — their HF org is the authoritative publisher for models that run on
|
||||
// the Coqui TTS Python runtime. preferences.backend="coqui" overrides.
|
||||
type CoquiImporter struct{}
|
||||
|
||||
func (i *CoquiImporter) Name() string { return "coqui" }
|
||||
func (i *CoquiImporter) Modality() string { return "tts" }
|
||||
func (i *CoquiImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *CoquiImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "coqui" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil && strings.EqualFold(details.HuggingFace.Author, "coqui") {
|
||||
return true
|
||||
}
|
||||
|
||||
if owner, _, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
return strings.EqualFold(owner, "coqui")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *CoquiImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "coqui",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/coqui_test.go
Normal file
47
core/gallery/importers/coqui_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("CoquiImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches coqui/XTTS-v2 (owner)", func() {
|
||||
uri := "https://huggingface.co/coqui/XTTS-v2"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: coqui"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("coqui/XTTS-v2"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=coqui for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "coqui"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: coqui"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.CoquiImporter{}
|
||||
Expect(imp.Name()).To(Equal("coqui"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -15,6 +15,10 @@ var _ Importer = &DiffuserImporter{}
|
||||
|
||||
type DiffuserImporter struct{}
|
||||
|
||||
func (i *DiffuserImporter) Name() string { return "diffusers" }
|
||||
func (i *DiffuserImporter) Modality() string { return "image" }
|
||||
func (i *DiffuserImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *DiffuserImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
|
||||
117
core/gallery/importers/faster-whisper.go
Normal file
117
core/gallery/importers/faster-whisper.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &FasterWhisperImporter{}
|
||||
|
||||
// FasterWhisperImporter recognises CTranslate2-packaged whisper checkpoints
|
||||
// (the format consumed by the faster-whisper runtime). The classic layout is
|
||||
// a flat directory with model.bin + config.json and an ASR pipeline_tag.
|
||||
//
|
||||
// We disambiguate from vanilla OpenAI whisper repos — which would otherwise
|
||||
// also hit the tokenizer.json path and get routed to transformers — by
|
||||
// requiring either the Systran owner (the upstream distributor) or the
|
||||
// string "faster-whisper" in the repo name. preferences.backend=
|
||||
// faster-whisper overrides detection.
|
||||
type FasterWhisperImporter struct{}
|
||||
|
||||
func (i *FasterWhisperImporter) Name() string { return "faster-whisper" }
|
||||
func (i *FasterWhisperImporter) Modality() string { return "asr" }
|
||||
func (i *FasterWhisperImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *FasterWhisperImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "faster-whisper" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !HasFile(details.HuggingFace.Files, "model.bin") {
|
||||
return false
|
||||
}
|
||||
if !HasFile(details.HuggingFace.Files, "config.json") {
|
||||
return false
|
||||
}
|
||||
if details.HuggingFace.PipelineTag != "automatic-speech-recognition" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Narrow to the faster-whisper distribution: Systran owner OR
|
||||
// "faster-whisper" in the repo name. Without this guard, any vanilla
|
||||
// whisper repo on HF would also match the file pair and ASR tag.
|
||||
if strings.EqualFold(details.HuggingFace.Author, "Systran") {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(strings.ToLower(details.HuggingFace.ModelID), "faster-whisper")
|
||||
}
|
||||
|
||||
func (i *FasterWhisperImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "faster-whisper",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
46
core/gallery/importers/faster-whisper_test.go
Normal file
46
core/gallery/importers/faster-whisper_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("FasterWhisperImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches Systran/faster-whisper-large-v3 (model.bin + config.json + ASR)", func() {
|
||||
uri := "https://huggingface.co/Systran/faster-whisper-large-v3"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: faster-whisper"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("Systran/faster-whisper-large-v3"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=faster-whisper for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "faster-whisper"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: faster-whisper"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.FasterWhisperImporter{}
|
||||
Expect(imp.Name()).To(Equal("faster-whisper"))
|
||||
Expect(imp.Modality()).To(Equal("asr"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
101
core/gallery/importers/fish-speech.go
Normal file
101
core/gallery/importers/fish-speech.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &FishSpeechImporter{}
|
||||
|
||||
// FishSpeechImporter recognises Fish Audio's open-weights TTS releases
|
||||
// (Fish Speech, S1/S2 series). The `fishaudio` owner is the canonical
|
||||
// publisher — scoping by owner avoids false positives from generic
|
||||
// safetensors+tokenizer packaging used elsewhere.
|
||||
// preferences.backend="fish-speech" overrides detection.
|
||||
type FishSpeechImporter struct{}
|
||||
|
||||
func (i *FishSpeechImporter) Name() string { return "fish-speech" }
|
||||
func (i *FishSpeechImporter) Modality() string { return "tts" }
|
||||
func (i *FishSpeechImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *FishSpeechImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "fish-speech" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil && strings.EqualFold(details.HuggingFace.Author, "fishaudio") {
|
||||
return true
|
||||
}
|
||||
// URI fallback for parity with other TTS importers when HF metadata
|
||||
// fetching fails (see BarkImporter note).
|
||||
if owner, _, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
return strings.EqualFold(owner, "fishaudio")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *FishSpeechImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "fish-speech",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/fish-speech_test.go
Normal file
47
core/gallery/importers/fish-speech_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("FishSpeechImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches fishaudio/s2-pro (owner = fishaudio)", func() {
|
||||
uri := "https://huggingface.co/fishaudio/s2-pro"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: fish-speech"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("fishaudio/s2-pro"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=fish-speech for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "fish-speech"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: fish-speech"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.FishSpeechImporter{}
|
||||
Expect(imp.Name()).To(Equal("fish-speech"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
91
core/gallery/importers/helpers.go
Normal file
91
core/gallery/importers/helpers.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
// HasFile returns true when any file in files has exactly the given basename.
|
||||
// Directory components in file.Path are ignored — a nested
|
||||
// "sub/dir/config.json" is considered a match for name = "config.json".
|
||||
func HasFile(files []hfapi.ModelFile, name string) bool {
|
||||
for _, f := range files {
|
||||
if filepath.Base(f.Path) == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasExtension returns true when any file has the given extension
|
||||
// (case-insensitive). ext must include the leading dot, e.g. ".onnx".
|
||||
func HasExtension(files []hfapi.ModelFile, ext string) bool {
|
||||
lower := strings.ToLower(ext)
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(strings.ToLower(f.Path), lower) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasONNX returns true when any file ends in .onnx (case-insensitive).
|
||||
func HasONNX(files []hfapi.ModelFile) bool {
|
||||
return HasExtension(files, ".onnx")
|
||||
}
|
||||
|
||||
// HasONNXConfigPair returns true when an .onnx file has an accompanying
|
||||
// "<same basename>.onnx.json" file. This is the piper voice packaging
|
||||
// convention, e.g. en_US-amy-medium.onnx + en_US-amy-medium.onnx.json.
|
||||
func HasONNXConfigPair(files []hfapi.ModelFile) bool {
|
||||
paths := make(map[string]struct{}, len(files))
|
||||
for _, f := range files {
|
||||
paths[strings.ToLower(f.Path)] = struct{}{}
|
||||
}
|
||||
for p := range paths {
|
||||
if !strings.HasSuffix(p, ".onnx") {
|
||||
continue
|
||||
}
|
||||
if _, ok := paths[p+".json"]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HFOwnerRepoFromURI extracts the "owner", "repo" pair from an HF URI.
|
||||
// Accepted prefixes: "https://huggingface.co/", "huggingface://", "hf://".
|
||||
// Returns ok=false when the URI is not an HF URI or is missing either
|
||||
// component. This exists so importers can fall back to URI-based matching
|
||||
// when pkg/huggingface-api's recursive tree listing errors out on repos
|
||||
// with nested subdirectories (a known pre-existing bug).
|
||||
func HFOwnerRepoFromURI(uri string) (owner, repo string, ok bool) {
|
||||
stripped := uri
|
||||
for _, pfx := range []string{"https://huggingface.co/", "huggingface://", "hf://"} {
|
||||
stripped = strings.TrimPrefix(stripped, pfx)
|
||||
}
|
||||
parts := strings.SplitN(stripped, "/", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
|
||||
// HasGGMLFile returns true when any file matches "<prefix>*.bin", which is
|
||||
// the whisper.cpp packaging convention (e.g. "ggml-base.en.bin"). Both prefix
|
||||
// and suffix match is case-sensitive on prefix and case-insensitive on the
|
||||
// .bin extension.
|
||||
func HasGGMLFile(files []hfapi.ModelFile, prefix string) bool {
|
||||
for _, f := range files {
|
||||
name := filepath.Base(f.Path)
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(name), ".bin") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
89
core/gallery/importers/helpers_test.go
Normal file
89
core/gallery/importers/helpers_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
var _ = Describe("importer helpers", func() {
|
||||
mkFiles := func(paths ...string) []hfapi.ModelFile {
|
||||
out := make([]hfapi.ModelFile, 0, len(paths))
|
||||
for _, p := range paths {
|
||||
out = append(out, hfapi.ModelFile{Path: p})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
Describe("HasFile", func() {
|
||||
It("returns false for an empty slice", func() {
|
||||
Expect(importers.HasFile(nil, "config.json")).To(BeFalse())
|
||||
Expect(importers.HasFile([]hfapi.ModelFile{}, "config.json")).To(BeFalse())
|
||||
})
|
||||
It("matches on exact basename, ignoring directory components", func() {
|
||||
files := mkFiles("sub/dir/config.json", "other.txt")
|
||||
Expect(importers.HasFile(files, "config.json")).To(BeTrue())
|
||||
})
|
||||
It("does not match partial basenames", func() {
|
||||
files := mkFiles("sub/dir/myconfig.json")
|
||||
Expect(importers.HasFile(files, "config.json")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HasExtension", func() {
|
||||
It("matches case-insensitively", func() {
|
||||
files := mkFiles("model.ONNX", "other.txt")
|
||||
Expect(importers.HasExtension(files, ".onnx")).To(BeTrue())
|
||||
})
|
||||
It("returns false when no file has the extension", func() {
|
||||
Expect(importers.HasExtension(mkFiles("README.md"), ".onnx")).To(BeFalse())
|
||||
})
|
||||
It("handles empty slices gracefully", func() {
|
||||
Expect(importers.HasExtension(nil, ".onnx")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HasONNX", func() {
|
||||
It("is true when any file ends in .onnx", func() {
|
||||
Expect(importers.HasONNX(mkFiles("voice/en_US-amy-medium.onnx"))).To(BeTrue())
|
||||
})
|
||||
It("is false otherwise", func() {
|
||||
Expect(importers.HasONNX(mkFiles("model.bin"))).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HasONNXConfigPair", func() {
|
||||
It("matches the piper .onnx + .onnx.json pair", func() {
|
||||
files := mkFiles(
|
||||
"en/en_US/amy/medium/en_US-amy-medium.onnx",
|
||||
"en/en_US/amy/medium/en_US-amy-medium.onnx.json",
|
||||
)
|
||||
Expect(importers.HasONNXConfigPair(files)).To(BeTrue())
|
||||
})
|
||||
It("requires the accompanying json to share the .onnx basename", func() {
|
||||
files := mkFiles("model.onnx", "config.json")
|
||||
Expect(importers.HasONNXConfigPair(files)).To(BeFalse())
|
||||
})
|
||||
It("returns false for a lone .onnx file", func() {
|
||||
files := mkFiles("model.onnx")
|
||||
Expect(importers.HasONNXConfigPair(files)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HasGGMLFile", func() {
|
||||
It("finds ggml-prefixed .bin files", func() {
|
||||
files := mkFiles("ggml-base.en.bin", "README.md")
|
||||
Expect(importers.HasGGMLFile(files, "ggml-")).To(BeTrue())
|
||||
})
|
||||
It("requires both prefix and .bin suffix", func() {
|
||||
files := mkFiles("ggml-base.en.gguf")
|
||||
Expect(importers.HasGGMLFile(files, "ggml-")).To(BeFalse())
|
||||
})
|
||||
It("returns false when prefix does not match", func() {
|
||||
files := mkFiles("whisper.bin")
|
||||
Expect(importers.HasGGMLFile(files, "ggml-")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -2,8 +2,10 @@ package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
@@ -15,7 +17,138 @@ import (
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
// ErrAmbiguousImport is returned when HuggingFace metadata hints at a known
|
||||
// modality (e.g. pipeline_tag: "automatic-speech-recognition") but no
|
||||
// importer's artefact-level detection matches the repository. Callers should
|
||||
// pass preferences.backend to disambiguate. Use errors.Is to match regardless
|
||||
// of wrapping — DiscoverModelConfig returns a typed AmbiguousImportError that
|
||||
// carries the detected modality + candidate backends, and whose Is() matches
|
||||
// this sentinel so legacy callers keep working.
|
||||
var ErrAmbiguousImport = errors.New("importer: ambiguous — specify preferences.backend")
|
||||
|
||||
// AmbiguousImportError is the concrete error DiscoverModelConfig returns when
|
||||
// it can't pick an importer automatically. It carries the importer-modality
|
||||
// key (e.g. "tts", "asr") and the list of candidate backend names so HTTP
|
||||
// consumers can render a picker without re-deriving the mapping from HF
|
||||
// pipeline_tag values.
|
||||
type AmbiguousImportError struct {
|
||||
// Modality is the importer modality key ("text", "asr", "tts", "image",
|
||||
// "embeddings", "reranker", "detection"). Pre-mapped from the HF
|
||||
// pipeline_tag so the UI doesn't have to.
|
||||
Modality string
|
||||
// Candidates is the list of backend names whose Modality() matches — a
|
||||
// subset of the importer registry plus AdditionalBackendsProvider
|
||||
// drop-ins.
|
||||
Candidates []string
|
||||
// URI is the original URI that triggered the ambiguity.
|
||||
URI string
|
||||
// PipelineTag is the raw HF pipeline_tag value as reported by the model
|
||||
// metadata — preserved for logging / debugging.
|
||||
PipelineTag string
|
||||
}
|
||||
|
||||
func (e *AmbiguousImportError) Error() string {
|
||||
return fmt.Sprintf("importer: ambiguous — detected modality %q (pipeline_tag=%q) for %s, candidates: %v",
|
||||
e.Modality, e.PipelineTag, e.URI, e.Candidates)
|
||||
}
|
||||
|
||||
// Is lets callers match with errors.Is(err, ErrAmbiguousImport) without caring
|
||||
// about the typed shape.
|
||||
func (e *AmbiguousImportError) Is(target error) bool {
|
||||
return target == ErrAmbiguousImport
|
||||
}
|
||||
|
||||
// ambiguousModalities enumerates the HF pipeline_tag values that are narrow
|
||||
// enough to be confident we should surface ambiguity instead of a generic
|
||||
// "no importer matched" error. Tags outside this whitelist keep the previous
|
||||
// behaviour (plain error) so we don't block uncommon-but-still-valid imports.
|
||||
// The mapped value is the importer modality key used to filter candidates.
|
||||
var ambiguousModalities = map[string]string{
|
||||
"automatic-speech-recognition": "asr",
|
||||
"text-to-speech": "tts",
|
||||
"sentence-similarity": "embeddings",
|
||||
"text-classification": "reranker",
|
||||
"object-detection": "detection",
|
||||
"text-to-image": "image",
|
||||
}
|
||||
|
||||
// PipelineTagToModality maps HF pipeline_tag strings to the importer modality
|
||||
// key used internally (and by /backends/known). Returns the modality + true
|
||||
// when the tag is in the ambiguous whitelist; "" + false otherwise.
|
||||
func PipelineTagToModality(pipelineTag string) (string, bool) {
|
||||
m, ok := ambiguousModalities[pipelineTag]
|
||||
return m, ok
|
||||
}
|
||||
|
||||
// CandidatesForModality returns the backend names whose importer modality
|
||||
// matches the requested key. Includes AdditionalBackendsProvider drop-ins so
|
||||
// entries like ik-llama-cpp surface for text modalities. Results are sorted
|
||||
// for deterministic ordering in API responses.
|
||||
func CandidatesForModality(modality string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
for _, imp := range defaultImporters {
|
||||
if imp.Modality() != modality {
|
||||
continue
|
||||
}
|
||||
seen[imp.Name()] = struct{}{}
|
||||
if host, ok := imp.(AdditionalBackendsProvider); ok {
|
||||
for _, extra := range host.AdditionalBackends() {
|
||||
if extra.Modality != modality {
|
||||
continue
|
||||
}
|
||||
seen[extra.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
out := make([]string, 0, len(seen))
|
||||
for n := range seen {
|
||||
out = append(out, n)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
var defaultImporters = []Importer{
|
||||
// ASR (Batch 1)
|
||||
&WhisperImporter{},
|
||||
&MoonshineImporter{},
|
||||
&NemoImporter{},
|
||||
&FasterWhisperImporter{},
|
||||
&QwenASRImporter{},
|
||||
// TTS (Batch 2)
|
||||
&PiperImporter{},
|
||||
&BarkImporter{},
|
||||
&FishSpeechImporter{},
|
||||
&OutettsImporter{},
|
||||
&VoxCPMImporter{},
|
||||
&KokoroImporter{},
|
||||
&KittenTTSImporter{},
|
||||
&NeuTTSImporter{},
|
||||
&ChatterboxImporter{},
|
||||
&VibeVoiceImporter{},
|
||||
&CoquiImporter{},
|
||||
// Image/Video (Batch 3)
|
||||
&StableDiffusionGGMLImporter{},
|
||||
&ACEStepImporter{},
|
||||
// Text LLM (Batch 4) — VLLMOmniImporter must stay ahead of
|
||||
// VLLMImporter so Qwen Omni repos (which also carry tokenizer
|
||||
// files) route to vllm-omni rather than plain vllm.
|
||||
&VLLMOmniImporter{},
|
||||
// Embeddings / rerankers / detection / VAD (Batch 5)
|
||||
// SileroVADImporter first — unique filename signal, cannot collide.
|
||||
&SileroVADImporter{},
|
||||
// RerankersImporter must run before SentenceTransformers and
|
||||
// Transformers — some reranker repos ship modules.json and tokenizer
|
||||
// files that those importers would otherwise claim.
|
||||
&RerankersImporter{},
|
||||
// SentenceTransformersImporter must run before TransformersImporter:
|
||||
// sentence-transformers repos ship tokenizer.json which transformers
|
||||
// would otherwise claim.
|
||||
&SentenceTransformersImporter{},
|
||||
// RFDetrImporter must run before TransformersImporter — RF-DETR
|
||||
// checkpoints may carry tokenizer-adjacent artefacts.
|
||||
&RFDetrImporter{},
|
||||
// Existing
|
||||
&LlamaCPPImporter{},
|
||||
&MLXImporter{},
|
||||
&VLLMImporter{},
|
||||
@@ -32,6 +165,42 @@ type Details struct {
|
||||
type Importer interface {
|
||||
Match(details Details) bool
|
||||
Import(details Details) (gallery.ModelConfig, error)
|
||||
// Name is the canonical backend name (e.g. "llama-cpp"). Used by
|
||||
// /backends/known to populate the import form dropdown.
|
||||
Name() string
|
||||
// Modality is the backend's primary modality ("text", "asr", "tts",
|
||||
// "image", "embeddings", "reranker", "detection", "vad"). Used for
|
||||
// grouping in the UI.
|
||||
Modality() string
|
||||
// AutoDetects is true when Match() can fire without an explicit
|
||||
// preferences.backend. Preference-only entries surface as
|
||||
// AutoDetect=false in /backends/known.
|
||||
AutoDetects() bool
|
||||
}
|
||||
|
||||
// KnownBackendEntry describes one backend advertised by an importer.
|
||||
// Importers that host drop-in replacements (e.g. llama-cpp hosting
|
||||
// ik-llama-cpp and turboquant) return additional entries via
|
||||
// AdditionalBackendsProvider so the endpoint can surface them without
|
||||
// registering separate importers.
|
||||
type KnownBackendEntry struct {
|
||||
Name string
|
||||
Modality string
|
||||
Description string
|
||||
}
|
||||
|
||||
// AdditionalBackendsProvider is implemented by importers that advertise
|
||||
// drop-in replacements sharing their Match/Import logic. The entries
|
||||
// appear in /backends/known with AutoDetect=false since they are
|
||||
// preference-only.
|
||||
type AdditionalBackendsProvider interface {
|
||||
AdditionalBackends() []KnownBackendEntry
|
||||
}
|
||||
|
||||
// Registry returns the list of registered importers. Callers must not
|
||||
// mutate the returned slice.
|
||||
func Registry() []Importer {
|
||||
return defaultImporters
|
||||
}
|
||||
|
||||
func hasYAMLExtension(uri string) bool {
|
||||
@@ -115,6 +284,19 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model
|
||||
}
|
||||
}
|
||||
if !importerMatched {
|
||||
// When HuggingFace metadata hints at a known, narrow modality but no
|
||||
// importer matched the artefacts, surface an explicit ambiguity so the
|
||||
// caller knows to pass preferences.backend rather than silently guess.
|
||||
if hfDetails != nil && hfDetails.PipelineTag != "" {
|
||||
if modality, known := ambiguousModalities[hfDetails.PipelineTag]; known {
|
||||
return gallery.ModelConfig{}, &AmbiguousImportError{
|
||||
Modality: modality,
|
||||
Candidates: CandidatesForModality(modality),
|
||||
URI: uri,
|
||||
PipelineTag: hfDetails.PipelineTag,
|
||||
}
|
||||
}
|
||||
}
|
||||
return gallery.ModelConfig{}, fmt.Errorf("no importer matched for %s", uri)
|
||||
}
|
||||
return modelConfig, nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -201,6 +202,99 @@ var _ = Describe("DiscoverModelConfig", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("ErrAmbiguousImport sentinel", func() {
|
||||
It("is defined so callers can match with errors.Is", func() {
|
||||
Expect(importers.ErrAmbiguousImport).ToNot(BeNil())
|
||||
// Wrapping-sanity: fmt.Errorf("%w", err) preserves identity.
|
||||
wrapped := fmt.Errorf("context: %w", importers.ErrAmbiguousImport)
|
||||
Expect(errors.Is(wrapped, importers.ErrAmbiguousImport)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("surfaces modality and candidates on the typed error for HTTP consumers", func() {
|
||||
// TTS fixture — pipeline_tag=text-to-speech, no importer matches.
|
||||
uri := "https://huggingface.co/nari-labs/Dia-1.6B"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, importers.ErrAmbiguousImport)).To(BeTrue())
|
||||
|
||||
var amb *importers.AmbiguousImportError
|
||||
Expect(errors.As(err, &amb)).To(BeTrue(), "expected AmbiguousImportError, got: %v", err)
|
||||
Expect(amb.Modality).To(Equal("tts"))
|
||||
Expect(amb.Candidates).To(ContainElements("piper", "bark", "kokoro"))
|
||||
Expect(amb.Candidates).ToNot(ContainElement("llama-cpp"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
// These tests drive the /backends/known endpoint: each importer must
|
||||
// self-describe its canonical name, primary modality, and whether it
|
||||
// can auto-detect without an explicit preference.
|
||||
It("Registry returns all default importers", func() {
|
||||
registry := importers.Registry()
|
||||
Expect(registry).ToNot(BeEmpty())
|
||||
names := make([]string, 0, len(registry))
|
||||
for _, imp := range registry {
|
||||
names = append(names, imp.Name())
|
||||
}
|
||||
Expect(names).To(ContainElements("llama-cpp", "mlx", "vllm", "transformers", "diffusers"))
|
||||
})
|
||||
|
||||
It("LlamaCPPImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.LlamaCPPImporter{}
|
||||
Expect(imp.Name()).To(Equal("llama-cpp"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("MLXImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.MLXImporter{}
|
||||
Expect(imp.Name()).To(Equal("mlx"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("VLLMImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.VLLMImporter{}
|
||||
Expect(imp.Name()).To(Equal("vllm"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("TransformersImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.TransformersImporter{}
|
||||
Expect(imp.Name()).To(Equal("transformers"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("DiffuserImporter exposes name/modality/autodetect", func() {
|
||||
imp := &importers.DiffuserImporter{}
|
||||
Expect(imp.Name()).To(Equal("diffusers"))
|
||||
Expect(imp.Modality()).To(Equal("image"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("LlamaCPPImporter advertises drop-in replacements", func() {
|
||||
imp := &importers.LlamaCPPImporter{}
|
||||
provider, ok := any(imp).(importers.AdditionalBackendsProvider)
|
||||
Expect(ok).To(BeTrue(), "LlamaCPPImporter must implement AdditionalBackendsProvider")
|
||||
|
||||
extras := provider.AdditionalBackends()
|
||||
names := make([]string, 0, len(extras))
|
||||
modalities := make([]string, 0, len(extras))
|
||||
for _, e := range extras {
|
||||
names = append(names, e.Name)
|
||||
modalities = append(modalities, e.Modality)
|
||||
}
|
||||
Expect(names).To(ContainElements("ik-llama-cpp", "turboquant"))
|
||||
for _, m := range modalities {
|
||||
Expect(m).To(Equal("text"))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("with invalid JSON preferences", func() {
|
||||
It("should return error when JSON is invalid even if URI matches", func() {
|
||||
uri := "https://example.com/model.gguf"
|
||||
|
||||
109
core/gallery/importers/kitten-tts.go
Normal file
109
core/gallery/importers/kitten-tts.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &KittenTTSImporter{}
|
||||
|
||||
// KittenTTSImporter recognises KittenML's kitten-tts releases. Detection
|
||||
// uses the `KittenML` owner or a "kitten-tts" substring in the repo name
|
||||
// for third-party mirrors. preferences.backend="kitten-tts" overrides.
|
||||
type KittenTTSImporter struct{}
|
||||
|
||||
func (i *KittenTTSImporter) Name() string { return "kitten-tts" }
|
||||
func (i *KittenTTSImporter) Modality() string { return "tts" }
|
||||
func (i *KittenTTSImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *KittenTTSImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "kitten-tts" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "KittenML") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "kitten-tts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "KittenML") || strings.Contains(strings.ToLower(repo), "kitten-tts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *KittenTTSImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "kitten-tts",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/kitten-tts_test.go
Normal file
47
core/gallery/importers/kitten-tts_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("KittenTTSImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches KittenML/kitten-tts-nano-0.1 (owner + repo name)", func() {
|
||||
uri := "https://huggingface.co/KittenML/kitten-tts-nano-0.1"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: kitten-tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("KittenML/kitten-tts-nano-0.1"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=kitten-tts for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "kitten-tts"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: kitten-tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.KittenTTSImporter{}
|
||||
Expect(imp.Name()).To(Equal("kitten-tts"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
110
core/gallery/importers/kokoro.go
Normal file
110
core/gallery/importers/kokoro.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &KokoroImporter{}
|
||||
|
||||
// KokoroImporter recognises hexgrad's Kokoro TTS family (Kokoro-82M,
|
||||
// Kokoro-82M-v1.1-zh, …). The repo name carries "Kokoro" and the weights
|
||||
// ship as a PyTorch .pth/.pt — pairing the two keeps us from claiming the
|
||||
// quantised GGUF mirrors (which llama-cpp handles) or the ONNX exports
|
||||
// (which the pref-only `kokoros` Rust runtime handles).
|
||||
// preferences.backend="kokoro" overrides detection; preferences.backend
|
||||
// ="kokoros" deliberately does *not* trigger this importer (see test).
|
||||
type KokoroImporter struct{}
|
||||
|
||||
func (i *KokoroImporter) Name() string { return "kokoro" }
|
||||
func (i *KokoroImporter) Modality() string { return "tts" }
|
||||
func (i *KokoroImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *KokoroImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Explicit "kokoro" overrides. "kokoros" is intentionally distinct:
|
||||
// comparing against the exact string prevents pref-only kokoros
|
||||
// requests from hijacking this importer.
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "kokoro" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(repoName), "kokoro") {
|
||||
return false
|
||||
}
|
||||
// Require a PyTorch checkpoint to disambiguate from ONNX-only or
|
||||
// GGUF-only mirrors that route to other backends.
|
||||
return HasExtension(details.HuggingFace.Files, ".pth") || HasExtension(details.HuggingFace.Files, ".pt")
|
||||
}
|
||||
|
||||
func (i *KokoroImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "kokoro",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
63
core/gallery/importers/kokoro_test.go
Normal file
63
core/gallery/importers/kokoro_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("KokoroImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches hexgrad/Kokoro-82M (repo name + .pth)", func() {
|
||||
uri := "https://huggingface.co/hexgrad/Kokoro-82M"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: kokoro"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("hexgrad/Kokoro-82M"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=kokoro for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "kokoro"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: kokoro"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("kokoros preference disambiguation", func() {
|
||||
It("does not hijack preferences.backend=kokoros (pref-only)", func() {
|
||||
// The kokoros Rust runtime is pref-only (listed in
|
||||
// knownPrefOnlyBackends). The autodetect path for the kokoro
|
||||
// importer must NOT fire when the user explicitly selects the
|
||||
// kokoros backend for an arbitrary URI — if it did, DiscoverModelConfig
|
||||
// would incorrectly emit backend=kokoro.
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "kokoros"}`)
|
||||
|
||||
_, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
// kokoros has no importer, so discovery should not match anything.
|
||||
Expect(err).To(HaveOccurred(), "kokoros is pref-only — DiscoverModelConfig should not match any importer")
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.KokoroImporter{}
|
||||
Expect(imp.Name()).To(Equal("kokoro"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -3,7 +3,6 @@ package importers
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -11,14 +10,33 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
"github.com/mudler/xlog"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &LlamaCPPImporter{}
|
||||
var (
|
||||
_ Importer = &LlamaCPPImporter{}
|
||||
_ AdditionalBackendsProvider = &LlamaCPPImporter{}
|
||||
)
|
||||
|
||||
type LlamaCPPImporter struct{}
|
||||
|
||||
func (i *LlamaCPPImporter) Name() string { return "llama-cpp" }
|
||||
func (i *LlamaCPPImporter) Modality() string { return "text" }
|
||||
func (i *LlamaCPPImporter) AutoDetects() bool { return true }
|
||||
|
||||
// AdditionalBackends advertises drop-in replacements that share the
|
||||
// llama-cpp detection logic. They are preference-only: selecting one
|
||||
// from the import form swaps the emitted YAML backend field but reuses
|
||||
// the llama-cpp Match/Import pipeline.
|
||||
func (i *LlamaCPPImporter) AdditionalBackends() []KnownBackendEntry {
|
||||
return []KnownBackendEntry{
|
||||
{Name: "ik-llama-cpp", Modality: "text", Description: "GGUF drop-in replacement for llama-cpp with ik-quants"},
|
||||
{Name: "turboquant", Modality: "text", Description: "GGUF drop-in replacement for llama-cpp with TurboQuant optimizations"},
|
||||
}
|
||||
}
|
||||
|
||||
func (i *LlamaCPPImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
@@ -101,12 +119,25 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
|
||||
embeddings, _ := preferencesMap["embeddings"].(string)
|
||||
|
||||
// Honour drop-in replacement preferences. Only the curated names
|
||||
// advertised via AdditionalBackends() are accepted; anything else
|
||||
// (including "llama-cpp" itself, or an unknown value) keeps the
|
||||
// default backend field so arbitrary input can't leak through. See
|
||||
// the AdditionalBackends method for the canonical list.
|
||||
backend := "llama-cpp"
|
||||
if b, ok := preferencesMap["backend"].(string); ok {
|
||||
switch b {
|
||||
case "ik-llama-cpp", "turboquant":
|
||||
backend = b
|
||||
}
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Options: []string{"use_jinja:true"},
|
||||
Backend: "llama-cpp",
|
||||
Backend: backend,
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
@@ -172,59 +203,34 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
},
|
||||
}
|
||||
case details.HuggingFace != nil:
|
||||
// We want to:
|
||||
// Get first the chosen quants that match filenames
|
||||
// OR the first mmproj/gguf file found
|
||||
var lastMMProjFile *gallery.File
|
||||
var lastGGUFFile *gallery.File
|
||||
foundPreferedQuant := false
|
||||
foundPreferedMMprojQuant := false
|
||||
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
// Get the mmproj prefered quants
|
||||
if strings.Contains(strings.ToLower(file.Path), "mmproj") {
|
||||
lastMMProjFile = &gallery.File{
|
||||
URI: file.URL,
|
||||
Filename: filepath.Join("llama-cpp", "mmproj", name, filepath.Base(file.Path)),
|
||||
SHA256: file.SHA256,
|
||||
}
|
||||
if slices.ContainsFunc(mmprojQuantsList, func(quant string) bool {
|
||||
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastMMProjFile)
|
||||
foundPreferedMMprojQuant = true
|
||||
}
|
||||
} else if strings.HasSuffix(strings.ToLower(file.Path), "gguf") {
|
||||
lastGGUFFile = &gallery.File{
|
||||
URI: file.URL,
|
||||
Filename: filepath.Join("llama-cpp", "models", name, filepath.Base(file.Path)),
|
||||
SHA256: file.SHA256,
|
||||
}
|
||||
// get the files of the prefered quants
|
||||
if slices.ContainsFunc(quants, func(quant string) bool {
|
||||
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
|
||||
}) {
|
||||
foundPreferedQuant = true
|
||||
cfg.Files = append(cfg.Files, *lastGGUFFile)
|
||||
}
|
||||
// Split the repo listing into mmproj vs plain GGUF files, then group
|
||||
// shards so every multi-part GGUF (llama.cpp `-NNNNN-of-MMMMM.gguf`
|
||||
// pattern) is treated as one logical selection candidate. The
|
||||
// previous implementation picked files one at a time, so sharded
|
||||
// models ended up with only the last part referenced in the gallery
|
||||
// entry — useless to llama.cpp, which needs shard 1 and the whole
|
||||
// set to load a split model.
|
||||
var mmprojFiles, ggufFiles []hfapi.ModelFile
|
||||
for _, f := range details.HuggingFace.Files {
|
||||
lowerPath := strings.ToLower(f.Path)
|
||||
switch {
|
||||
case strings.Contains(lowerPath, "mmproj"):
|
||||
mmprojFiles = append(mmprojFiles, f)
|
||||
case strings.HasSuffix(lowerPath, ".gguf"):
|
||||
ggufFiles = append(ggufFiles, f)
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure to add at least one file if not already present (which is the latest one)
|
||||
if lastMMProjFile != nil && !foundPreferedMMprojQuant {
|
||||
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
|
||||
return f.Filename == lastMMProjFile.Filename
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastMMProjFile)
|
||||
}
|
||||
}
|
||||
mmprojGroups := hfapi.GroupShards(mmprojFiles)
|
||||
ggufGroups := hfapi.GroupShards(ggufFiles)
|
||||
|
||||
if lastGGUFFile != nil && !foundPreferedQuant {
|
||||
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
|
||||
return f.Filename == lastGGUFFile.Filename
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastGGUFFile)
|
||||
}
|
||||
// Emit the model group first so cfg.Files[0] is the model — callers
|
||||
// and tests rely on the model file preceding any mmproj companion.
|
||||
if group := pickPreferredGroup(ggufGroups, quants); group != nil {
|
||||
appendShardGroup(&cfg, *group, filepath.Join("llama-cpp", "models", name))
|
||||
}
|
||||
if group := pickPreferredGroup(mmprojGroups, mmprojQuantsList); group != nil {
|
||||
appendShardGroup(&cfg, *group, filepath.Join("llama-cpp", "mmproj", name))
|
||||
}
|
||||
|
||||
// Find first mmproj file and configure it in the config file
|
||||
@@ -236,7 +242,9 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
break
|
||||
}
|
||||
|
||||
// Find first non-mmproj file and configure it in the config file
|
||||
// Find first non-mmproj file and configure it in the config file.
|
||||
// For sharded models this is shard 1 — llama.cpp's split loader
|
||||
// discovers the remaining shards by filename pattern from there.
|
||||
for _, file := range cfg.Files {
|
||||
if strings.Contains(strings.ToLower(file.Filename), "mmproj") {
|
||||
continue
|
||||
@@ -262,3 +270,48 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// pickPreferredGroup walks the preference list in priority order and returns
|
||||
// the first group whose base filename contains any preference. When nothing
|
||||
// matches, the last group wins — this preserves the historical "if the user
|
||||
// asked for a quant we don't have, fall back to whatever's available"
|
||||
// behaviour, lifted to whole shard sets.
|
||||
func pickPreferredGroup(groups []hfapi.ShardGroup, prefs []string) *hfapi.ShardGroup {
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, pref := range prefs {
|
||||
lower := strings.ToLower(pref)
|
||||
for i := range groups {
|
||||
if strings.Contains(strings.ToLower(groups[i].Base), lower) {
|
||||
return &groups[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return &groups[len(groups)-1]
|
||||
}
|
||||
|
||||
// appendShardGroup copies every shard of group into cfg.Files under dest,
|
||||
// skipping any entry whose target filename is already present so repeated
|
||||
// calls (e.g. the rare case of mmproj + model picking the same group)
|
||||
// don't produce duplicates.
|
||||
func appendShardGroup(cfg *gallery.ModelConfig, group hfapi.ShardGroup, dest string) {
|
||||
for _, f := range group.Files {
|
||||
target := filepath.Join(dest, filepath.Base(f.Path))
|
||||
duplicate := false
|
||||
for _, existing := range cfg.Files {
|
||||
if existing.Filename == target {
|
||||
duplicate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if duplicate {
|
||||
continue
|
||||
}
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: f.URL,
|
||||
Filename: target,
|
||||
SHA256: f.SHA256,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -129,4 +130,269 @@ var _ = Describe("LlamaCPPImporter", func() {
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("drop-in backend preferences", func() {
|
||||
// baseline: no preference keeps backend: llama-cpp and the file
|
||||
// layout that downstream assertions depend on.
|
||||
It("emits backend: llama-cpp when no backend preference is set", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("swaps the emitted backend to ik-llama-cpp when preferred", func() {
|
||||
preferences := json.RawMessage(`{"backend": "ik-llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: ik-llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).NotTo(ContainSubstring("backend: llama-cpp\n"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// Model path must remain identical to the llama-cpp baseline.
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"))
|
||||
})
|
||||
|
||||
It("swaps the emitted backend to turboquant when preferred", func() {
|
||||
preferences := json.RawMessage(`{"backend": "turboquant"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: turboquant"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).NotTo(ContainSubstring("backend: llama-cpp\n"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"))
|
||||
})
|
||||
|
||||
It("keeps backend: llama-cpp for unknown backend preferences", func() {
|
||||
// Unknown backend values must not leak into the emitted YAML —
|
||||
// we only honour the two curated drop-in replacements.
|
||||
preferences := json.RawMessage(`{"backend": "something-weird"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import from HuggingFace file listing", func() {
|
||||
// These tests exercise the HF branch of Import() without touching
|
||||
// the network — they construct a fake *hfapi.ModelDetails and
|
||||
// assert the emitted gallery entry directly. Historically the HF
|
||||
// branch was only covered by live-API integration specs in
|
||||
// importers_test.go; anything that happened in between (shard
|
||||
// grouping, quant fallback) had no unit-level regression net.
|
||||
|
||||
const repoBase = "https://huggingface.co/acme/example-GGUF/resolve/main/"
|
||||
|
||||
hfFile := func(path, sha string) hfapi.ModelFile {
|
||||
return hfapi.ModelFile{
|
||||
Path: path,
|
||||
SHA256: sha,
|
||||
URL: repoBase + path,
|
||||
}
|
||||
}
|
||||
|
||||
withHF := func(preferences string, files ...hfapi.ModelFile) Details {
|
||||
d := Details{
|
||||
URI: "https://huggingface.co/acme/example-GGUF",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "acme/example-GGUF",
|
||||
Files: files,
|
||||
},
|
||||
}
|
||||
if preferences != "" {
|
||||
d.Preferences = json.RawMessage(preferences)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
It("picks the preferred quant in a single-file repo", func() {
|
||||
details := withHF(`{"name":"example","quantizations":"Q4_K_M"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "aaa"),
|
||||
hfFile("model-Q3_K_M.gguf", "bbb"),
|
||||
hfFile("README.md", ""),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(1), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/example/model-Q4_K_M.gguf"))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal(repoBase + "model-Q4_K_M.gguf"))
|
||||
Expect(modelConfig.Files[0].SHA256).To(Equal("aaa"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/example/model-Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("falls back to the last group when no preference matches", func() {
|
||||
// Default preference is q4_k_m; the repo has only Q8_0 and
|
||||
// Q3_K_M. The old implementation would emit exactly the last
|
||||
// file seen — this test pins the fallback behaviour so the
|
||||
// group-level fallback keeps matching the historical intent.
|
||||
details := withHF(`{"name":"example"}`,
|
||||
hfFile("model-Q8_0.gguf", "aaa"),
|
||||
hfFile("model-Q3_K_M.gguf", "bbb"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(1), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/example/model-Q3_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("emits all shards of a multi-part GGUF and points Model at shard 1", func() {
|
||||
// Regression for PR #9510: unsloth/Kimi-K2.6-GGUF ships 14
|
||||
// Q8_K_XL shards. Default prefs are q4_k_m; none match, so the
|
||||
// fallback must take the whole shard group (not just the last
|
||||
// shard) and the config's `model:` must point at shard 1 so
|
||||
// llama.cpp's split loader can walk the rest.
|
||||
files := make([]hfapi.ModelFile, 0, 14)
|
||||
// Deliberately add shards out of order to prove sorting works.
|
||||
for _, idx := range []int{7, 1, 14, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12, 13} {
|
||||
files = append(files, hfFile(
|
||||
fmt.Sprintf("Kimi-K2.6-UD-Q8_K_XL-%05d-of-00014.gguf", idx),
|
||||
fmt.Sprintf("sha-%02d", idx),
|
||||
))
|
||||
}
|
||||
|
||||
details := withHF(`{"name":"Kimi-K2.6-GGUF"}`, files...)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(14), fmt.Sprintf("%+v", modelConfig))
|
||||
|
||||
// All 14 shards must be present, in order, under the models dir.
|
||||
for i := 1; i <= 14; i++ {
|
||||
expected := fmt.Sprintf("llama-cpp/models/Kimi-K2.6-GGUF/Kimi-K2.6-UD-Q8_K_XL-%05d-of-00014.gguf", i)
|
||||
Expect(modelConfig.Files[i-1].Filename).To(Equal(expected))
|
||||
Expect(modelConfig.Files[i-1].SHA256).To(Equal(fmt.Sprintf("sha-%02d", i)))
|
||||
}
|
||||
|
||||
// The configured model path must be shard 1 — this is the file
|
||||
// llama.cpp's split loader expects to be pointed at.
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: llama-cpp/models/Kimi-K2.6-GGUF/Kimi-K2.6-UD-Q8_K_XL-00001-of-00014.gguf",
|
||||
))
|
||||
})
|
||||
|
||||
It("emits all shards of the preferred quant alongside an mmproj", func() {
|
||||
// Sharded multimodal model: mmproj is single-file, the text
|
||||
// model is split in 3 parts and matches the user preference.
|
||||
details := withHF(`{"name":"VL-GGUF","quantizations":"Q4_K_M","mmproj_quantizations":"F16"}`,
|
||||
hfFile("mmproj-F16.gguf", "mm"),
|
||||
hfFile("model-Q4_K_M-00001-of-00003.gguf", "p1"),
|
||||
hfFile("model-Q4_K_M-00002-of-00003.gguf", "p2"),
|
||||
hfFile("model-Q4_K_M-00003-of-00003.gguf", "p3"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(4), fmt.Sprintf("%+v", modelConfig))
|
||||
|
||||
// Model shards come first, in order, then the mmproj.
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/VL-GGUF/model-Q4_K_M-00001-of-00003.gguf"))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/models/VL-GGUF/model-Q4_K_M-00002-of-00003.gguf"))
|
||||
Expect(modelConfig.Files[2].Filename).To(Equal("llama-cpp/models/VL-GGUF/model-Q4_K_M-00003-of-00003.gguf"))
|
||||
Expect(modelConfig.Files[3].Filename).To(Equal("llama-cpp/mmproj/VL-GGUF/mmproj-F16.gguf"))
|
||||
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"model: llama-cpp/models/VL-GGUF/model-Q4_K_M-00001-of-00003.gguf",
|
||||
))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring(
|
||||
"mmproj: llama-cpp/mmproj/VL-GGUF/mmproj-F16.gguf",
|
||||
))
|
||||
})
|
||||
|
||||
It("does not emit duplicate entries when called repeatedly on the same group", func() {
|
||||
// Guards appendShardGroup's dedup: if a shard ends up in the
|
||||
// Files slice via more than one code path (e.g. a future
|
||||
// refactor that processes mmproj and model candidates through
|
||||
// the same path), we must not accidentally duplicate downloads.
|
||||
details := withHF(`{"name":"dup","quantizations":"Q4_K_M,Q4_K_M"}`,
|
||||
hfFile("model-Q4_K_M.gguf", "aaa"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(1), fmt.Sprintf("%+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/dup/model-Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("ignores non-gguf files in the repo listing", func() {
|
||||
// Real HF repos ship READMEs, tokenizer json, images, etc.
|
||||
// Only .gguf entries should surface as downloadable files.
|
||||
details := withHF(`{"name":"noise"}`,
|
||||
hfFile("README.md", ""),
|
||||
hfFile("config.json", ""),
|
||||
hfFile("logo.png", ""),
|
||||
hfFile("model-Q4_K_M.gguf", "aaa"),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(HaveLen(1))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/noise/model-Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("produces no files when the repo contains no .gguf", func() {
|
||||
details := withHF(`{"name":"empty"}`,
|
||||
hfFile("README.md", ""),
|
||||
hfFile("config.json", ""),
|
||||
)
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Files).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("AdditionalBackends", func() {
|
||||
It("advertises ik-llama-cpp and turboquant as drop-in replacements", func() {
|
||||
entries := importer.AdditionalBackends()
|
||||
|
||||
names := make([]string, 0, len(entries))
|
||||
byName := map[string]importers.KnownBackendEntry{}
|
||||
for _, e := range entries {
|
||||
names = append(names, e.Name)
|
||||
byName[e.Name] = e
|
||||
}
|
||||
Expect(names).To(ConsistOf("ik-llama-cpp", "turboquant"))
|
||||
|
||||
ik := byName["ik-llama-cpp"]
|
||||
Expect(ik.Modality).To(Equal("text"))
|
||||
Expect(ik.Description).NotTo(BeEmpty())
|
||||
|
||||
tq := byName["turboquant"]
|
||||
Expect(tq.Modality).To(Equal("text"))
|
||||
Expect(tq.Description).NotTo(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -18,8 +18,11 @@ import (
|
||||
//
|
||||
// Detection order:
|
||||
// 1. GGUF files (*.gguf) — uses llama-cpp backend
|
||||
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
|
||||
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
|
||||
// 2. whisper.cpp ggml-*.bin — uses whisper backend
|
||||
// 3. silero_vad*.onnx — uses silero-vad backend
|
||||
// 4. piper .onnx + .onnx.json pair — uses piper backend
|
||||
// 5. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
|
||||
// 6. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
|
||||
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
// Make paths relative to the models directory (parent of dirPath)
|
||||
// so config YAML stays portable.
|
||||
@@ -51,7 +54,48 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 2. LoRA adapter: look for adapter_config.json
|
||||
// 2. whisper.cpp ggml-*.bin models
|
||||
if ggmlFile := findFileByPrefixSuffix(dirPath, "ggml-", ".bin"); ggmlFile != "" {
|
||||
xlog.Info("ImportLocalPath: detected whisper.cpp GGML model", "path", ggmlFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "whisper",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
}
|
||||
cfg.Model = relPath(ggmlFile)
|
||||
cfg.Description = buildDescription(dirPath, "Whisper GGML")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 3/4. Single .onnx file in dir — silero-vad or piper depending on signals.
|
||||
if onnxFile := findSingleONNX(dirPath); onnxFile != "" {
|
||||
base := filepath.Base(onnxFile)
|
||||
lowerBase := strings.ToLower(base)
|
||||
switch {
|
||||
case strings.HasPrefix(lowerBase, "silero"):
|
||||
xlog.Info("ImportLocalPath: detected Silero VAD model", "path", onnxFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "silero-vad",
|
||||
}
|
||||
cfg.Model = relPath(onnxFile)
|
||||
cfg.Description = buildDescription(dirPath, "Silero VAD")
|
||||
return cfg, nil
|
||||
case fileExists(onnxFile + ".json"):
|
||||
xlog.Info("ImportLocalPath: detected Piper voice", "path", onnxFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "piper",
|
||||
}
|
||||
cfg.Model = relPath(onnxFile)
|
||||
cfg.Description = buildDescription(dirPath, "Piper voice")
|
||||
return cfg, nil
|
||||
}
|
||||
// Lone .onnx without piper config and without silero prefix: fall
|
||||
// through — no reliable backend to assign.
|
||||
}
|
||||
|
||||
// 5. LoRA adapter: look for adapter_config.json
|
||||
|
||||
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
|
||||
if fileExists(adapterConfigPath) {
|
||||
@@ -116,6 +160,41 @@ func findGGUF(dir string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// findFileByPrefixSuffix returns the path to the first file in dir matching
|
||||
// both prefix (case-sensitive) and suffix (case-insensitive), or "".
|
||||
func findFileByPrefixSuffix(dir, prefix, suffix string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
lowerSuffix := strings.ToLower(suffix)
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if strings.HasPrefix(name, prefix) && strings.HasSuffix(strings.ToLower(name), lowerSuffix) {
|
||||
return filepath.Join(dir, name)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// findSingleONNX returns the path to the first .onnx file found in dir, or "".
|
||||
// Subdirectories are ignored — callers expect a flat layout.
|
||||
func findSingleONNX(dir string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".onnx") {
|
||||
return filepath.Join(dir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
|
||||
func readBaseModel(dirPath string) string {
|
||||
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
|
||||
|
||||
@@ -117,6 +117,47 @@ var _ = Describe("ImportLocalPath", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("Whisper ggml-*.bin detection", func() {
|
||||
It("maps ggml-base.en.bin to the whisper backend", func() {
|
||||
modelDir := filepath.Join(tmpDir, "whisper-base")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "ggml-base.en.bin"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "whisper-base")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("whisper"))
|
||||
Expect(cfg.Model).To(ContainSubstring("ggml-base.en.bin"))
|
||||
Expect(cfg.KnownUsecaseStrings).To(ContainElement("transcript"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Piper ONNX + ONNX config detection", func() {
|
||||
It("maps the .onnx + .onnx.json pair to the piper backend", func() {
|
||||
modelDir := filepath.Join(tmpDir, "piper-amy")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "en_US-amy-medium.onnx"), []byte("fake"), 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "en_US-amy-medium.onnx.json"), []byte("{}"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "piper-amy")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("piper"))
|
||||
Expect(cfg.Model).To(ContainSubstring("en_US-amy-medium.onnx"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Silero VAD detection", func() {
|
||||
It("maps silero_vad.onnx to the silero-vad backend", func() {
|
||||
modelDir := filepath.Join(tmpDir, "silero")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "silero_vad.onnx"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "silero")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("silero-vad"))
|
||||
Expect(cfg.Model).To(ContainSubstring("silero_vad.onnx"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("fallback", func() {
|
||||
It("returns error for empty directory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "empty")
|
||||
|
||||
@@ -15,6 +15,10 @@ var _ Importer = &MLXImporter{}
|
||||
|
||||
type MLXImporter struct{}
|
||||
|
||||
func (i *MLXImporter) Name() string { return "mlx" }
|
||||
func (i *MLXImporter) Modality() string { return "text" }
|
||||
func (i *MLXImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *MLXImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
|
||||
109
core/gallery/importers/moonshine.go
Normal file
109
core/gallery/importers/moonshine.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &MoonshineImporter{}
|
||||
|
||||
// MoonshineImporter recognises the UsefulSensors Moonshine ASR models, which
|
||||
// ship as ONNX artefacts under HF repositories owned by "UsefulSensors".
|
||||
// Detection combines the owner and a .onnx file presence check so we don't
|
||||
// accidentally match other UsefulSensors projects that might not host ASR
|
||||
// weights. preferences.backend="moonshine" overrides detection.
|
||||
type MoonshineImporter struct{}
|
||||
|
||||
func (i *MoonshineImporter) Name() string { return "moonshine" }
|
||||
func (i *MoonshineImporter) Modality() string { return "asr" }
|
||||
func (i *MoonshineImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *MoonshineImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "moonshine" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(details.HuggingFace.Author, "UsefulSensors") {
|
||||
return false
|
||||
}
|
||||
// Accept either on-disk .onnx (the canonical Moonshine packaging) or an
|
||||
// ASR pipeline_tag on the metadata. The latter covers the transformers/
|
||||
// safetensors-only sibling repos (moonshine-tiny, moonshine-base, …)
|
||||
// that still route to the moonshine runtime.
|
||||
if HasONNX(details.HuggingFace.Files) {
|
||||
return true
|
||||
}
|
||||
return details.HuggingFace.PipelineTag == "automatic-speech-recognition"
|
||||
}
|
||||
|
||||
func (i *MoonshineImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// Prefer the canonical HF repo path ("owner/repo") so downstream
|
||||
// runtime tooling can resolve the model regardless of how the user
|
||||
// spelled the URI (hf://, https://huggingface.co/, etc.).
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "moonshine",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
48
core/gallery/importers/moonshine_test.go
Normal file
48
core/gallery/importers/moonshine_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MoonshineImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches UsefulSensors/moonshine-tiny (owner + ASR pipeline_tag)", func() {
|
||||
uri := "https://huggingface.co/UsefulSensors/moonshine-tiny"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: moonshine"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("transcript"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
// Model should reference the HF repo path.
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("UsefulSensors/moonshine-tiny"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=moonshine for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "moonshine"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: moonshine"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.MoonshineImporter{}
|
||||
Expect(imp.Name()).To(Equal("moonshine"))
|
||||
Expect(imp.Modality()).To(Equal("asr"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
99
core/gallery/importers/nemo.go
Normal file
99
core/gallery/importers/nemo.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &NemoImporter{}
|
||||
|
||||
// NemoImporter matches NVIDIA NeMo ASR checkpoints, which always ship as
|
||||
// single-file ".nemo" archives under NVIDIA-owned HF repositories. Combining
|
||||
// owner=nvidia with the .nemo extension is narrow enough to avoid picking up
|
||||
// the unrelated NVIDIA LLM repos that only carry safetensors weights.
|
||||
// preferences.backend="nemo" overrides detection.
|
||||
type NemoImporter struct{}
|
||||
|
||||
func (i *NemoImporter) Name() string { return "nemo" }
|
||||
func (i *NemoImporter) Modality() string { return "asr" }
|
||||
func (i *NemoImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *NemoImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "nemo" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(details.HuggingFace.Author, "nvidia") {
|
||||
return false
|
||||
}
|
||||
return HasExtension(details.HuggingFace.Files, ".nemo")
|
||||
}
|
||||
|
||||
func (i *NemoImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "nemo",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/nemo_test.go
Normal file
47
core/gallery/importers/nemo_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("NemoImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches nvidia/parakeet-tdt-0.6b-v3 (owner + .nemo file)", func() {
|
||||
uri := "https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: nemo"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("transcript"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("nvidia/parakeet-tdt-0.6b-v3"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=nemo for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "nemo"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: nemo"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.NemoImporter{}
|
||||
Expect(imp.Name()).To(Equal("nemo"))
|
||||
Expect(imp.Modality()).To(Equal("asr"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
110
core/gallery/importers/neutts.go
Normal file
110
core/gallery/importers/neutts.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &NeuTTSImporter{}
|
||||
|
||||
// NeuTTSImporter recognises Neuphonic's NeuTTS releases. Detection uses
|
||||
// "neutts" (case-insensitive) substring in the repo name or the
|
||||
// `neuphonic` owner — covers both the primary "neutts-air" release and
|
||||
// community mirrors. preferences.backend="neutts" overrides detection.
|
||||
type NeuTTSImporter struct{}
|
||||
|
||||
func (i *NeuTTSImporter) Name() string { return "neutts" }
|
||||
func (i *NeuTTSImporter) Modality() string { return "tts" }
|
||||
func (i *NeuTTSImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *NeuTTSImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "neutts" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "neuphonic") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "neutts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "neuphonic") || strings.Contains(strings.ToLower(repo), "neutts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *NeuTTSImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "neutts",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/neutts_test.go
Normal file
47
core/gallery/importers/neutts_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("NeuTTSImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches neuphonic/neutts-air (owner)", func() {
|
||||
uri := "https://huggingface.co/neuphonic/neutts-air"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: neutts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("neuphonic/neutts-air"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=neutts for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "neutts"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: neutts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.NeuTTSImporter{}
|
||||
Expect(imp.Name()).To(Equal("neutts"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
112
core/gallery/importers/outetts.go
Normal file
112
core/gallery/importers/outetts.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &OutettsImporter{}
|
||||
|
||||
// OutettsImporter recognises OuteAI's OuteTTS releases. Detection uses the
|
||||
// `OuteAI` owner or a case-insensitive "OuteTTS" substring in the repo
|
||||
// name so third-party forks (e.g. community finetunes re-hosted outside
|
||||
// the OuteAI org) still route to this backend.
|
||||
// preferences.backend="outetts" overrides detection.
|
||||
type OutettsImporter struct{}
|
||||
|
||||
func (i *OutettsImporter) Name() string { return "outetts" }
|
||||
func (i *OutettsImporter) Modality() string { return "tts" }
|
||||
func (i *OutettsImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *OutettsImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "outetts" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "OuteAI") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if strings.Contains(strings.ToLower(repoName), "outetts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// URI fallback (parity with other TTS importers).
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "OuteAI") || strings.Contains(strings.ToLower(repo), "outetts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *OutettsImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "outetts",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
47
core/gallery/importers/outetts_test.go
Normal file
47
core/gallery/importers/outetts_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("OutettsImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches OuteAI/OuteTTS-0.3-1B (owner + repo name)", func() {
|
||||
uri := "https://huggingface.co/OuteAI/OuteTTS-0.3-1B"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: outetts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("OuteAI/OuteTTS-0.3-1B"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=outetts for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "outetts"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: outetts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.OutettsImporter{}
|
||||
Expect(imp.Name()).To(Equal("outetts"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
109
core/gallery/importers/piper.go
Normal file
109
core/gallery/importers/piper.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &PiperImporter{}
|
||||
|
||||
// PiperImporter recognises Piper TTS voices. Piper ships each voice as a pair
|
||||
// of "<voice>.onnx" + "<voice>.onnx.json" files (e.g.
|
||||
// en_US-amy-medium.onnx + en_US-amy-medium.onnx.json) — the JSON sidecar is
|
||||
// what disambiguates these from generic ONNX exports used by other backends
|
||||
// (Moonshine, sentence-transformers, etc). preferences.backend="piper"
|
||||
// overrides detection.
|
||||
type PiperImporter struct{}
|
||||
|
||||
func (i *PiperImporter) Name() string { return "piper" }
|
||||
func (i *PiperImporter) Modality() string { return "tts" }
|
||||
func (i *PiperImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *PiperImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "piper" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
return HasONNXConfigPair(details.HuggingFace.Files)
|
||||
}
|
||||
|
||||
func (i *PiperImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// Default to the HF repo path so users can resolve the voice at runtime.
|
||||
// If the repo ships onnx pairs, surface the first voice file name so the
|
||||
// config is ready-to-run for single-voice repositories.
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
if details.HuggingFace != nil {
|
||||
for _, f := range details.HuggingFace.Files {
|
||||
base := filepath.Base(f.Path)
|
||||
if strings.HasSuffix(strings.ToLower(base), ".onnx") {
|
||||
model = base
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "piper",
|
||||
KnownUsecaseStrings: []string{"tts"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
50
core/gallery/importers/piper_test.go
Normal file
50
core/gallery/importers/piper_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("PiperImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches a single-voice piper repo by .onnx + .onnx.json pair", func() {
|
||||
// rhasspy/piper-voices is the canonical piper distribution but
|
||||
// its tree is too deep to recurse via the HF API inside a unit
|
||||
// test — per-voice mirrors exercise the same onnx+onnx.json
|
||||
// packaging with a flat directory.
|
||||
uri := "https://huggingface.co/HirCoir/piper-voice-es-mx-lucas-melor"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: piper"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("tts"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=piper for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "piper"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: piper"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.PiperImporter{}
|
||||
Expect(imp.Name()).To(Equal("piper"))
|
||||
Expect(imp.Modality()).To(Equal("tts"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
105
core/gallery/importers/qwen-asr.go
Normal file
105
core/gallery/importers/qwen-asr.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &QwenASRImporter{}
|
||||
|
||||
// QwenASRImporter matches Qwen's dedicated ASR checkpoints, e.g.
|
||||
// Qwen/Qwen3-ASR-1.7B. Detection is scoped to the Qwen owner with an "ASR"
|
||||
// substring in the repo name — narrow enough to avoid other Qwen Audio
|
||||
// variants that run on different backends (Qwen-Audio, Qwen2-Audio, Qwen3-
|
||||
// Omni, Qwen TTS). preferences.backend=qwen-asr forces detection.
|
||||
type QwenASRImporter struct{}
|
||||
|
||||
func (i *QwenASRImporter) Name() string { return "qwen-asr" }
|
||||
func (i *QwenASRImporter) Modality() string { return "asr" }
|
||||
func (i *QwenASRImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *QwenASRImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "qwen-asr" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(details.HuggingFace.Author, "Qwen") {
|
||||
return false
|
||||
}
|
||||
// Extract the repo-name portion so we don't accidentally match when
|
||||
// "asr" only appears as a substring in the owner field.
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
return strings.Contains(strings.ToLower(repoName), "asr")
|
||||
}
|
||||
|
||||
func (i *QwenASRImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "qwen-asr",
|
||||
KnownUsecaseStrings: []string{"transcript"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
46
core/gallery/importers/qwen-asr_test.go
Normal file
46
core/gallery/importers/qwen-asr_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("QwenASRImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches Qwen/Qwen3-ASR-1.7B (owner + ASR in name)", func() {
|
||||
uri := "https://huggingface.co/Qwen/Qwen3-ASR-1.7B"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: qwen-asr"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("Qwen/Qwen3-ASR-1.7B"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=qwen-asr for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "qwen-asr"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: qwen-asr"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.QwenASRImporter{}
|
||||
Expect(imp.Name()).To(Equal("qwen-asr"))
|
||||
Expect(imp.Modality()).To(Equal("asr"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
136
core/gallery/importers/rerankers.go
Normal file
136
core/gallery/importers/rerankers.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &RerankersImporter{}
|
||||
|
||||
// RerankersImporter routes cross-encoder / reranker repositories to the
|
||||
// "rerankers" backend. It must be registered BEFORE SentenceTransformers
|
||||
// and Transformers because reranker repos typically ship tokenizer files
|
||||
// (and sometimes modules.json) that would otherwise be claimed by those
|
||||
// generic importers.
|
||||
//
|
||||
// Detection signals:
|
||||
// - preferences.backend="rerankers" (explicit override);
|
||||
// - HF owner == "cross-encoder" (the canonical sentence-transformers
|
||||
// cross-encoder organisation);
|
||||
// - repo name contains "reranker" (case-insensitive) — catches BAAI
|
||||
// bge-reranker variants, Alibaba-NLP/gte-reranker-*, etc.
|
||||
type RerankersImporter struct{}
|
||||
|
||||
func (i *RerankersImporter) Name() string { return "rerankers" }
|
||||
func (i *RerankersImporter) Modality() string { return "reranker" }
|
||||
func (i *RerankersImporter) AutoDetects() bool { return true }
|
||||
|
||||
func repoLooksLikeReranker(repo string) bool {
|
||||
return strings.Contains(strings.ToLower(repo), "reranker")
|
||||
}
|
||||
|
||||
func (i *RerankersImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "rerankers" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "cross-encoder") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if repoLooksLikeReranker(repoName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: hfapi recursion bug may leave HuggingFace nil — decide
|
||||
// from the URI owner/repo.
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "cross-encoder") {
|
||||
return true
|
||||
}
|
||||
if repoLooksLikeReranker(repo) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *RerankersImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// Prefer the canonical HF "owner/repo" identifier so emitted YAML
|
||||
// mirrors the gallery rerankers entries.
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
} else if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
model = owner + "/" + repo
|
||||
}
|
||||
|
||||
trueV := true
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "rerankers",
|
||||
KnownUsecaseStrings: []string{"rerank"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
// Reranking is a field of the embedded config.LLMConfig; set it after
|
||||
// the literal so the intent stays obvious.
|
||||
modelConfig.Reranking = &trueV
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
158
core/gallery/importers/rerankers_test.go
Normal file
158
core/gallery/importers/rerankers_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("RerankersImporter", func() {
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
Expect(imp.Name()).To(Equal("rerankers"))
|
||||
Expect(imp.Modality()).To(Equal("reranker"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("matches when backend preference is rerankers", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
preferences := json.RawMessage(`{"backend": "rerankers"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/some-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches cross-encoder owner via HuggingFace details", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
Author: "cross-encoder",
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches when the repo name contains 'reranker' (case-insensitive)", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/BAAI/bge-reranker-v2-m3",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "BAAI/bge-reranker-v2-m3",
|
||||
Author: "BAAI",
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches Alibaba-NLP/gte-reranker repos", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/Alibaba-NLP/gte-reranker-modernbert-base",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "Alibaba-NLP/gte-reranker-modernbert-base",
|
||||
Author: "Alibaba-NLP",
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches via URI fallback when HuggingFace details are missing", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/BAAI/bge-reranker-v2-m3",
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does not match unrelated models without reranker signals", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/meta-llama/Llama-3-8B",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "meta-llama/Llama-3-8B",
|
||||
Author: "meta-llama",
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns false for invalid preferences JSON", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
preferences := json.RawMessage(`not valid json`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("produces a YAML with backend rerankers, reranking true, and the repo as the model", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/BAAI/bge-reranker-v2-m3",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "BAAI/bge-reranker-v2-m3",
|
||||
Author: "BAAI",
|
||||
},
|
||||
}
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: rerankers"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("BAAI/bge-reranker-v2-m3"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("reranking: true"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("respects custom name and description from preferences", func() {
|
||||
imp := &importers.RerankersImporter{}
|
||||
preferences := json.RawMessage(`{"name": "my-reranker", "description": "Custom"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/BAAI/bge-reranker-v2-m3",
|
||||
Preferences: preferences,
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "BAAI/bge-reranker-v2-m3",
|
||||
Author: "BAAI",
|
||||
},
|
||||
}
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-reranker"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("registration order vs TransformersImporter", func() {
|
||||
It("routes BAAI/bge-reranker HF URIs to rerankers rather than transformers", func() {
|
||||
uri := "https://huggingface.co/BAAI/bge-reranker-v2-m3"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: rerankers"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
})
|
||||
122
core/gallery/importers/rfdetr.go
Normal file
122
core/gallery/importers/rfdetr.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &RFDetrImporter{}
|
||||
|
||||
// RFDetrImporter routes RF-DETR object-detection repositories to the
|
||||
// "rfdetr" backend. It must be registered BEFORE TransformersImporter
|
||||
// because RF-DETR checkpoints often ship tokenizer-adjacent artefacts.
|
||||
//
|
||||
// Detection signals:
|
||||
// - preferences.backend="rfdetr" (explicit override);
|
||||
// - repo name contains "rf-detr" or "rfdetr" (case-insensitive).
|
||||
type RFDetrImporter struct{}
|
||||
|
||||
func (i *RFDetrImporter) Name() string { return "rfdetr" }
|
||||
func (i *RFDetrImporter) Modality() string { return "detection" }
|
||||
func (i *RFDetrImporter) AutoDetects() bool { return true }
|
||||
|
||||
func repoLooksLikeRFDetr(repo string) bool {
|
||||
lower := strings.ToLower(repo)
|
||||
return strings.Contains(lower, "rf-detr") || strings.Contains(lower, "rfdetr")
|
||||
}
|
||||
|
||||
func (i *RFDetrImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "rfdetr" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if repoLooksLikeRFDetr(repoName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: hfapi recursion bug may leave HuggingFace nil — decide
|
||||
// from the URI owner/repo.
|
||||
if _, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if repoLooksLikeRFDetr(repo) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *RFDetrImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// Prefer the canonical HF "owner/repo" identifier so the emitted
|
||||
// YAML mirrors gallery rfdetr entries.
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
} else if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
model = owner + "/" + repo
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "rfdetr",
|
||||
KnownUsecaseStrings: []string{"detection"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
132
core/gallery/importers/rfdetr_test.go
Normal file
132
core/gallery/importers/rfdetr_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("RFDetrImporter", func() {
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
Expect(imp.Name()).To(Equal("rfdetr"))
|
||||
Expect(imp.Modality()).To(Equal("detection"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("matches when backend preference is rfdetr", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
preferences := json.RawMessage(`{"backend": "rfdetr"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/some-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches when the repo name contains 'rf-detr' (case-insensitive)", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/roboflow/rf-detr-base",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "roboflow/rf-detr-base",
|
||||
Author: "roboflow",
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches when the repo name contains 'rfdetr' (case-insensitive)", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/some/rfdetr-whatever",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "some/rfdetr-whatever",
|
||||
Author: "some",
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches via URI fallback when HuggingFace details are missing", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/roboflow/rf-detr-base",
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does not match unrelated repos without rfdetr signals", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/meta-llama/Llama-3-8B",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "meta-llama/Llama-3-8B",
|
||||
Author: "meta-llama",
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns false for invalid preferences JSON", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
preferences := json.RawMessage(`not valid json`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("produces a YAML with backend rfdetr and the repo as the model", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/roboflow/rf-detr-base",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "roboflow/rf-detr-base",
|
||||
Author: "roboflow",
|
||||
},
|
||||
}
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: rfdetr"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("roboflow/rf-detr-base"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("respects custom name and description from preferences", func() {
|
||||
imp := &importers.RFDetrImporter{}
|
||||
preferences := json.RawMessage(`{"name": "my-detr", "description": "Custom"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/roboflow/rf-detr-base",
|
||||
Preferences: preferences,
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "roboflow/rf-detr-base",
|
||||
Author: "roboflow",
|
||||
},
|
||||
}
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-detr"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom"))
|
||||
})
|
||||
})
|
||||
})
|
||||
124
core/gallery/importers/sentencetransformers.go
Normal file
124
core/gallery/importers/sentencetransformers.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &SentenceTransformersImporter{}
|
||||
|
||||
// SentenceTransformersImporter routes sentence-transformers embedding
|
||||
// repositories to the "sentencetransformers" backend. It MUST be
|
||||
// registered BEFORE TransformersImporter — ST repos ship tokenizer.json
|
||||
// which would otherwise be claimed by the transformers importer.
|
||||
//
|
||||
// Detection signals:
|
||||
// - preferences.backend="sentencetransformers" (explicit override);
|
||||
// - repo ships "modules.json" (the ST pipeline manifest);
|
||||
// - repo ships "sentence_bert_config.json" (legacy ST marker);
|
||||
// - HF owner == "sentence-transformers".
|
||||
type SentenceTransformersImporter struct{}
|
||||
|
||||
func (i *SentenceTransformersImporter) Name() string { return "sentencetransformers" }
|
||||
func (i *SentenceTransformersImporter) Modality() string { return "embeddings" }
|
||||
func (i *SentenceTransformersImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *SentenceTransformersImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "sentencetransformers" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if HasFile(details.HuggingFace.Files, "modules.json") {
|
||||
return true
|
||||
}
|
||||
if HasFile(details.HuggingFace.Files, "sentence_bert_config.json") {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(details.HuggingFace.Author, "sentence-transformers") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: hfapi recursion bug may leave HuggingFace nil — decide
|
||||
// from the URI owner.
|
||||
if owner, _, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "sentence-transformers") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *SentenceTransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// Prefer the canonical HF "owner/repo" identifier so the emitted YAML
|
||||
// mirrors the gallery sentencetransformers entries.
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
} else if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
model = owner + "/" + repo
|
||||
}
|
||||
|
||||
trueV := true
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "sentencetransformers",
|
||||
KnownUsecaseStrings: []string{"embeddings"},
|
||||
Embeddings: &trueV,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
171
core/gallery/importers/sentencetransformers_test.go
Normal file
171
core/gallery/importers/sentencetransformers_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("SentenceTransformersImporter", func() {
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
Expect(imp.Name()).To(Equal("sentencetransformers"))
|
||||
Expect(imp.Modality()).To(Equal("embeddings"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("matches when backend preference is sentencetransformers", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
preferences := json.RawMessage(`{"backend": "sentencetransformers"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/some-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches when HF repo ships modules.json", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "sentence-transformers/all-MiniLM-L6-v2",
|
||||
Author: "sentence-transformers",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "modules.json"},
|
||||
{Path: "tokenizer.json"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches when HF repo ships sentence_bert_config.json", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/some/st-model",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "some/st-model",
|
||||
Author: "some",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "sentence_bert_config.json"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches sentence-transformers owner even without marker files", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/sentence-transformers/foo",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "sentence-transformers/foo",
|
||||
Author: "sentence-transformers",
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches via URI fallback when HuggingFace details are missing", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does not match unrelated plain transformers models", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/meta-llama/Llama-3-8B",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "meta-llama/Llama-3-8B",
|
||||
Author: "meta-llama",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer.json"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns false for invalid preferences JSON", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
preferences := json.RawMessage(`not valid json`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("produces a YAML with backend sentencetransformers, embeddings true, and the repo as the model", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "sentence-transformers/all-MiniLM-L6-v2",
|
||||
Author: "sentence-transformers",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "modules.json"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: sentencetransformers"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("sentence-transformers/all-MiniLM-L6-v2"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("embeddings: true"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("respects custom name and description from preferences", func() {
|
||||
imp := &importers.SentenceTransformersImporter{}
|
||||
preferences := json.RawMessage(`{"name": "my-embed", "description": "Custom"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
|
||||
Preferences: preferences,
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "sentence-transformers/all-MiniLM-L6-v2",
|
||||
Author: "sentence-transformers",
|
||||
},
|
||||
}
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-embed"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("registration order vs TransformersImporter", func() {
|
||||
It("routes sentence-transformers HF URIs to sentencetransformers rather than transformers", func() {
|
||||
uri := "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: sentencetransformers"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
})
|
||||
130
core/gallery/importers/silero-vad.go
Normal file
130
core/gallery/importers/silero-vad.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &SileroVADImporter{}
|
||||
|
||||
// SileroVADImporter recognises the Silero Voice Activity Detection models
|
||||
// distributed as ONNX weights. The canonical packaging ships a file named
|
||||
// exactly "silero_vad.onnx" (see snakers4/silero-vad); we additionally
|
||||
// accept any ONNX file under the "snakers4" owner so community-mirrored
|
||||
// copies still route here. preferences.backend="silero-vad" overrides
|
||||
// detection.
|
||||
type SileroVADImporter struct{}
|
||||
|
||||
func (i *SileroVADImporter) Name() string { return "silero-vad" }
|
||||
func (i *SileroVADImporter) Modality() string { return "vad" }
|
||||
func (i *SileroVADImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *SileroVADImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "silero-vad" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
if HasFile(details.HuggingFace.Files, "silero_vad.onnx") {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(details.HuggingFace.Author, "snakers4") && HasONNX(details.HuggingFace.Files) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: hfapi recursion bug may leave HuggingFace nil — decide
|
||||
// from the URI owner/repo. The snakers4 organisation ships only
|
||||
// silero-* projects, so URI-level ownership is a safe signal.
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "snakers4") && strings.Contains(strings.ToLower(repo), "silero") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *SileroVADImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
cfg := gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
// Prefer the canonical silero_vad.onnx filename when available so the
|
||||
// emitted YAML points at the actual weights. Fall back to the HF repo
|
||||
// path otherwise — users can adjust after import.
|
||||
model := details.URI
|
||||
if details.HuggingFace != nil {
|
||||
for _, f := range details.HuggingFace.Files {
|
||||
if filepath.Base(f.Path) == "silero_vad.onnx" {
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: f.URL,
|
||||
Filename: "silero_vad.onnx",
|
||||
SHA256: f.SHA256,
|
||||
})
|
||||
model = "silero_vad.onnx"
|
||||
break
|
||||
}
|
||||
}
|
||||
if model == details.URI && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
}
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "silero-vad",
|
||||
KnownUsecaseStrings: []string{"vad"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
cfg.ConfigFile = string(data)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
159
core/gallery/importers/silero-vad_test.go
Normal file
159
core/gallery/importers/silero-vad_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("SileroVADImporter", func() {
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
Expect(imp.Name()).To(Equal("silero-vad"))
|
||||
Expect(imp.Modality()).To(Equal("vad"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("matches when backend preference is silero-vad", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
preferences := json.RawMessage(`{"backend": "silero-vad"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/some-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches when HF repo ships silero_vad.onnx", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/snakers4/silero-vad",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "snakers4/silero-vad",
|
||||
Author: "snakers4",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "silero_vad.onnx"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches snakers4 owner with ONNX files even without the canonical filename", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/snakers4/silero-vad",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "snakers4/silero-vad",
|
||||
Author: "snakers4",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "some-other.onnx"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches via URI fallback when HuggingFace details are missing", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/snakers4/silero-vad",
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does not match unrelated repos without silero_vad.onnx or snakers4 owner", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/someone/random-model",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "someone/random-model",
|
||||
Author: "someone",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "config.json"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns false for invalid preferences JSON", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
preferences := json.RawMessage(`not valid json`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
Expect(imp.Match(details)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("produces a YAML with backend silero-vad and the vad known_usecase", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/snakers4/silero-vad",
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "snakers4/silero-vad",
|
||||
Author: "snakers4",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "silero_vad.onnx", URL: "https://huggingface.co/snakers4/silero-vad/resolve/main/silero_vad.onnx"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: silero-vad"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- vad"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("respects custom name and description from preferences", func() {
|
||||
imp := &importers.SileroVADImporter{}
|
||||
preferences := json.RawMessage(`{"name": "my-vad", "description": "Custom"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/snakers4/silero-vad",
|
||||
Preferences: preferences,
|
||||
HuggingFace: &hfapi.ModelDetails{
|
||||
ModelID: "snakers4/silero-vad",
|
||||
Author: "snakers4",
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "silero_vad.onnx"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelConfig, err := imp.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-vad"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("detection from HuggingFace", func() {
|
||||
It("matches snakers4/silero-vad via live HF metadata", func() {
|
||||
uri := "https://huggingface.co/snakers4/silero-vad"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: silero-vad"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
})
|
||||
220
core/gallery/importers/stablediffusion-ggml.go
Normal file
220
core/gallery/importers/stablediffusion-ggml.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &StableDiffusionGGMLImporter{}
|
||||
|
||||
// sdGGMLArchTokens enumerates the filename/repo substrings that reliably
|
||||
// indicate a Stable Diffusion / FLUX GGUF checkpoint. Matching is
|
||||
// case-insensitive so leejet's "stable-diffusion", city96's "FLUX.1-dev-gguf"
|
||||
// and assorted SDXL/SD3 mirrors all land here instead of being stolen by
|
||||
// llama-cpp (which otherwise claims every .gguf). This is a heuristic —
|
||||
// GGUF metadata would be authoritative but requires a parser we do not
|
||||
// ship in this package.
|
||||
var sdGGMLArchTokens = []string{
|
||||
"flux",
|
||||
"sd1.5",
|
||||
"sdxl",
|
||||
"sd3",
|
||||
"stable-diffusion",
|
||||
"stable_diffusion",
|
||||
}
|
||||
|
||||
// StableDiffusionGGMLImporter recognises GGUF-packaged Stable Diffusion /
|
||||
// FLUX checkpoints (leejet/stable-diffusion.cpp outputs, city96's FLUX GGUF
|
||||
// mirrors, second-state's SD 3.5 dumps, etc). It must be registered BEFORE
|
||||
// LlamaCPPImporter so llama-cpp does not steal the .gguf match.
|
||||
// preferences.backend="stablediffusion-ggml" overrides detection.
|
||||
type StableDiffusionGGMLImporter struct{}
|
||||
|
||||
func (i *StableDiffusionGGMLImporter) Name() string { return "stablediffusion-ggml" }
|
||||
func (i *StableDiffusionGGMLImporter) Modality() string { return "image" }
|
||||
func (i *StableDiffusionGGMLImporter) AutoDetects() bool { return true }
|
||||
|
||||
// containsArchToken reports whether s (compared case-insensitively) includes
|
||||
// any of the known SD/FLUX arch markers.
|
||||
func containsArchToken(s string) bool {
|
||||
lower := strings.ToLower(s)
|
||||
for _, tok := range sdGGMLArchTokens {
|
||||
if strings.Contains(lower, tok) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *StableDiffusionGGMLImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "stablediffusion-ggml" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Raw .gguf URI with an arch token in the filename/URI.
|
||||
if strings.HasSuffix(strings.ToLower(details.URI), ".gguf") && containsArchToken(details.URI) {
|
||||
return true
|
||||
}
|
||||
|
||||
// HF repo (when the API succeeded) with at least one .gguf file and
|
||||
// either a leejet owner or an arch token in the repo name.
|
||||
if details.HuggingFace != nil {
|
||||
if hasGGUF(details.HuggingFace.Files) {
|
||||
if strings.EqualFold(details.HuggingFace.Author, "leejet") {
|
||||
return true
|
||||
}
|
||||
repoName := details.HuggingFace.ModelID
|
||||
if idx := strings.Index(repoName, "/"); idx >= 0 {
|
||||
repoName = repoName[idx+1:]
|
||||
}
|
||||
if containsArchToken(repoName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: HF details are nil because of the known hfapi tree-listing
|
||||
// bug on repos with nested paths — decide from the URI owner/repo alone.
|
||||
if owner, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
if strings.EqualFold(owner, "leejet") || containsArchToken(repo) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *StableDiffusionGGMLImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// Default: raw .gguf URL — basename is the model name.
|
||||
model := filepath.Base(details.URI)
|
||||
cfg := gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(strings.ToLower(details.URI), ".gguf"):
|
||||
// Raw .gguf URI: mirror llama-cpp's flat layout.
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: details.URI,
|
||||
Filename: filepath.Base(details.URI),
|
||||
})
|
||||
model = filepath.Base(details.URI)
|
||||
case details.HuggingFace != nil && hasGGUF(details.HuggingFace.Files):
|
||||
chosen := pickSDGGUF(details.HuggingFace.Files)
|
||||
if chosen != nil {
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: chosen.URL,
|
||||
Filename: filepath.Base(chosen.Path),
|
||||
SHA256: chosen.SHA256,
|
||||
})
|
||||
model = filepath.Base(chosen.Path)
|
||||
}
|
||||
default:
|
||||
// Pure preference-driven import with a bare URI — best-effort model
|
||||
// name; the operator is expected to top up parameters post-import.
|
||||
if details.HuggingFace != nil && details.HuggingFace.ModelID != "" {
|
||||
model = details.HuggingFace.ModelID
|
||||
} else {
|
||||
model = details.URI
|
||||
}
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "stablediffusion-ggml",
|
||||
KnownUsecaseStrings: []string{"FLAG_IMAGE"},
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: model},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
cfg.ConfigFile = string(data)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// hasGGUF reports whether files contains at least one .gguf entry.
|
||||
func hasGGUF(files []hfapi.ModelFile) bool {
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(strings.ToLower(f.Path), ".gguf") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pickSDGGUF selects the best .gguf file for a SD/FLUX repo. Preference
|
||||
// order: Q4_K, then Q8_0, then the first .gguf in the tree. Quantisation
|
||||
// naming follows leejet/stable-diffusion.cpp and city96's FLUX mirrors.
|
||||
func pickSDGGUF(files []hfapi.ModelFile) *hfapi.ModelFile {
|
||||
var q4k, q8, first *hfapi.ModelFile
|
||||
for idx := range files {
|
||||
f := &files[idx]
|
||||
if !strings.HasSuffix(strings.ToLower(f.Path), ".gguf") {
|
||||
continue
|
||||
}
|
||||
if first == nil {
|
||||
first = f
|
||||
}
|
||||
lower := strings.ToLower(filepath.Base(f.Path))
|
||||
if q4k == nil && strings.Contains(lower, "q4_k") {
|
||||
q4k = f
|
||||
}
|
||||
if q8 == nil && strings.Contains(lower, "q8_0") {
|
||||
q8 = f
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case q4k != nil:
|
||||
return q4k
|
||||
case q8 != nil:
|
||||
return q8
|
||||
default:
|
||||
return first
|
||||
}
|
||||
}
|
||||
60
core/gallery/importers/stablediffusion-ggml_test.go
Normal file
60
core/gallery/importers/stablediffusion-ggml_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StableDiffusionGGMLImporter", func() {
|
||||
Context("detection from HuggingFace", func() {
|
||||
// city96/FLUX.1-dev-gguf is the canonical community GGUF mirror for
|
||||
// FLUX.1-dev and ships a flat tree of .gguf quantisations
|
||||
// (flux1-dev-Q4_K.gguf, flux1-dev-Q8_0.gguf, etc.). Detection must
|
||||
// route this to stablediffusion-ggml (and NOT to llama-cpp, which
|
||||
// otherwise steals every .gguf repo).
|
||||
It("matches a HF repo with GGUF files whose owner/repo contains flux/sd/sdxl tokens", func() {
|
||||
uri := "https://huggingface.co/city96/FLUX.1-dev-gguf"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: stablediffusion-ggml"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("matches a raw .gguf URL containing flux/sd arch tokens", func() {
|
||||
uri := "https://example.com/models/flux1-dev-Q4_K.gguf"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: stablediffusion-ggml"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("preference override", func() {
|
||||
It("honours preferences.backend=stablediffusion-ggml for arbitrary URIs", func() {
|
||||
uri := "https://example.com/some-unrelated-model"
|
||||
preferences := json.RawMessage(`{"backend": "stablediffusion-ggml"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: stablediffusion-ggml"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
imp := &importers.StableDiffusionGGMLImporter{}
|
||||
Expect(imp.Name()).To(Equal("stablediffusion-ggml"))
|
||||
Expect(imp.Modality()).To(Equal("image"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user