mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-06 15:56:06 -04:00
Compare commits
213 Commits
v4.2.6
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b15627c864 | ||
|
|
352b7ec604 | ||
|
|
ba706422fb | ||
|
|
e837921c2c | ||
|
|
73385713ca | ||
|
|
a4e671779a | ||
|
|
7051b2e0a1 | ||
|
|
469737101a | ||
|
|
858257eaf0 | ||
|
|
ef80a0e825 | ||
|
|
92726f7631 | ||
|
|
994063ba9a | ||
|
|
c1a55cf72d | ||
|
|
96758841d8 | ||
|
|
7a59260621 | ||
|
|
27e63b9a78 | ||
|
|
55c0911c23 | ||
|
|
f6cb6ab6d9 | ||
|
|
9f11b09c6a | ||
|
|
a5c4f822f0 | ||
|
|
fb36c262fe | ||
|
|
0e4e8980e6 | ||
|
|
3a932a9803 | ||
|
|
9d10418593 | ||
|
|
5470051d4d | ||
|
|
68c5eeebc3 | ||
|
|
1531fabe23 | ||
|
|
b7673d5b76 | ||
|
|
b64bdaf406 | ||
|
|
eebf08ff1d | ||
|
|
42e51894c3 | ||
|
|
d9ae6481fb | ||
|
|
f1c495a748 | ||
|
|
415b561947 | ||
|
|
e6a0d4c375 | ||
|
|
7e59a5c7c5 | ||
|
|
aea954a482 | ||
|
|
595e448714 | ||
|
|
860f9d63ad | ||
|
|
a5a0b3dc4e | ||
|
|
94eca04c60 | ||
|
|
35bd485d6a | ||
|
|
1fe96f8d9a | ||
|
|
c508e9d7c6 | ||
|
|
55e754fd05 | ||
|
|
a17753f7d1 | ||
|
|
c61838dba6 | ||
|
|
7013e13f05 | ||
|
|
5a0013defe | ||
|
|
c01ed631d6 | ||
|
|
d47464cb06 | ||
|
|
63f176346e | ||
|
|
af94d08729 | ||
|
|
6795d38f50 | ||
|
|
718223f33b | ||
|
|
39e050d9e2 | ||
|
|
c222161291 | ||
|
|
aa80d4681b | ||
|
|
0d57957ebb | ||
|
|
76fe0bb929 | ||
|
|
baa11133f1 | ||
|
|
1bdd3338a6 | ||
|
|
e08492a2c3 | ||
|
|
d5d8fe909d | ||
|
|
8a82753277 | ||
|
|
51ca109067 | ||
|
|
07f6c15a37 | ||
|
|
a44bdb29d4 | ||
|
|
aee4611ab2 | ||
|
|
486467623c | ||
|
|
4912c9b73a | ||
|
|
12d1f3a697 | ||
|
|
a7cad704b9 | ||
|
|
7e4df67556 | ||
|
|
5b24b4dacc | ||
|
|
52fdb46892 | ||
|
|
b389f0fe5f | ||
|
|
74281be340 | ||
|
|
cacf2f7a2c | ||
|
|
4a2cc64d07 | ||
|
|
4647770316 | ||
|
|
3c9b9529c0 | ||
|
|
fc2bd0986c | ||
|
|
a473a32678 | ||
|
|
3e220373b0 | ||
|
|
fbcd886a47 | ||
|
|
e1a782b70f | ||
|
|
73cfedc023 | ||
|
|
b982c977d5 | ||
|
|
532ca1b3a2 | ||
|
|
00ad55b590 | ||
|
|
4c58fd302f | ||
|
|
66582e7035 | ||
|
|
1d13949588 | ||
|
|
c8ad67bbca | ||
|
|
1c92b00918 | ||
|
|
b81a6d01b3 | ||
|
|
0fd666ee6e | ||
|
|
7763fb23a3 | ||
|
|
324277ccfd | ||
|
|
10d02e6c59 | ||
|
|
05ae06c17b | ||
|
|
2671e0c6f7 | ||
|
|
81b6b94f0b | ||
|
|
373dc44992 | ||
|
|
02a0e70396 | ||
|
|
7a4ca8f60d | ||
|
|
893e69cbf8 | ||
|
|
c9a1a7e6a0 | ||
|
|
4d01298048 | ||
|
|
b6055e7ecf | ||
|
|
51bad74bf8 | ||
|
|
f3236b74cf | ||
|
|
eed3ecff82 | ||
|
|
df7623fd87 | ||
|
|
4e5ec6f67b | ||
|
|
8d70855ea6 | ||
|
|
90c29f9258 | ||
|
|
66aaa525e5 | ||
|
|
a29a06f6b6 | ||
|
|
80893a298b | ||
|
|
9d90e04418 | ||
|
|
11f43ad8b0 | ||
|
|
437f0fa193 | ||
|
|
aa743f8824 | ||
|
|
2162611dca | ||
|
|
4aad97971c | ||
|
|
e4c70fca7a | ||
|
|
4b398c9798 | ||
|
|
4a5219fa9c | ||
|
|
b5a620294e | ||
|
|
5d544a7868 | ||
|
|
87e01aa290 | ||
|
|
f17d99f6e5 | ||
|
|
597daa925b | ||
|
|
3e0612b8b4 | ||
|
|
de2ce74bea | ||
|
|
1c6c3adad6 | ||
|
|
c2cd3b9ada | ||
|
|
9ff270eb65 | ||
|
|
8d6548c0b9 | ||
|
|
b02e3ffe61 | ||
|
|
a891eedd08 | ||
|
|
06e777b75e | ||
|
|
90ea327178 | ||
|
|
6a80e23733 | ||
|
|
1dcd1ae915 | ||
|
|
acad78a95a | ||
|
|
c94d1e1f5b | ||
|
|
270c256409 | ||
|
|
1a30020a82 | ||
|
|
8bbe89a537 | ||
|
|
dcc5599f89 | ||
|
|
a95f4e63e0 | ||
|
|
dfd19a3f88 | ||
|
|
d7387c725c | ||
|
|
63d84a5705 | ||
|
|
1198d10b58 | ||
|
|
a0f3e26245 | ||
|
|
e4cc1f11f3 | ||
|
|
6ed269d0b9 | ||
|
|
5756fb046d | ||
|
|
7980629bc5 | ||
|
|
d0a59be9de | ||
|
|
5cda4f1ccf | ||
|
|
c500461c69 | ||
|
|
834ecc36bf | ||
|
|
61bf34ea2f | ||
|
|
0b2ae3c6ca | ||
|
|
4735345105 | ||
|
|
7384fd800b | ||
|
|
6942713d85 | ||
|
|
0cf52c44d4 | ||
|
|
0d34cf7cbd | ||
|
|
f0cb02afb8 | ||
|
|
a39e025d64 | ||
|
|
05e8e1e9f4 | ||
|
|
a7f6cc8956 | ||
|
|
f15b9178ec | ||
|
|
959de86761 | ||
|
|
4c234abc2c | ||
|
|
c68818a62e | ||
|
|
11d5bd0cc3 | ||
|
|
12e056e96d | ||
|
|
308aa8908a | ||
|
|
b2d68a53a2 | ||
|
|
e3706c0512 | ||
|
|
1ffd82a050 | ||
|
|
f515168dbe | ||
|
|
ef6ca34513 | ||
|
|
9413c3767f | ||
|
|
3bf3cce232 | ||
|
|
06f8159035 | ||
|
|
f6a73f54fa | ||
|
|
24e04d8e81 | ||
|
|
b9a49449ae | ||
|
|
1879e11042 | ||
|
|
403d391316 | ||
|
|
fc3980dadd | ||
|
|
2009544b44 | ||
|
|
e859345b12 | ||
|
|
f30712f8e8 | ||
|
|
a19c77c5f8 | ||
|
|
4b02d23c0c | ||
|
|
21140e96b2 | ||
|
|
fc803e8d48 | ||
|
|
ca51606bfe | ||
|
|
cb502de309 | ||
|
|
5d0b549049 | ||
|
|
11cff1b309 | ||
|
|
4ca3d2cdc0 | ||
|
|
3cba35ed32 | ||
|
|
265ae35231 |
@@ -112,6 +112,8 @@ Add a YAML anchor definition in the `## metas` section (around line 2-300). Look
|
||||
|
||||
Add image entries at the end of the file, following the pattern of similar backends such as `diffusers` or `chatterbox`. Include both `latest` (production) and `master` (development) tags.
|
||||
|
||||
**Note on integrity:** OCI backends installed from a gallery whose `verification:` block is set are verified against a keyless-cosign policy before extraction; tarball/HTTP backends use the optional `sha256:` field. New backends do not need any extra YAML — the gallery-level `verification:` block covers every entry. See [.agents/backend-signing.md](backend-signing.md) for the producer-side CI step.
|
||||
|
||||
## 4. Update the Makefile
|
||||
|
||||
The Makefile needs to be updated in several places to support building and testing the new backend:
|
||||
|
||||
126
.agents/backend-signing.md
Normal file
126
.agents/backend-signing.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Backend image signing & verification
|
||||
|
||||
LocalAI verifies backend OCI images against a per-gallery keyless-cosign
|
||||
policy. This page documents the trust model, the producer side
|
||||
(`.github/workflows/backend_merge.yml` in this repo), and the consumer
|
||||
side (`pkg/oci/cosignverify` plus the gallery YAML).
|
||||
|
||||
## Trust model
|
||||
|
||||
- **Producer:** `.github/workflows/backend_merge.yml` signs each pushed
|
||||
manifest list with `cosign sign --recursive` in keyless mode after
|
||||
`docker buildx imagetools create`. The signing cert is issued by
|
||||
Fulcio bound to the workflow's OIDC identity. There is no long-lived
|
||||
signing key. `--recursive` signs both the manifest list and every
|
||||
per-arch entry — needed because our consumer resolves a tag to a
|
||||
per-arch manifest before checking signatures.
|
||||
- **Storage:** Signatures are written as OCI 1.1 referrers
|
||||
(`--registry-referrers-mode=oci-1-1`) in the new Sigstore bundle format
|
||||
(current cosign releases do this by default; no `--new-bundle-format`
|
||||
flag). No `:sha256-<hex>.sig` tag clutter.
|
||||
- **Consumer:** `pkg/oci/cosignverify` discovers the bundle via the
|
||||
referrers API, hands it to `sigstore-go`, and verifies it against the
|
||||
policy declared in the gallery YAML (`Gallery.Verification`).
|
||||
- **Revocation:** Keyless cosign certs are ephemeral (10-minute Fulcio
|
||||
validity), so revocation is policy-side, not CA-side. The gallery's
|
||||
`verification.not_before` (RFC3339) is the kill-switch — advance it to
|
||||
invalidate every signature produced before a known compromise window.
|
||||
|
||||
## Producer setup
|
||||
|
||||
`backend_merge.yml` is the workflow that joins per-arch digests into the
|
||||
multi-arch manifest list users actually pull, so it's also the right place
|
||||
to sign. The job needs:
|
||||
|
||||
- `permissions: { id-token: write, contents: read }` at the job level so
|
||||
the runner can exchange its GitHub OIDC token for a Fulcio cert.
|
||||
- `sigstore/cosign-installer@v3` step (current cosign releases already
|
||||
default to the new bundle format).
|
||||
- After each `docker buildx imagetools create`, resolve the resulting
|
||||
list digest with `docker buildx imagetools inspect <tag> --format
|
||||
'{{.Manifest.Digest}}'` and sign:
|
||||
|
||||
```sh
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"${REGISTRY_REPO}@${DIGEST}"
|
||||
```
|
||||
|
||||
Sign by digest, never by tag — signing by tag binds the signature to
|
||||
whatever the tag points at *now*, and a subsequent tag push orphans it.
|
||||
|
||||
`--registry-referrers-mode=oci-1-1` is still gated behind
|
||||
`COSIGN_EXPERIMENTAL=1` in cosign v2.4.x (set at the job env level in
|
||||
`backend_merge.yml`). Re-evaluate when bumping the pinned cosign release
|
||||
— newer versions are expected to graduate this flag and the env var can
|
||||
then be dropped.
|
||||
|
||||
`backend_build_darwin.yml` builds and pushes single-arch darwin images
|
||||
that bypass the manifest-list merge. If/when those entries get a gallery
|
||||
`verification:` policy, the equivalent cosign step has to land there
|
||||
too.
|
||||
|
||||
## Consumer setup (in `mudler/LocalAI` gallery YAML)
|
||||
|
||||
Once CI is signing, add a `verification:` block to the backend gallery
|
||||
entry (`backend/index.yaml`):
|
||||
|
||||
```yaml
|
||||
- name: localai
|
||||
url: github:mudler/LocalAI/backend/index.yaml@master
|
||||
verification:
|
||||
issuer: "https://token.actions.githubusercontent.com"
|
||||
identity_regex: "^https://github\\.com/mudler/LocalAI/\\.github/workflows/backend_merge\\.yml@refs/heads/master$"
|
||||
# Optional revocation cutoff; advance during incident response.
|
||||
# not_before: "2026-06-01T00:00:00Z"
|
||||
```
|
||||
|
||||
Identity matching pins the OIDC subject Fulcio issued the signing cert
|
||||
to. Without this, any image signed by *anyone* with a Fulcio cert would
|
||||
pass — the regex is what makes a signature mean "produced by our CI".
|
||||
|
||||
## Strict mode
|
||||
|
||||
Default behaviour: OCI backends without a `verification:` block install
|
||||
with a warning (logs include `installing OCI backend without signature
|
||||
verification`). Tarball/HTTP backends without a `sha256` field log a
|
||||
similar warning.
|
||||
|
||||
For production, set `LOCALAI_REQUIRE_BACKEND_INTEGRITY=1` (or pass
|
||||
`--require-backend-integrity` to `local-ai run` / `local-ai backends
|
||||
install` / `local-ai models install`). The warning becomes a hard error
|
||||
and unverifiable backends refuse to install.
|
||||
|
||||
## Revocation playbook
|
||||
|
||||
If `backend_merge.yml` (or any workflow with `id-token: write`) is
|
||||
compromised and we've shipped malicious signed images:
|
||||
|
||||
1. **Identify the compromise window.** Find the earliest IntegratedTime
|
||||
from the bad signatures (Rekor search by `subject` filter).
|
||||
2. **Set `verification.not_before`** in `backend/index.yaml` to a
|
||||
timestamp just *after* that window's start.
|
||||
3. **Push the YAML.** Deployed LocalAI instances pick it up on next
|
||||
gallery refresh (1-hour cache in `core/gallery/gallery.go`).
|
||||
4. **Fix the underlying compromise** in the workflow and re-sign images
|
||||
with the new build, which will have IntegratedTime > `not_before`.
|
||||
5. **Optional:** for absolute decisiveness, also rotate to a new
|
||||
workflow path (`backend_merge_v2.yml`) and update `identity_regex`.
|
||||
|
||||
## Where the code lives
|
||||
|
||||
- `pkg/oci/cosignverify/` — verifier, policy, OCI referrer fetch, NotBefore enforcement.
|
||||
- `pkg/downloader/uri.go` — `WithImageVerifier` option threaded through `DownloadFileWithContext`.
|
||||
- `core/gallery/backends.go` — `backendDownloadOptions` builds the verifier from the gallery's policy.
|
||||
- `core/config/gallery.go` — `Gallery.Verification` YAML schema.
|
||||
- `core/cli/run.go`, `core/cli/backends.go`, `core/cli/models.go` — `--require-backend-integrity` flag propagation.
|
||||
- `.github/workflows/backend_merge.yml` — producer-side `cosign sign --recursive` after each multi-arch manifest list push.
|
||||
|
||||
## Out of scope (follow-ups)
|
||||
|
||||
- **Signing the gallery YAML itself.** The index is fetched over HTTPS
|
||||
from GitHub; we trust the host. A cosign blob signature on the YAML
|
||||
would close that gap but adds key-management overhead. Revisit this
|
||||
page if/when added.
|
||||
- **Tarball/HTTP backend signing.** Cosign can sign arbitrary blobs, but
|
||||
for now non-OCI backends keep using the `sha256:` field in YAML.
|
||||
@@ -15,3 +15,35 @@ Let's say the user wants to build a particular backend for a given platform. For
|
||||
- Unless the user specifies that they want you to run the command, then just print it because not all agent frontends handle long running jobs well and the output may overflow your context
|
||||
- The user may say they want to build AMD or ROCM instead of hipblas, or Intel instead of SYCL or NVIDIA insted of l4t or cublas. Ask for confirmation if there is ambiguity.
|
||||
- Sometimes the user may need extra parameters to be added to `docker build` (e.g. `--platform` for cross-platform builds or `--progress` to view the full logs), in which case you can generate the `docker build` command directly.
|
||||
|
||||
## Test coverage gate
|
||||
|
||||
The core Go suites (`./pkg`, `./core`, plus the in-process integration suite `./tests/e2e`) are covered by a **strict, monotonic coverage ratchet**:
|
||||
|
||||
- `make test-coverage` — runs the suites with `covermode=atomic` instrumentation and writes a merged profile to `coverage/coverage.out`. Uses the same prerequisites as `make test`.
|
||||
- **`--coverpkg` (`COVERAGE_COVERPKG = core/...,pkg/...`):** coverage is attributed to the core+pkg packages, not just the package under test. This is what lets the in-process `tests/e2e` suite (which drives the real HTTP server over loopback via `application.New`) credit the `core/http/endpoints/...` handlers it exercises — folding it in roughly doubled endpoint coverage (e.g. `endpoints/openai` 13.6% → 52%). The denominator is therefore *all* of `core`+`pkg` (minus generated proto, dropped via `COVERAGE_EXCLUDE_RE`), so the number isn't comparable to a plain per-package figure.
|
||||
- **Integration suites (`COVERAGE_E2E_ROOTS = ./tests/e2e`)** run non-recursively (excludes `tests/e2e/distributed`, which needs containers) with `--label-filter=!real-models` (those need a downloaded model) against the mock backend built by `prepare-test`. `tests/integration` is deliberately excluded — it needs `make backends/local-store`, which the coverage CI job doesn't build.
|
||||
- **Flake note:** folding integration tests into a *strict* gate means a hard e2e failure (or a spec that silently stops running) can fail the coverage gate, not just the test. `--flake-attempts` absorbs transient retryable failures; covermode=atomic keeps line coverage deterministic otherwise.
|
||||
- **Why one ginkgo run per root (`scripts/run-coverage.sh`):** passing several recursive roots to a *single* ginkgo invocation (e.g. `ginkgo -r ./pkg ./core`) only merges **one** root's coverprofile into `--output-dir`/`--coverprofile` — the others are silently dropped. Verified with ginkgo 2.29.0: `-r ./pkg ./core` yields only `./pkg` coverage, while `-r ./core` alone yields all 34 core packages. So the script runs each root separately and concatenates the (disjoint) profiles. Don't "simplify" it back to a single multi-root invocation — that's how `core/` (including all of `core/http`, ~7.4k statements) silently vanished from the number before.
|
||||
- **Build tags (`COVERAGE_TAGS`, passed via `GINKGO_TAGS`):** defaults to `debug auth`. The `auth` tag is required to compile the real (sqlite-backed) auth implementation and its ~150 `//go:build auth` tests — without it those files aren't built, the tests don't run, and the gate scores auth against a stub (~3.7% instead of ~38%). If you add new tag-gated tests, extend `COVERAGE_TAGS` or they won't count (and likely won't run in CI at all).
|
||||
- `make test-coverage-check` — runs `test-coverage`, then `scripts/coverage-check.sh` fails the build if total coverage is **below** the committed baseline in `coverage-baseline.txt`. The Linux job in `.github/workflows/test.yml` runs this instead of `make test`.
|
||||
- `make test-coverage-baseline` — regenerates and overwrites `coverage-baseline.txt` from the current run.
|
||||
- `make install-hooks` — sets `core.hooksPath` to the versioned `.githooks/`, whose `pre-commit` runs checks scoped to what's staged: Go changes → `make lint` + `make test-coverage-check`; `core/http/react-ui/` changes → `make test-ui-coverage-check` (Playwright e2e + UI coverage gate). A commit touching neither is skipped; bypass with `git commit --no-verify`. The hook resolves golangci-lint's new-from base to `upstream/master` → `origin/master` → `master`, so it works from a fork clone where `origin/master` is stale (passed to `make lint` via `LINT_NEW_FROM`).
|
||||
|
||||
### React UI coverage
|
||||
|
||||
The React UI (`core/http/react-ui/`) has **no component/unit tests** — its only tests are the Playwright e2e specs in `e2e/`, which run against the real app served by `tests/e2e-ui/ui-test-server` (the dist is `//go:embed`ed, so the server is rebuilt per coverage run). Those specs do genuinely exercise the UI (clicks, `fill`, `setInputFiles`, `getByRole`/`getByText`, visibility/value assertions).
|
||||
|
||||
- `make test-ui-coverage` — builds an istanbul-instrumented bundle (`COVERAGE=true`, via `vite-plugin-istanbul` with `forceBuildInstrument: true` — the plugin skips production builds otherwise), re-embeds it into `ui-test-server` (the dist is `//go:embed`ed), runs the Playwright specs, and writes an `nyc` report to `core/http/react-ui/coverage/`. The specs import `{ test, expect }` from `e2e/coverage-fixtures.js` (re-exports Playwright's, plus harvests `window.__coverage__` into `.nyc_output/` after each test). Instrumentation is off unless `COVERAGE=true`, so dev/prod builds and plain `make test-ui-e2e` are unaffected (the fixture no-ops when `window.__coverage__` is absent).
|
||||
- **Browser:** the flake dev shell ships `chromium` and exports `PLAYWRIGHT_CHROMIUM_PATH`; `playwright.config.js` uses it via `launchOptions.executablePath`, and the Makefile skips `playwright install` when it's set. This avoids Playwright's downloaded browser, which can't resolve system libs (`libglib-2.0`, …) on NixOS. In CI (no `PLAYWRIGHT_CHROMIUM_PATH`) the Makefile falls back to `playwright install --with-deps chromium`.
|
||||
- The app is a React SPA, so coverage accumulates across in-app navigation within a test; a full `page.goto`/reload resets it.
|
||||
- `.nycrc.json` uses `all: true`, so **every `src/**` file is in the report**, including 0%-coverage ones — that's how you spot features with no test at all (sort the HTML report or `coverage-summary.json` by line% ascending).
|
||||
- **UI coverage gate:** `make test-ui-coverage-check` runs the suite then `scripts/ui-coverage-check.sh`, failing if total line coverage drops more than `UI_COVERAGE_TOLERANCE` below `core/http/react-ui/coverage-baseline.txt`. `make test-ui-coverage-baseline` regenerates the baseline. Runs in CI (`tests-ui-e2e.yml`) and pre-commit on `core/http/react-ui/` changes.
|
||||
- **Why it has a tolerance (unlike the strict Go gate):** UI e2e coverage is *non-deterministic*. Specs that assert on state and end while async/lazy render work is still in flight collect those lines only when the render beats the coverage teardown — so the total drifts with machine speed/load (a fast local box reads higher than a slow CI runner), diffusely across many specs. The tolerance absorbs that drift, so set the baseline *below* the slow-CI floor, never to a fast-local `make test-ui-coverage-baseline` number, or CI flaps.
|
||||
- **Raising coverage is cheap:** a *render-smoke* spec (navigate to a route, assert its header renders) mounts a lazy page and runs its full render + initial effects, capturing most of its lines in a few lines of test — see `e2e/page-render-smoke.spec.js`. Auth is disabled in the test server (`isAdmin=true`), so `RequireAdmin`/`RequireFeature` routes render without a mock. The most *deterministic* win is removing a race: make a spec `await` a rendered element before ending (see `e2e/agents.spec.js` → AgentCreate) so its lines count every run.
|
||||
|
||||
Rules (both gates):
|
||||
- **Install the hooks:** `make install-hooks` once per clone so lint + coverage run pre-commit. Don't lean on CI for what the hook catches.
|
||||
- **Don't work around the gate:** never `git commit --no-verify`, and never hand-lower a baseline or widen a tolerance to turn a red gate green. The ratchet only moves up.
|
||||
- If a change drops coverage, **add tests** (sort `coverage-summary.json` by line% ascending to find untested code) rather than editing the baseline. When coverage legitimately rises, commit the regenerated baseline (`make test-coverage-baseline` / `test-ui-coverage-baseline`).
|
||||
- The Go gate is **strict — no tolerance**; `covermode=atomic` keeps it deterministic. The UI gate keeps a small tolerance only because its e2e coverage isn't.
|
||||
|
||||
@@ -50,6 +50,17 @@ Do not mix styles within a package. If you are extending tests in a package that
|
||||
|
||||
This is enforced by `golangci-lint` via the `forbidigo` linter (see `.golangci.yml`); calls like `t.Errorf` / `t.Fatalf` / `t.Run` / `t.Skip` / `t.Logf` are flagged. Run `make lint` locally before submitting; the same check runs in CI (`.github/workflows/lint.yml`).
|
||||
|
||||
## Outbound HTTP
|
||||
|
||||
All outbound HTTP must go through `github.com/mudler/LocalAI/pkg/httpclient` rather than the standard library's default client. Use `httpclient.New(...)` (no body deadline — safe for streaming/SSE) or `httpclient.NewWithTimeout(d, ...)` (simple request/response). Both **refuse redirects by default** and set a TLS 1.2 floor.
|
||||
|
||||
The reason is GHSA-3mj3-57v2-4636: the std default client follows redirects, and on a *cross-host* redirect Go forwards custom credential headers (e.g. Anthropic's `x-api-key`) to the redirect target, leaking the secret. `httpclient` fails closed instead.
|
||||
|
||||
- Need to follow redirects (download CDNs, registry blobs, GitHub asset URLs)? Pass `httpclient.WithFollowRedirects()` — it still strips credential headers on any cross-host hop.
|
||||
- Have a custom transport (IP-pinned dialer, HTTP/2 tuning, a credential-injecting `RoundTripper`)? Pass `httpclient.WithTransport(rt)`, basing the transport on `httpclient.HardenedTransport()` to keep the TLS floor. Handed a `*http.Client` by a library? `httpclient.Harden(c)` applies the policy in place.
|
||||
|
||||
This is enforced by `forbidigo` (see `.golangci.yml`): `http.DefaultClient` and `http.Get`/`Post`/`PostForm`/`Head` are flagged. The `&http.Client{}` composite literal can't be matched precisely by forbidigo without also flagging legitimate `*http.Client` type references, so that form is caught by review — don't construct raw clients.
|
||||
|
||||
## Documentation
|
||||
|
||||
The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant.
|
||||
|
||||
@@ -68,6 +68,34 @@ go test -count=1 -timeout=30m -v ./tests/e2e-backends/...
|
||||
|
||||
CI does not load the model; the suite is opt-in via env vars.
|
||||
|
||||
## Distributed mode
|
||||
|
||||
ds4 supports **layer-split** distributed inference (a model too big for one host,
|
||||
split by transformer layer; the GGUF must be present on every machine, each loads
|
||||
only its slice). Topology is **inverted** vs llama.cpp: the coordinator listens,
|
||||
workers dial in.
|
||||
|
||||
- **`ds4-worker` binary**: built and packaged next to `grpc-server` (`package.sh`
|
||||
copies it into `package/`). Links the same engine objects plus `ds4_distributed.o`;
|
||||
**no gRPC/protobuf dependency** (speaks ds4's own TCP transport), so it builds
|
||||
even where `grpc-server` can't. Runs the worker serving loop (`ds4_dist_run`).
|
||||
- **Coordinator wiring**: the ds4 `grpc-server` acts as coordinator when `LoadModel`
|
||||
`ModelOptions.Options` (from model-YAML `options:`) carry:
|
||||
- `ds4_role:coordinator` (enables distributed mode; absent → single-node, back-compat)
|
||||
- `ds4_layers:0:19` (coordinator's own slice, inclusive; `N:output` includes the head)
|
||||
- `ds4_listen:0.0.0.0:1234` (address workers dial into)
|
||||
- `ds4_route_timeout:60` (optional; seconds Predict/PredictStream wait for the route
|
||||
to form before returning gRPC `UNAVAILABLE`; default 60)
|
||||
- **Worker CLI**: `local-ai worker ds4-distributed -- <ds4-worker args>` resolves the
|
||||
ds4 backend and execs the packaged `ds4-worker` (raw passthrough), e.g.
|
||||
`--role worker --model /models/ds4flash.gguf --layers 20:output --coordinator <host> 1234`.
|
||||
|
||||
Opt-in e2e in `tests/e2e-backends/backend_test.go`, gated by
|
||||
`BACKEND_TEST_DS4_DISTRIBUTED=1` (plus `BACKEND_TEST_DS4_WORKER_BINARY`,
|
||||
`BACKEND_TEST_DS4_WORKER_LAYERS`, `BACKEND_TEST_DS4_COORDINATOR_LAYERS`,
|
||||
`BACKEND_TEST_DS4_LISTEN`). Design spec:
|
||||
`docs/superpowers/specs/2026-05-30-ds4-distributed-inference-design.md`.
|
||||
|
||||
## Importer
|
||||
|
||||
`core/gallery/importers/ds4.go` (`DS4Importer`) auto-detects ds4 weights by
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
.devcontainer
|
||||
models
|
||||
backends
|
||||
volumes
|
||||
examples/chatbot-ui/models
|
||||
backend/go/image/stablediffusion-ggml/build/
|
||||
backend/go/*/build
|
||||
@@ -21,3 +22,27 @@ __pycache__
|
||||
# backend virtual environments
|
||||
**/venv
|
||||
backend/python/**/source
|
||||
|
||||
# In-place llama.cpp clone + per-variant build copies. The Makefile
|
||||
# clones llama.cpp itself at the pinned LLAMA_VERSION; if a stale
|
||||
# local checkout is COPY'd into the image, the `llama.cpp:` target
|
||||
# sees the directory and skips re-cloning, so grpc-server.cpp ends
|
||||
# up compiled against whatever (likely older) commit the host had.
|
||||
backend/cpp/llama-cpp/llama.cpp
|
||||
backend/cpp/llama-cpp-*-build
|
||||
|
||||
# Rust backend build output (sources are tracked; target/ is generated)
|
||||
backend/rust/*/target
|
||||
|
||||
# Local-only artifacts that bloat the build context but the image never needs.
|
||||
# Saved image tarballs, locally-installed backends, the host-built binary, and
|
||||
# assorted tool/scratch dirs. None of these are git-tracked.
|
||||
backend-images
|
||||
local-backends
|
||||
local-ai
|
||||
.crush
|
||||
protoc
|
||||
tests
|
||||
|
||||
# Installed via npm inside the build stage; no need to ship the host copy.
|
||||
**/node_modules
|
||||
|
||||
60
.githooks/pre-commit
Executable file
60
.githooks/pre-commit
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env sh
|
||||
#
|
||||
# LocalAI pre-commit hook. Install it (once per clone) with:
|
||||
#
|
||||
# make install-hooks
|
||||
#
|
||||
# Runs only the checks relevant to what's staged:
|
||||
# - Go files -> make lint + make test-coverage-check
|
||||
# - core/http/react-ui -> make test-ui-coverage-check (Playwright e2e + gate)
|
||||
# A commit touching neither is skipped entirely (docs/YAML/etc. can't change
|
||||
# lint findings, Go coverage, or the UI).
|
||||
#
|
||||
# To bypass for a single commit (e.g. a WIP checkpoint): git commit --no-verify
|
||||
set -eu
|
||||
|
||||
repo_root="$(git rev-parse --show-toplevel)"
|
||||
cd "$repo_root"
|
||||
|
||||
staged="$(git diff --cached --name-only --diff-filter=ACMRD)"
|
||||
|
||||
go_changed=0
|
||||
ui_changed=0
|
||||
if echo "$staged" | grep -qE '\.go$'; then go_changed=1; fi
|
||||
if echo "$staged" | grep -qE '^core/http/react-ui/'; then ui_changed=1; fi
|
||||
|
||||
if [ "$go_changed" -eq 0 ] && [ "$ui_changed" -eq 0 ]; then
|
||||
echo "pre-commit: no Go or React UI changes staged — skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$go_changed" -eq 1 ]; then
|
||||
# Resolve the ref golangci-lint's new-from-merge-base should compare
|
||||
# against. .golangci.yml pins origin/master, which is correct in CI
|
||||
# (origin == the canonical repo) but wrong from a fork clone, where
|
||||
# origin/master lags behind and lint would report the whole upstream
|
||||
# backlog. Prefer upstream/master, then origin/master, then master.
|
||||
lint_base=""
|
||||
for ref in upstream/master origin/master master; do
|
||||
if git rev-parse --verify --quiet "${ref}^{commit}" >/dev/null 2>&1; then
|
||||
lint_base="$ref"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
echo "pre-commit ▶ golangci-lint (make lint${lint_base:+, new-from $lint_base})"
|
||||
make lint LINT_NEW_FROM="$lint_base"
|
||||
|
||||
echo "pre-commit ▶ coverage gate (make test-coverage-check) — builds and runs the"
|
||||
echo " pkg/core suites plus tests/e2e; can take a few minutes."
|
||||
make test-coverage-check
|
||||
fi
|
||||
|
||||
if [ "$ui_changed" -eq 1 ]; then
|
||||
echo "pre-commit ▶ React UI e2e + coverage gate (make test-ui-coverage-check) —"
|
||||
echo " rebuilds the UI + ui-test-server, runs the Playwright specs, and"
|
||||
echo " fails if line coverage regressed; can take a couple of minutes."
|
||||
make test-ui-coverage-check
|
||||
fi
|
||||
|
||||
echo "pre-commit ✓ all relevant checks passed"
|
||||
423
.github/backend-matrix.yml
vendored
423
.github/backend-matrix.yml
vendored
@@ -690,6 +690,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-rfdetr-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -703,6 +716,32 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -1491,6 +1530,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-rfdetr-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1504,6 +1556,19 @@ include:
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-rfdetr-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1517,6 +1582,32 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1530,6 +1621,32 @@ include:
|
||||
backend: "whisper"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-crispasr'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-parakeet-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2635,6 +2752,74 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# rfdetr-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-rfdetr-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-rfdetr-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-rfdetr-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-rfdetr-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-rfdetr-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2715,6 +2900,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-rfdetr-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "rfdetr-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# whisper
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -2730,6 +2928,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2744,6 +2956,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2757,6 +2983,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2770,6 +3009,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2784,6 +3036,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2798,6 +3064,20 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-crispasr'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
@@ -2811,6 +3091,19 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-crispasr'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2824,6 +3117,128 @@ include:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-crispasr'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "crispasr"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# parakeet-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-parakeet-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
platform-tag: 'amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-parakeet-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/arm64'
|
||||
platform-tag: 'arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-parakeet-cpp'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-parakeet-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-parakeet-cpp'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
runs-on: 'ubuntu-latest'
|
||||
skip-drivers: 'false'
|
||||
backend: "parakeet-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# acestep-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -3856,6 +4271,14 @@ includeDarwin:
|
||||
tag-suffix: "-metal-darwin-arm64-whisper"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "crispasr"
|
||||
tag-suffix: "-metal-darwin-arm64-crispasr"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "parakeet-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-parakeet-cpp"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "acestep-cpp"
|
||||
tag-suffix: "-metal-darwin-arm64-acestep-cpp"
|
||||
build-type: "metal"
|
||||
|
||||
13
.github/gallery-agent/main.go
vendored
13
.github/gallery-agent/main.go
vendored
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -113,6 +114,17 @@ func main() {
|
||||
fmt.Println("Searching for trending models on HuggingFace...")
|
||||
rawModels, err := client.GetTrending(searchTerm, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, hfapi.ErrRateLimited) {
|
||||
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: 0,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
|
||||
59
.github/workflows/backend_merge.yml
vendored
59
.github/workflows/backend_merge.yml
vendored
@@ -31,8 +31,20 @@ on:
|
||||
jobs:
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
# id-token: write is required for keyless cosign — the workflow
|
||||
# exchanges the GitHub OIDC token for a short-lived Fulcio cert that
|
||||
# signs each pushed manifest. Without this permission the runner
|
||||
# cannot mint the token, and `cosign sign` fails with "no token".
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
env:
|
||||
quay_username: ${{ secrets.quayUsername }}
|
||||
# cosign v2.4.x still gates --registry-referrers-mode=oci-1-1 behind
|
||||
# this flag. Without it, signing fails with:
|
||||
# invalid argument "oci-1-1" for "--registry-referrers-mode" flag:
|
||||
# in order to use mode "oci-1-1", you must set COSIGN_EXPERIMENTAL=1
|
||||
COSIGN_EXPERIMENTAL: '1'
|
||||
steps:
|
||||
# Sparse checkout: the merge job needs `.github/scripts/` (for the
|
||||
# keepalive cleanup script) but none of the source tree.
|
||||
@@ -57,6 +69,16 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@master
|
||||
|
||||
# cosign signs each pushed manifest list with --recursive so the
|
||||
# index and every per-arch entry get an attached Sigstore bundle.
|
||||
# Recent cosign releases always emit the new bundle format, so
|
||||
# there's no extra CLI flag to opt into it.
|
||||
- name: Install cosign
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: sigstore/cosign-installer@v3
|
||||
with:
|
||||
cosign-release: 'v2.4.1'
|
||||
|
||||
- name: Login to DockerHub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v4
|
||||
@@ -120,11 +142,25 @@ jobs:
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
if [ -z "$tags" ]; then
|
||||
echo "No quay.io tags from docker/metadata-action; skipping quay merge"
|
||||
else
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
||||
exit 0
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'quay.io/go-skynet/ci-cache@sha256:%s ' *)
|
||||
# Resolve the manifest-list digest (any tag points at it) so
|
||||
# cosign can sign by digest. Signing by tag would leave the
|
||||
# signature orphaned the next time the tag moves.
|
||||
first_tag=$(jq -cr '
|
||||
.tags | map(select(startswith("quay.io/"))) | .[0]
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||
# --recursive walks the list and signs every per-arch entry
|
||||
# too — clients that resolve a tag to a platform-specific
|
||||
# manifest before checking signatures need the per-arch
|
||||
# signatures, not just the list-level one.
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"quay.io/go-skynet/local-ai-backends@${digest}"
|
||||
|
||||
- name: Create manifest list and push (dockerhub)
|
||||
if: github.event_name != 'pull_request'
|
||||
@@ -139,11 +175,18 @@ jobs:
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
if [ -z "$tags" ]; then
|
||||
echo "No dockerhub tags from docker/metadata-action; skipping dockerhub merge"
|
||||
else
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'localai/localai-backends@sha256:%s ' *)
|
||||
exit 0
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags \
|
||||
$(printf 'localai/localai-backends@sha256:%s ' *)
|
||||
first_tag=$(jq -cr '
|
||||
.tags | map(select(startswith("localai/"))) | .[0]
|
||||
' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digest=$(docker buildx imagetools inspect "$first_tag" --format '{{.Manifest.Digest}}')
|
||||
cosign sign --yes --recursive \
|
||||
--registry-referrers-mode=oci-1-1 \
|
||||
"localai/localai-backends@${digest}"
|
||||
|
||||
- name: Inspect manifest
|
||||
if: github.event_name != 'pull_request'
|
||||
|
||||
12
.github/workflows/bump_deps.yaml
vendored
12
.github/workflows/bump_deps.yaml
vendored
@@ -30,6 +30,14 @@ jobs:
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/whisper/Makefile"
|
||||
- repository: "CrispStrobe/CrispASR"
|
||||
variable: "CRISPASR_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/crispasr/Makefile"
|
||||
- repository: "mudler/parakeet.cpp"
|
||||
variable: "PARAKEET_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/parakeet-cpp/Makefile"
|
||||
- repository: "leejet/stable-diffusion.cpp"
|
||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||
branch: "master"
|
||||
@@ -50,6 +58,10 @@ jobs:
|
||||
variable: "SAM3_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/sam3-cpp/Makefile"
|
||||
- repository: "mudler/rf-detr.cpp"
|
||||
variable: "RFDETR_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/rfdetr-cpp/Makefile"
|
||||
- repository: "predict-woo/qwen3-tts.cpp"
|
||||
variable: "QWEN3TTS_CPP_VERSION"
|
||||
branch: "main"
|
||||
|
||||
1
.github/workflows/image_build.yml
vendored
1
.github/workflows/image_build.yml
vendored
@@ -106,6 +106,7 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
1
.github/workflows/image_merge.yml
vendored
1
.github/workflows/image_merge.yml
vendored
@@ -80,6 +80,7 @@ jobs:
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{raw}}
|
||||
type=sha
|
||||
type=raw,value={{branch}}-{{date 'X'}}-{{sha}},enable={{is_default_branch}}
|
||||
flavor: |
|
||||
latest=${{ inputs.tag-latest }}
|
||||
suffix=${{ inputs.tag-suffix }},onlatest=true
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.9
|
||||
uses: securego/gosec@v2.27.1
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
|
||||
2
.github/workflows/stalebot.yml
vendored
2
.github/workflows/stalebot.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
if: github.repository == 'mudler/LocalAI'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v9
|
||||
- uses: actions/stale@eb5cf3af3ac0a1aa4c9c45633dd1ae542a27a899 # v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'
|
||||
|
||||
58
.github/workflows/test-extra.yml
vendored
58
.github/workflows/test-extra.yml
vendored
@@ -37,6 +37,7 @@ jobs:
|
||||
sglang: ${{ steps.detect.outputs.sglang }}
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
qwen3-tts-cpp: ${{ steps.detect.outputs.qwen3-tts-cpp }}
|
||||
rfdetr-cpp: ${{ steps.detect.outputs.rfdetr-cpp }}
|
||||
vibevoice-cpp: ${{ steps.detect.outputs.vibevoice-cpp }}
|
||||
localvqe: ${{ steps.detect.outputs.localvqe }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
@@ -45,6 +46,7 @@ jobs:
|
||||
speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }}
|
||||
sherpa-onnx: ${{ steps.detect.outputs.sherpa-onnx }}
|
||||
whisper: ${{ steps.detect.outputs.whisper }}
|
||||
parakeet-cpp: ${{ steps.detect.outputs.parakeet-cpp }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
@@ -632,6 +634,26 @@ jobs:
|
||||
- name: Build whisper backend image and run transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-whisper-transcription
|
||||
# Parakeet ASR via the parakeet-cpp backend (C++/ggml port of NeMo
|
||||
# Parakeet). Drives AudioTranscription (offline, with word timestamps) on
|
||||
# tdt_ctc-110m + the JFK 11s clip.
|
||||
tests-parakeet-cpp-grpc-transcription:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.parakeet-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Build parakeet-cpp backend image and run transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-parakeet-cpp-transcription
|
||||
# VITS TTS via the sherpa-onnx backend. Drives both TTS (file write) and
|
||||
# TTSStream (PCM chunks) on the e2e-backends harness.
|
||||
tests-sherpa-onnx-grpc-tts:
|
||||
@@ -843,6 +865,42 @@ jobs:
|
||||
- name: Test qwen3-tts-cpp
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/qwen3-tts-cpp test
|
||||
# Per-backend smoke for rfdetr-cpp: builds the .so + Go binary and runs
|
||||
# `make -C backend/go/rfdetr-cpp test`. test.sh fetches the small (~20 MB)
|
||||
# rfdetr-nano-q8_0 GGUF from the published mudler/rfdetr-cpp-nano HF repo
|
||||
# via curl and synthesises a tiny PNG to exercise the wire protocol.
|
||||
tests-rfdetr-cpp:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.rfdetr-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake curl libopenblas-dev
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
- name: Display Go version
|
||||
run: go version
|
||||
- name: Proto Dependencies
|
||||
run: |
|
||||
# Install protoc
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||
PATH="$PATH:$HOME/go/bin" make protogen-go
|
||||
- name: Build rfdetr-cpp
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/rfdetr-cpp
|
||||
- name: Test rfdetr-cpp
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/rfdetr-cpp test
|
||||
# Per-backend smoke for vibevoice-cpp: builds the .so + Go binary and
|
||||
# runs `make -C backend/go/vibevoice-cpp test`. test.sh auto-downloads
|
||||
# the published mudler/vibevoice.cpp-models bundle (TTS Q8_0 + ASR Q4_K
|
||||
|
||||
17
.github/workflows/test.yml
vendored
17
.github/workflows/test.yml
vendored
@@ -53,9 +53,22 @@ jobs:
|
||||
node-version: '22'
|
||||
- name: Build React UI
|
||||
run: make react-ui
|
||||
- name: Test
|
||||
# Runs the core suite with coverage and fails if total coverage dropped
|
||||
# below the committed baseline (coverage-baseline.txt). The gate is
|
||||
# strict — any decrease fails. Raise the baseline with
|
||||
# `make test-coverage-baseline` and commit it when coverage rises.
|
||||
- name: Test (with coverage gate)
|
||||
run: |
|
||||
PATH="$PATH:/root/go/bin" make --jobs 5 --output-sync=target test
|
||||
PATH="$PATH:/root/go/bin" make --jobs 5 --output-sync=target test-coverage-check
|
||||
- name: Upload coverage report
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-linux
|
||||
path: |
|
||||
coverage/coverage.out
|
||||
coverage/coverage.html
|
||||
if-no-files-found: ignore
|
||||
- name: Setup tmate session if tests fail
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
|
||||
28
.github/workflows/tests-ui-e2e.yml
vendored
28
.github/workflows/tests-ui-e2e.yml
vendored
@@ -37,6 +37,10 @@ jobs:
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '22'
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: '1.3.11'
|
||||
- name: Proto Dependencies
|
||||
run: |
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
|
||||
@@ -48,16 +52,12 @@ jobs:
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential libopus-dev
|
||||
- name: Build UI test server
|
||||
run: PATH="$PATH:$HOME/go/bin" make build-ui-test-server
|
||||
- name: Install Playwright
|
||||
working-directory: core/http/react-ui
|
||||
run: |
|
||||
npm install
|
||||
npx playwright install --with-deps chromium
|
||||
- name: Run Playwright tests
|
||||
working-directory: core/http/react-ui
|
||||
run: npx playwright test
|
||||
# Builds an instrumented UI bundle, runs the Playwright specs, and fails
|
||||
# if line coverage regressed beyond the jitter tolerance (the gate is
|
||||
# in `make test-ui-coverage-check`). PLAYWRIGHT_CHROMIUM_PATH is unset
|
||||
# here, so scripts/ensure-playwright-browser.sh installs Chromium via apt.
|
||||
- name: Run UI e2e + coverage gate
|
||||
run: PATH="$PATH:$HOME/go/bin" make test-ui-coverage-check
|
||||
- name: Upload Playwright report
|
||||
if: ${{ failure() }}
|
||||
uses: actions/upload-artifact@v7
|
||||
@@ -65,6 +65,14 @@ jobs:
|
||||
name: playwright-report
|
||||
path: core/http/react-ui/playwright-report/
|
||||
retention-days: 7
|
||||
- name: Upload UI coverage report
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: ui-coverage
|
||||
path: core/http/react-ui/coverage/
|
||||
if-no-files-found: ignore
|
||||
retention-days: 7
|
||||
- name: Setup tmate session if tests fail
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
|
||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -26,6 +26,10 @@ go-bert
|
||||
LocalAI
|
||||
/local-ai
|
||||
/local-ai-launcher
|
||||
# Root-level build artifacts when running `go build ./...` against
|
||||
# Go backend packages whose main lives under backend/go/.
|
||||
/cloud-proxy
|
||||
/local-store
|
||||
# prevent above rules from omitting the helm chart
|
||||
!charts/*
|
||||
# prevent above rules from omitting the api/localai folder
|
||||
@@ -66,10 +70,17 @@ docs/static/gallery.html
|
||||
# per-developer customization files for the development container
|
||||
.devcontainer/customization/*
|
||||
|
||||
# Coverage profiles (the committed baseline is coverage-baseline.txt)
|
||||
/coverage/
|
||||
|
||||
# React UI build artifacts (keep placeholder dist/index.html)
|
||||
core/http/react-ui/node_modules/
|
||||
core/http/react-ui/dist
|
||||
|
||||
# React UI coverage (vite-plugin-istanbul + nyc, via `make test-ui-coverage`)
|
||||
core/http/react-ui/.nyc_output/
|
||||
core/http/react-ui/coverage/
|
||||
|
||||
# Extracted backend binaries for container-based testing
|
||||
local-backends/
|
||||
|
||||
@@ -77,3 +88,6 @@ local-backends/
|
||||
tests/e2e-ui/ui-test-server
|
||||
core/http/react-ui/playwright-report/
|
||||
core/http/react-ui/test-results/
|
||||
|
||||
# Local worktrees
|
||||
.worktrees/
|
||||
|
||||
@@ -46,8 +46,81 @@ linters:
|
||||
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.Fail. See .agents/coding-style.md.'
|
||||
- pattern: '^t\.FailNow$'
|
||||
msg: 'LocalAI tests must use Ginkgo/Gomega; use Fail(...) instead of t.FailNow. See .agents/coding-style.md.'
|
||||
# In-process config should flow through ApplicationConfig / kong-bound
|
||||
# CLI flags, not via os.Getenv. The CLI layer is the legitimate
|
||||
# env→struct boundary (kong's `env:"..."` tag); anything deeper that
|
||||
# reads env directly leaks process state into business logic and
|
||||
# makes flags impossible to test or override per-request. Backend
|
||||
# subprocesses, the system/capabilities probe, and a few places that
|
||||
# read non-LocalAI env vars (HOME, PATH, AUTH_TOKEN passed by parent)
|
||||
# are exempt — see linters.exclusions.rules below.
|
||||
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
|
||||
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
|
||||
# Outbound HTTP must go through pkg/httpclient, which refuses redirects
|
||||
# by default and sets a TLS floor. The std-library default client and
|
||||
# the http.Get/Post/... convenience helpers follow redirects (up to 10)
|
||||
# and, on a cross-host redirect, forward custom credential headers such
|
||||
# as Anthropic's x-api-key to the redirect target — leaking the secret
|
||||
# (GHSA-3mj3-57v2-4636). forbidigo can't precisely match the
|
||||
# `&http.Client{}` composite literal without also flagging legitimate
|
||||
# `*http.Client` type references, so that form is enforced by
|
||||
# convention + review; these two patterns catch the implicit-default
|
||||
# client, which is the common footgun.
|
||||
- pattern: '^http\.DefaultClient$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.DefaultClient — the std client follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
- pattern: '^http\.(Get|Post|PostForm|Head)$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.Get/Post/PostForm/Head — these use http.DefaultClient, which follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
exclusions:
|
||||
paths:
|
||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||
- 'backend/go/whisper/sources'
|
||||
- 'docs/'
|
||||
rules:
|
||||
# CLI entry points: kong's `env:"..."` tag is the legitimate env→struct
|
||||
# boundary, and a handful of subcommands legitimately propagate values
|
||||
# to spawned subprocesses (LLAMACPP_GRPC_SERVERS, MLX hostfile, ...).
|
||||
- path: ^core/cli/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Backend subprocesses are independent binaries with their own env
|
||||
# surface; they're not "in-process config" of the LocalAI server.
|
||||
- path: ^backend/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# System capability probe reads HOME, PATH-style vars to discover
|
||||
# GPUs, default paths, etc. — not LocalAI config.
|
||||
- path: ^pkg/system/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# gRPC server reads AUTH_TOKEN passed in by the parent process at spawn
|
||||
# time; model.Loader sets/inherits env to communicate with subprocesses.
|
||||
- path: ^pkg/grpc/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
- path: ^pkg/model/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Top-level main binaries (local-ai, launcher) are entry points.
|
||||
- path: ^cmd/
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# Tests legitimately read $HOME, $TMPDIR, and gating env vars
|
||||
# (LOCALAI_COSIGN_LIVE, etc.) to skip live-network specs.
|
||||
- path: _test\.go$
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# pkg/httpclient is the sanctioned home for outbound HTTP clients; it
|
||||
# necessarily references net/http directly.
|
||||
- path: ^pkg/httpclient/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Tests drive local httptest servers where redirect/TLS hardening is
|
||||
# irrelevant; the std client is fine there.
|
||||
- path: _test\.go$
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Vendored upstream whisper.cpp Go bindings are a separate module and
|
||||
# cannot import pkg/httpclient.
|
||||
- path: ^backend/go/whisper/sources/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
|
||||
@@ -31,9 +31,11 @@ LocalAI follows the Linux kernel project's [guidelines for AI coding assistants]
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
||||
| [.agents/localai-assistant-mcp.md](.agents/localai-assistant-mcp.md) | LocalAI Assistant chat modality — adding admin tools to the in-process MCP server, editing skill prompts, keeping REST + MCP + skills in sync |
|
||||
| [.agents/backend-signing.md](.agents/backend-signing.md) | Backend OCI image signing (keyless cosign + sigstore-go) — producer-side CI setup, consumer-side gallery `verification:` block, strict mode (`LOCALAI_REQUIRE_BACKEND_INTEGRITY`), revocation via `not_before` |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
- **Git hooks & coverage gates**: Run `make install-hooks` once per clone so the pre-commit lint + coverage gates run. **Never bypass them with `git commit --no-verify`, and never lower a coverage baseline or widen a gate's tolerance to turn a red gate green** — the coverage ratchet only moves up. If a change drops coverage, add tests to raise it (e.g. render-smoke specs). See [.agents/building-and-testing.md](.agents/building-and-testing.md).
|
||||
- **Logging**: Use `github.com/mudler/xlog` (same API as slog)
|
||||
- **Go style**: Prefer `any` over `interface{}`
|
||||
- **Comments**: Explain *why*, not *what*
|
||||
|
||||
@@ -198,6 +198,7 @@ For AI-assisted development, see [`AGENTS.md`](AGENTS.md) (or the equivalent [`C
|
||||
|
||||
- Prefer modern Go idioms — for example, use `any` instead of `interface{}`.
|
||||
- Use [`golangci-lint`](https://golangci-lint.run) to catch common issues before submitting a PR.
|
||||
- Run `make install-hooks` once per clone to enable the pre-commit hook: Go changes run `make lint` + the coverage gate (`make test-coverage-check`); `core/http/react-ui/` changes run the Playwright e2e suite (`make test-ui`). Bypass a single commit with `git commit --no-verify`.
|
||||
- Use [`github.com/mudler/xlog`](https://github.com/mudler/xlog) for logging (same API as `slog`). Do not use `fmt.Println` or the standard `log` package for operational logging.
|
||||
- Use tab indentation for Go files (as defined in `.editorconfig`).
|
||||
|
||||
@@ -265,6 +266,12 @@ The e2e tests run LocalAI in a Docker container and exercise the API:
|
||||
make test-e2e
|
||||
```
|
||||
|
||||
### React UI tests and coverage
|
||||
|
||||
The React UI (`core/http/react-ui/`) is covered by Playwright e2e specs, gated by a **monotonic line-coverage ratchet** (`make test-ui-coverage-check`, run in CI and pre-commit). The metric is non-deterministic — a fast local box reads higher than a slow CI runner for the same code — so a small tolerance is unavoidable.
|
||||
|
||||
**If your change lowers UI coverage, raise it back by adding specs — do not widen the tolerance or hand-lower the baseline.** A *render-smoke* spec (navigate to a page, assert its header is visible) cheaply covers an entire lazy page. See `core/http/react-ui/e2e/page-render-smoke.spec.js` and the full policy in [.agents/building-and-testing.md](.agents/building-and-testing.md#react-ui-coverage).
|
||||
|
||||
### Running E2E container tests
|
||||
|
||||
These tests build a standard LocalAI Docker image and run it with pre-configured model configs to verify that most endpoints work correctly:
|
||||
|
||||
172
Makefile
172
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/crispasr backends/parakeet-cpp backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/rfdetr-cpp backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/vibevoice-cpp backends/localvqe backends/tinygrad backends/sherpa-onnx backends/ds4 backends/ds4-darwin backends/liquid-audio
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -69,10 +69,41 @@ else
|
||||
GORELEASER=$(shell which goreleaser)
|
||||
endif
|
||||
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/...
|
||||
TEST_PATHS?=./api/... ./pkg/... ./core/... ./backend/go/cloud-proxy/... ./backend/go/local-store/...
|
||||
|
||||
## Coverage output and the committed baseline that CI compares against.
|
||||
## The gate is strict: total coverage must never decrease (no tolerance).
|
||||
## covermode=atomic makes line coverage deterministic regardless of test
|
||||
## ordering or flake retries, so there is no run-to-run jitter to absorb.
|
||||
COVERAGE_DIR?=$(abspath ./coverage)
|
||||
COVERAGE_PROFILE?=$(COVERAGE_DIR)/coverage.out
|
||||
COVERAGE_BASELINE?=coverage-baseline.txt
|
||||
## Coverage is collected one recursive root at a time and merged (see
|
||||
## scripts/run-coverage.sh): passing several recursive roots to a single
|
||||
## ginkgo invocation only keeps one root's coverprofile. Mirrors TEST_PATHS
|
||||
## minus ./api (which doesn't exist).
|
||||
COVERAGE_ROOTS?=./pkg ./core
|
||||
## Build tags for the coverage build. `auth` is required to compile the real
|
||||
## auth implementation and its ~150 `//go:build auth` tests (otherwise they're
|
||||
## invisible and the gate scores auth against a stub). `debug` matches `test`.
|
||||
COVERAGE_TAGS?=debug auth
|
||||
## Coverage is attributed to these packages via --coverpkg, so the in-process
|
||||
## integration suites (COVERAGE_E2E_ROOTS) credit the core/http handlers they
|
||||
## drive over HTTP — not just their own test package.
|
||||
COVERAGE_COVERPKG?=github.com/mudler/LocalAI/core/...,github.com/mudler/LocalAI/pkg/...
|
||||
## In-process integration suites folded into coverage. Run non-recursively
|
||||
## (excludes tests/e2e/distributed, which needs containers) with the mock
|
||||
## backend built by prepare-test. real-models specs need a downloaded model,
|
||||
## so they're filtered out. NOTE: tests/integration is intentionally NOT here —
|
||||
## it needs the local-store backend built (`make backends/local-store`), which
|
||||
## the coverage CI job doesn't do.
|
||||
COVERAGE_E2E_ROOTS?=./tests/e2e
|
||||
COVERAGE_E2E_LABELS?=!real-models
|
||||
## Drop generated protobuf from the denominator (it has no tests by design).
|
||||
COVERAGE_EXCLUDE_RE?=grpc/proto/.*[.]pb[.]go
|
||||
|
||||
|
||||
.PHONY: all test build vendor lint lint-all
|
||||
.PHONY: all test test-coverage test-coverage-baseline test-coverage-check test-ui test-ui-coverage-baseline test-ui-coverage-check install-hooks build vendor lint lint-all
|
||||
|
||||
all: help
|
||||
|
||||
@@ -170,6 +201,36 @@ test: prepare-test
|
||||
OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
|
||||
|
||||
## Runs the core suite ($(TEST_PATHS)) with statement-coverage instrumentation
|
||||
## and writes a merged profile to $(COVERAGE_PROFILE). Deliberately omits
|
||||
## --fail-fast so a single failure doesn't truncate the coverage number, and
|
||||
## uses covermode=atomic so the result is deterministic. Prints the total.
|
||||
test-coverage: prepare-test
|
||||
@echo 'Running tests with coverage'
|
||||
GINKGO_TAGS="$(COVERAGE_TAGS)" \
|
||||
COVERAGE_COVERPKG="$(COVERAGE_COVERPKG)" \
|
||||
COVERAGE_E2E_ROOTS="$(COVERAGE_E2E_ROOTS)" \
|
||||
COVERAGE_E2E_LABELS="$(COVERAGE_E2E_LABELS)" \
|
||||
COVERAGE_EXCLUDE_RE='$(COVERAGE_EXCLUDE_RE)' \
|
||||
OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \
|
||||
scripts/run-coverage.sh $(COVERAGE_DIR) $(COVERAGE_PROFILE) $(TEST_FLAKES) $(COVERAGE_ROOTS)
|
||||
@$(GOCMD) tool cover -html=$(COVERAGE_PROFILE) -o $(COVERAGE_DIR)/coverage.html
|
||||
@$(GOCMD) tool cover -func=$(COVERAGE_PROFILE) | tail -n1
|
||||
|
||||
## Writes the current total coverage to $(COVERAGE_BASELINE). Run this (and
|
||||
## commit the result) whenever a change legitimately raises coverage so the
|
||||
## ratchet moves up. Never lower it by hand.
|
||||
test-coverage-baseline: test-coverage
|
||||
@$(GOCMD) tool cover -func=$(COVERAGE_PROFILE) | awk '/^total:/{gsub(/%/,"",$$NF); print $$NF}' > $(COVERAGE_BASELINE)
|
||||
@echo "Saved coverage baseline: $$(cat $(COVERAGE_BASELINE))%"
|
||||
|
||||
## CI gate: fails if total coverage dropped more than COVERAGE_TOLERANCE
|
||||
## (default 0.5pp) below the committed baseline. A small tolerance absorbs the
|
||||
## run-to-run jitter from the in-process tests/e2e suite folded in via
|
||||
## --coverpkg (timing-dependent which handler lines execute).
|
||||
test-coverage-check: test-coverage
|
||||
@scripts/coverage-check.sh $(COVERAGE_PROFILE) $(COVERAGE_BASELINE)
|
||||
|
||||
########################################################
|
||||
## Lint
|
||||
########################################################
|
||||
@@ -185,12 +246,17 @@ test: prepare-test
|
||||
## everything else automatically, so new packages are scanned by default.
|
||||
LINT_EXCLUDE_DIRS_RE=/(backend/go/(piper|silero-vad|llm)|cmd/launcher)(/|$$)
|
||||
|
||||
## Set LINT_NEW_FROM to a git ref to override .golangci.yml's
|
||||
## new-from-merge-base (origin/master). Useful from a fork clone where
|
||||
## origin/master is stale relative to the canonical repo — the pre-commit
|
||||
## hook passes the resolved upstream ref here so local lint matches CI.
|
||||
LINT_NEW_FROM?=
|
||||
lint:
|
||||
@command -v golangci-lint >/dev/null 2>&1 || { \
|
||||
echo 'golangci-lint not installed. Install: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest'; \
|
||||
exit 1; \
|
||||
}
|
||||
golangci-lint run $$(go list -e -f '{{.Dir}}' ./... | grep -vE '$(LINT_EXCLUDE_DIRS_RE)')
|
||||
golangci-lint run $(if $(LINT_NEW_FROM),--new-from-merge-base=$(LINT_NEW_FROM),) $$(go list -e -f '{{.Dir}}' ./... | grep -vE '$(LINT_EXCLUDE_DIRS_RE)')
|
||||
|
||||
## Like `lint` but reports every issue, including the pre-existing baseline
|
||||
## that `lint` ignores via .golangci.yml's new-from-merge-base. Use this to
|
||||
@@ -202,6 +268,17 @@ lint-all:
|
||||
}
|
||||
golangci-lint run --new=false --new-from-merge-base= --new-from-rev= $$(go list -e -f '{{.Dir}}' ./... | grep -vE '$(LINT_EXCLUDE_DIRS_RE)')
|
||||
|
||||
########################################################
|
||||
## Git hooks
|
||||
########################################################
|
||||
## Points git at the versioned .githooks/ directory so the pre-commit hook
|
||||
## (lint + coverage gate) runs locally. Run once per clone. Undo with:
|
||||
## `git config --unset core.hooksPath`. Skip a single commit with
|
||||
## `git commit --no-verify`.
|
||||
install-hooks:
|
||||
git config core.hooksPath .githooks
|
||||
@echo 'Installed git hooks: core.hooksPath -> .githooks (pre-commit runs lint + test-coverage-check on Go changes)'
|
||||
|
||||
########################################################
|
||||
## E2E AIO tests (uses standard image with pre-configured models)
|
||||
########################################################
|
||||
@@ -232,13 +309,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
|
||||
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
|
||||
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
|
||||
test-e2e-distributed: protogen-go
|
||||
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
# test-e2e-distributed.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
@@ -268,12 +352,13 @@ prepare-e2e:
|
||||
run-e2e-image:
|
||||
docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests
|
||||
|
||||
test-e2e: build-mock-backend prepare-e2e run-e2e-image
|
||||
test-e2e: build-mock-backend build-cloud-proxy-backend prepare-e2e run-e2e-image
|
||||
@echo 'Running e2e tests'
|
||||
BUILD_TYPE=$(BUILD_TYPE) \
|
||||
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
|
||||
$(MAKE) clean-mock-backend
|
||||
$(MAKE) clean-cloud-proxy-backend
|
||||
$(MAKE) teardown-e2e
|
||||
docker rmi localai-tests
|
||||
|
||||
@@ -480,6 +565,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/insightface
|
||||
$(MAKE) -C backend/python/speaker-recognition
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
$(MAKE) -C backend/go/rfdetr-cpp
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
@@ -506,6 +592,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/insightface test
|
||||
$(MAKE) -C backend/python/speaker-recognition test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
$(MAKE) -C backend/go/rfdetr-cpp test
|
||||
|
||||
##
|
||||
## End-to-end gRPC tests that exercise a built backend container image.
|
||||
@@ -911,6 +998,19 @@ test-extra-backend-whisper-transcription: docker-build-whisper
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Audio transcription wrapper for the parakeet-cpp (parakeet.cpp ggml port)
|
||||
## backend. Mirrors test-extra-backend-whisper-transcription: drives the
|
||||
## AudioTranscription / AudioTranscriptionStream RPCs against a published
|
||||
## Parakeet GGUF using the JFK 11s clip from whisper.cpp's CI samples. Not
|
||||
## part of the default test suite - run explicitly once the pinned model URL
|
||||
## is reachable.
|
||||
test-extra-backend-parakeet-cpp-transcription: docker-build-parakeet-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:parakeet-cpp \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/mudler/parakeet-cpp-gguf/resolve/main/tdt_ctc-110m-f16.gguf \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## LocalVQE audio transform (joint AEC + noise suppression + dereverb).
|
||||
## Exercises the audio_transform capability end-to-end: batch transform
|
||||
## of a real WAV fixture and bidi streaming of synthetic silent frames.
|
||||
@@ -1064,10 +1164,13 @@ BACKEND_DS4 = ds4|ds4|.|false|false
|
||||
# Golang backends
|
||||
BACKEND_PIPER = piper|golang|.|false|true
|
||||
BACKEND_LOCAL_STORE = local-store|golang|.|false|true
|
||||
BACKEND_CLOUD_PROXY = cloud-proxy|golang|.|false|true
|
||||
BACKEND_HUGGINGFACE = huggingface|golang|.|false|true
|
||||
BACKEND_SILERO_VAD = silero-vad|golang|.|false|true
|
||||
BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true
|
||||
BACKEND_WHISPER = whisper|golang|.|false|true
|
||||
BACKEND_CRISPASR = crispasr|golang|.|false|true
|
||||
BACKEND_PARAKEET_CPP = parakeet-cpp|golang|.|false|true
|
||||
BACKEND_VOXTRAL = voxtral|golang|.|false|true
|
||||
BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true
|
||||
BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true
|
||||
@@ -1117,6 +1220,7 @@ BACKEND_KOKOROS = kokoros|rust|.|false|true
|
||||
|
||||
# C++ backends (Go wrapper with purego)
|
||||
BACKEND_SAM3_CPP = sam3-cpp|golang|.|false|true
|
||||
BACKEND_RFDETR_CPP = rfdetr-cpp|golang|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
@@ -1149,10 +1253,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DS4)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CLOUD_PROXY)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CRISPASR)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PARAKEET_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_OPUS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS)))
|
||||
@@ -1195,13 +1302,14 @@ $(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_RFDETR_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-ds4 docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-crispasr docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-liquid-audio docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-rfdetr-cpp docker-build-qwen3-tts-cpp docker-build-vibevoice-cpp docker-build-localvqe docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx docker-build-cloud-proxy
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -1213,6 +1321,12 @@ build-mock-backend: protogen-go
|
||||
clean-mock-backend:
|
||||
rm -f tests/e2e/mock-backend/mock-backend
|
||||
|
||||
build-cloud-proxy-backend: protogen-go
|
||||
$(GOCMD) build -o tests/e2e/mock-backend/cloud-proxy ./backend/go/cloud-proxy
|
||||
|
||||
clean-cloud-proxy-backend:
|
||||
rm -f tests/e2e/mock-backend/cloud-proxy
|
||||
|
||||
########################################################
|
||||
### UI E2E Test Server
|
||||
########################################################
|
||||
@@ -1223,6 +1337,50 @@ build-ui-test-server: build-mock-backend react-ui protogen-go
|
||||
test-ui-e2e: build-ui-test-server
|
||||
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
|
||||
|
||||
## Optional Playwright worker count for the UI e2e targets below. Pass
|
||||
## UI_TEST_WORKERS=N (e.g. `make test-ui-coverage UI_TEST_WORKERS=20`) to
|
||||
## override Playwright's default (cores/2). Empty by default so Playwright
|
||||
## picks its own worker count.
|
||||
UI_TEST_WORKERS ?=
|
||||
PLAYWRIGHT_WORKERS_FLAG = $(if $(UI_TEST_WORKERS),--workers=$(UI_TEST_WORKERS),)
|
||||
|
||||
## Fast Playwright e2e run used by the pre-commit hook on React UI changes.
|
||||
## Force-rebuilds the (non-instrumented) dist so the suite tests the working
|
||||
## tree — not a stale dist the `react-ui` skip-guard would leave — re-embeds
|
||||
## it into ui-test-server, and runs the specs. Uses the nix-provided browser
|
||||
## when PLAYWRIGHT_CHROMIUM_PATH is set (flake dev shell), else falls back to
|
||||
## downloading it as `test-ui-e2e` does.
|
||||
test-ui: build-mock-backend protogen-go
|
||||
cd core/http/react-ui && bun install && bun run build
|
||||
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
|
||||
cd core/http/react-ui && sh $(CURDIR)/scripts/ensure-playwright-browser.sh && bunx playwright test $(PLAYWRIGHT_WORKERS_FLAG)
|
||||
|
||||
## React UI code coverage from the Playwright e2e suite. Builds a
|
||||
## NON-instrumented bundle with source maps (COVERAGE_V8=true), re-embeds it
|
||||
## into the ui-test-server (the dist is //go:embed'ed at compile time), runs the
|
||||
## Playwright specs which collect native Chromium V8 coverage (PW_V8_COVERAGE=1)
|
||||
## — far cheaper than istanbul's build-time counters (~40% faster end-to-end) —
|
||||
## convert it to istanbul via v8-to-istanbul in the coverage fixture, and write
|
||||
## an nyc report to core/http/react-ui/coverage/. Removes the dist afterwards so
|
||||
## normal builds aren't served source-mapped assets. (The legacy istanbul path
|
||||
## still exists: `bun run build:coverage` + unset PW_V8_COVERAGE.)
|
||||
test-ui-coverage: build-mock-backend protogen-go
|
||||
trap 'rm -rf "$(CURDIR)/core/http/react-ui/dist"' EXIT; \
|
||||
( cd core/http/react-ui && bun install && bun run build:coverage-v8 ) && \
|
||||
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui && \
|
||||
( cd core/http/react-ui && rm -rf .nyc_output coverage && \
|
||||
sh $(CURDIR)/scripts/ensure-playwright-browser.sh && \
|
||||
PW_V8_COVERAGE=1 bunx playwright test $(PLAYWRIGHT_WORKERS_FLAG) && bun run coverage:report )
|
||||
|
||||
## UI coverage baseline (committed) and the strict gate that compares against
|
||||
## it — the React mirror of test-coverage-baseline / test-coverage-check.
|
||||
test-ui-coverage-baseline: test-ui-coverage
|
||||
@node -e 'const fs=require("fs");process.stdout.write(String(JSON.parse(fs.readFileSync("core/http/react-ui/coverage/coverage-summary.json")).total.lines.pct))' > core/http/react-ui/coverage-baseline.txt
|
||||
@echo "Saved UI coverage baseline: $$(cat core/http/react-ui/coverage-baseline.txt)% lines"
|
||||
|
||||
test-ui-coverage-check: test-ui-coverage
|
||||
sh $(CURDIR)/scripts/ui-coverage-check.sh core/http/react-ui/coverage/coverage-summary.json core/http/react-ui/coverage-baseline.txt
|
||||
|
||||
test-ui-e2e-docker:
|
||||
docker build -t localai-ui-e2e -f tests/e2e-ui/Dockerfile .
|
||||
docker run --rm localai-ui-e2e
|
||||
|
||||
35
README.md
35
README.md
@@ -31,12 +31,18 @@
|
||||
|
||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||
|
||||
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||
- **36+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first** — your data never leaves your infrastructure
|
||||
**A small core, not a bundle.** Each backend wraps a best-in-class engine (llama.cpp, vLLM, whisper.cpp, stable-diffusion, MLX...) in its own image, pulled only when a model needs it. You install nothing you don't use.
|
||||
|
||||
- **Composable by design**: backends are separate and pulled on demand, so you install only what your model needs
|
||||
- **Open and extensible**: load any model, or build your own backend in any language against an open interface
|
||||
- **Drop-in API compatibility**: OpenAI, Anthropic, and ElevenLabs APIs across every backend
|
||||
- **Any model, any modality**: LLMs, vision, voice, image, and video behind one API
|
||||
- **Any hardware**: NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready**: API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents**: autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first**: your data never leaves your infrastructure
|
||||
|
||||

|
||||
|
||||
Created by [Ettore Di Giacinto](https://github.com/mudler) and maintained by the [LocalAI team](#team).
|
||||
|
||||
@@ -149,8 +155,10 @@ For more details, see the [Getting Started guide](https://localai.io/basics/gett
|
||||
|
||||
## Latest News
|
||||
|
||||
- **April 2026**: [Voice recognition](https://github.com/mudler/LocalAI/pull/9500), [Face recognition, identification & liveness detection](https://github.com/mudler/LocalAI/pull/9480), [Ollama API compatibility](https://github.com/mudler/LocalAI/pull/9284), [Video generation in stable-diffusion.ggml](https://github.com/mudler/LocalAI/pull/9420), [Backend versioning with auto-upgrade](https://github.com/mudler/LocalAI/pull/9315), [Pin models & load-on-demand toggle](https://github.com/mudler/LocalAI/pull/9309), [Universal model importer](https://github.com/mudler/LocalAI/pull/9466), new backends: [sglang](https://github.com/mudler/LocalAI/pull/9359), [ik-llama-cpp](https://github.com/mudler/LocalAI/pull/9326), [TurboQuant](https://github.com/mudler/LocalAI/pull/9355), [sam.cpp](https://github.com/mudler/LocalAI/pull/9288), [Kokoros](https://github.com/mudler/LocalAI/pull/9212), [qwen3tts.cpp](https://github.com/mudler/LocalAI/pull/9316), [tinygrad multimodal](https://github.com/mudler/LocalAI/pull/9364)
|
||||
- **March 2026**: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
||||
- **May 2026**: **LocalAI 4.3.0** - `llama.cpp` [prompt cache on by default](https://github.com/mudler/LocalAI/pull/9925) (repeated system prompts collapse from minutes to seconds), [keyless cosign signing of backend OCI images](https://github.com/mudler/LocalAI/pull/9823), [per-API-key + per-user usage attribution](https://github.com/mudler/LocalAI/pull/9920), Distributed v3 with [per-request replica routing](https://github.com/mudler/LocalAI/pull/9968). [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.3.0)
|
||||
- **May 2026**: **LocalAI 4.2.0** - LocalAI sees and hears: [voice recognition](https://github.com/mudler/LocalAI/pull/9500), [face recognition + antispoofing liveness](https://github.com/mudler/LocalAI/pull/9480), speaker diarization. Plus [drop-in Ollama API](https://github.com/mudler/LocalAI/pull/9284), [video generation](https://github.com/mudler/LocalAI/pull/9420), redesigned UI with i18n + admin-configurable branding, vLLM at feature parity with llama.cpp, and 11 new backends. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.2.0)
|
||||
- **April 2026**: **LocalAI 4.1.0** - LocalAI becomes a control tower: distributed cluster mode with VRAM-aware smart routing + autoscaling, multi-user platform with OIDC and API keys, per-user quotas with predictive analytics, in-UI fine-tuning with TRL (auto-export to GGUF), on-the-fly quantization backend, visual pipeline editor. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.1.0)
|
||||
- **March 2026**: **LocalAI 4.0.0** - native agentic orchestration with the new [Agenthub](https://agenthub.localai.io) community hub, full React UI rewrite with Canvas mode, [MCP Apps + client-side](https://github.com/mudler/LocalAI/pull/8947) with tool streaming, [WebRTC realtime audio](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed](https://github.com/mudler/LocalAI/pull/8801). [Release notes](https://github.com/mudler/LocalAI/releases/tag/v4.0.0)
|
||||
- **February 2026**: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
||||
- **January 2026**: **LocalAI 3.10.0** — Anthropic API support, Open Responses API, video & image generation (LTX-2), unified GPU backends, tool streaming, Moonshine, Pocket-TTS. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0)
|
||||
- **December 2025**: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic multi-GPU model fitting (llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
|
||||
@@ -236,11 +244,22 @@ A huge thank you to our generous sponsors who support this project covering CI e
|
||||
<a href="https://www.spectrocloud.com/" target="blank">
|
||||
<img height="200" src="https://github.com/user-attachments/assets/72eab1dd-8b93-4fc0-9ade-84db49f24962">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>
|
||||
Past sponsors
|
||||
</summary>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.premai.io/" target="blank">
|
||||
<img height="200" src="https://github.com/mudler/LocalAI/assets/2420543/42e4ca83-661e-4f79-8e46-ae43689683d6"> <br>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
</details>
|
||||
|
||||
### Individual sponsors
|
||||
|
||||
A special thanks to individual sponsors, a full list is on [GitHub](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler). Special shout out to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
|
||||
|
||||
@@ -37,6 +37,22 @@ service Backend {
|
||||
|
||||
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||
|
||||
// TokenClassify runs a token-classification (NER) model on the
|
||||
// supplied text and returns each detected entity span. Used by the
|
||||
// PII redactor's optional NER tier — the regex tier still handles
|
||||
// formatted hits cheaply, while this catches names, locations, and
|
||||
// other unformatted PII that regex misses.
|
||||
rpc TokenClassify(TokenClassifyRequest) returns (TokenClassifyResponse) {}
|
||||
|
||||
// Score evaluates the model's joint log-probability of each
|
||||
// supplied candidate continuation given a shared prompt. The
|
||||
// prompt's KV cache is computed once and reused across candidates.
|
||||
// Used for routing-policy multi-label classification, reranking,
|
||||
// calibrated confidence, and reward-model scoring — any task where
|
||||
// the consumer wants the model's confidence in a pre-specified
|
||||
// continuation rather than a generated one.
|
||||
rpc Score(ScoreRequest) returns (ScoreResponse) {}
|
||||
|
||||
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
|
||||
|
||||
rpc VAD(VADRequest) returns (VADResponse) {}
|
||||
@@ -68,6 +84,23 @@ service Backend {
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
|
||||
// Forward proxies a raw HTTP request to an upstream provider. The
|
||||
// cloud-proxy backend implements this for passthrough-mode model
|
||||
// configs: the client wire format is preserved end-to-end (no
|
||||
// translation through internal proto), which means new provider
|
||||
// fields work the day they ship. Translation-mode proxies use the
|
||||
// standard Predict/PredictStream RPCs instead. Backends that don't
|
||||
// support this return UNIMPLEMENTED.
|
||||
//
|
||||
// The request is bidirectionally streamed so large bodies can flow
|
||||
// without buffering. In practice the first ForwardRequest carries
|
||||
// path, method, headers, and the initial body chunk; subsequent
|
||||
// messages append body chunks. The first ForwardReply carries the
|
||||
// upstream status and response headers; subsequent messages stream
|
||||
// body chunks (SSE frames or chunked transfer). Cancellation of the
|
||||
// gRPC context closes the upstream connection.
|
||||
rpc Forward(stream ForwardRequest) returns (stream ForwardReply) {}
|
||||
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -81,6 +114,76 @@ message MetricsResponse {
|
||||
int32 prompt_tokens_processed = 5;
|
||||
}
|
||||
|
||||
// TokenClassifyRequest carries the text to classify plus an optional
|
||||
// score threshold. The transformers backend interprets threshold as
|
||||
// the minimum confidence to include in the response; 0 = include all.
|
||||
message TokenClassifyRequest {
|
||||
string text = 1;
|
||||
float threshold = 2;
|
||||
}
|
||||
|
||||
// TokenClassifyEntity is one detected entity span. Byte offsets are
|
||||
// into the original UTF-8 text — start..end is a half-open range that
|
||||
// addresses the substring corresponding to entity_group.
|
||||
//
|
||||
// entity_group follows HuggingFace's aggregated-tag convention (e.g.
|
||||
// "PER", "LOC", "ORG", or a PII-specific label like "EMAIL" /
|
||||
// "SSN" depending on the model). The redactor's per-pattern action
|
||||
// map keys off this string.
|
||||
message TokenClassifyEntity {
|
||||
string entity_group = 1;
|
||||
int32 start = 2;
|
||||
int32 end = 3;
|
||||
float score = 4;
|
||||
string text = 5;
|
||||
}
|
||||
|
||||
message TokenClassifyResponse {
|
||||
repeated TokenClassifyEntity entities = 1;
|
||||
}
|
||||
|
||||
// ScoreRequest carries one shared prompt and one or more continuations
|
||||
// to score against it. The backend tokenises the prompt once and reuses
|
||||
// the resulting KV cache across all candidates in this request.
|
||||
message ScoreRequest {
|
||||
string prompt = 1;
|
||||
repeated string candidates = 2;
|
||||
// Return per-token logprobs for each candidate when true. Default
|
||||
// false to keep the wire response small; the joint log_prob field
|
||||
// covers the common ranking case.
|
||||
bool include_token_logprobs = 3;
|
||||
// When true, the response also populates length_normalized_log_prob
|
||||
// (joint log-prob divided by candidate token count). Useful when
|
||||
// candidates differ in length and the consumer wants a per-token
|
||||
// measure comparable across them (PMI-style scoring).
|
||||
bool length_normalize = 4;
|
||||
}
|
||||
|
||||
// CandidateScore is one row in the ScoreResponse, matching by index
|
||||
// the candidate in ScoreRequest.candidates.
|
||||
message CandidateScore {
|
||||
// Sum of log P(token_i | prompt, candidate_token_<i) across the
|
||||
// candidate's tokens. The primary ranking signal.
|
||||
double log_prob = 1;
|
||||
// log_prob / num_tokens — populated when length_normalize=true on
|
||||
// the request.
|
||||
double length_normalized_log_prob = 2;
|
||||
// Per-token detail — populated when include_token_logprobs=true.
|
||||
repeated TokenLogProb tokens = 3;
|
||||
// Number of tokens the backend tokenised this candidate into, after
|
||||
// any backend-specific normalisation (e.g. leading-space handling).
|
||||
int32 num_tokens = 4;
|
||||
}
|
||||
|
||||
message TokenLogProb {
|
||||
string token = 1;
|
||||
double log_prob = 2;
|
||||
}
|
||||
|
||||
message ScoreResponse {
|
||||
repeated CandidateScore candidates = 1;
|
||||
}
|
||||
|
||||
message RerankRequest {
|
||||
string query = 1;
|
||||
repeated string documents = 2;
|
||||
@@ -325,6 +428,25 @@ message ModelOptions {
|
||||
// applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs).
|
||||
// Unknown keys produce an error at LoadModel time.
|
||||
string EngineArgs = 73;
|
||||
|
||||
// Proxy carries the cloud-proxy backend's per-model configuration.
|
||||
// Empty for non-proxy backends.
|
||||
ProxyOptions Proxy = 74;
|
||||
}
|
||||
|
||||
// ProxyOptions configures the cloud-proxy backend. UpstreamURL and
|
||||
// Mode are always meaningful; Provider only matters in translate mode.
|
||||
// The two api_key_* fields are mutually exclusive and resolved by the
|
||||
// backend at LoadModel — core forwards the references rather than the
|
||||
// plaintext key.
|
||||
message ProxyOptions {
|
||||
string upstream_url = 1;
|
||||
string mode = 2;
|
||||
string provider = 3;
|
||||
string api_key_env = 4;
|
||||
string api_key_file = 5;
|
||||
string upstream_model = 6;
|
||||
int32 request_timeout_seconds = 7;
|
||||
}
|
||||
|
||||
message Result {
|
||||
@@ -415,6 +537,15 @@ message TTSRequest {
|
||||
string dst = 3;
|
||||
string voice = 4;
|
||||
optional string language = 5;
|
||||
// instructions is a free-form, per-request style/voice description (maps to
|
||||
// the OpenAI `instructions` field). Backends that support expressive synthesis
|
||||
// (e.g. Qwen3-TTS CustomVoice/VoiceDesign) prefer this over the static YAML
|
||||
// option when set; backends that don't simply ignore it.
|
||||
optional string instructions = 6;
|
||||
// params carries optional, backend-specific per-request generation parameters
|
||||
// (e.g. Chatterbox exaggeration/cfg_weight/temperature). Values are strings and
|
||||
// coerced by the backend; unset leaves the backend's configured defaults.
|
||||
map<string, string> params = 7;
|
||||
}
|
||||
|
||||
message VADRequest {
|
||||
@@ -1002,3 +1133,32 @@ message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
// ForwardHeader is one HTTP header on the request or response. Headers
|
||||
// like Authorization are typically injected by the backend (from the
|
||||
// resolved API key) rather than passed through from the client.
|
||||
message ForwardHeader {
|
||||
string name = 1;
|
||||
string value = 2;
|
||||
}
|
||||
|
||||
// ForwardRequest is a streamed HTTP request to the upstream. First
|
||||
// message carries path/method/headers; subsequent messages carry
|
||||
// body_chunk only. All fields except body_chunk are honoured on the
|
||||
// first message and ignored thereafter.
|
||||
message ForwardRequest {
|
||||
string path = 1; // e.g. "/v1/chat/completions" — appended to the model's upstream_url
|
||||
string method = 2; // usually "POST"
|
||||
repeated ForwardHeader headers = 3;
|
||||
bytes body_chunk = 4;
|
||||
}
|
||||
|
||||
// ForwardReply is a streamed HTTP response from the upstream. First
|
||||
// message carries status/headers; subsequent messages carry body_chunk
|
||||
// only. SSE responses arrive as a sequence of body_chunk frames; the
|
||||
// caller is responsible for any parsing.
|
||||
message ForwardReply {
|
||||
int32 status = 1;
|
||||
repeated ForwardHeader headers = 2;
|
||||
bytes body_chunk = 3;
|
||||
}
|
||||
|
||||
|
||||
1
backend/cpp/ds4/.gitignore
vendored
1
backend/cpp/ds4/.gitignore
vendored
@@ -2,6 +2,7 @@ ds4/
|
||||
build/
|
||||
package/
|
||||
grpc-server
|
||||
ds4-worker
|
||||
*.o
|
||||
backend.pb.cc
|
||||
backend.pb.h
|
||||
|
||||
@@ -60,6 +60,11 @@ elseif(DS4_GPU STREQUAL "cpu")
|
||||
set(DS4_OBJS "${DS4_DIR}/ds4_cpu.o")
|
||||
endif()
|
||||
|
||||
# ds4.c now references ds4_distributed.c (distributed inference was split into
|
||||
# its own translation unit upstream). It is a single GPU-agnostic object shared
|
||||
# by every GPU mode, so link it in regardless of DS4_GPU.
|
||||
list(APPEND DS4_OBJS "${DS4_DIR}/ds4_distributed.o")
|
||||
|
||||
add_executable(${TARGET}
|
||||
grpc-server.cpp
|
||||
dsml_parser.cpp
|
||||
@@ -99,3 +104,36 @@ if(DS4_NATIVE)
|
||||
target_compile_options(${TARGET} PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ds4-worker: standalone distributed worker. Links the same ds4 engine objects
|
||||
# (including ds4_distributed.o) but has NO gRPC/protobuf dependency - it speaks
|
||||
# ds4's own TCP transport via ds4_dist_run(). Buildable wherever the engine
|
||||
# objects build, even on hosts without protobuf/grpc dev headers.
|
||||
add_executable(ds4-worker worker_main.c)
|
||||
target_include_directories(ds4-worker PRIVATE ${DS4_DIR})
|
||||
foreach(obj ${DS4_OBJS})
|
||||
target_sources(ds4-worker PRIVATE ${obj})
|
||||
set_source_files_properties(${obj} PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE)
|
||||
endforeach()
|
||||
# worker_main.c is C, but the engine objects built by nvcc (ds4_cuda.o) and the
|
||||
# Metal path (ds4_metal.o, Obj-C++) reference the C++ runtime (libstdc++). Force
|
||||
# the C++ linker driver so those symbols resolve; the C driver would not link
|
||||
# libstdc++ and the CUDA/Metal builds fail with undefined std:: references.
|
||||
set_target_properties(ds4-worker PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_link_libraries(ds4-worker PRIVATE Threads::Threads m)
|
||||
|
||||
if(DS4_GPU STREQUAL "cuda")
|
||||
target_link_libraries(ds4-worker PRIVATE CUDA::cudart CUDA::cublas)
|
||||
elseif(DS4_GPU STREQUAL "metal")
|
||||
target_link_libraries(ds4-worker PRIVATE ${FOUNDATION_LIB} ${METAL_LIB})
|
||||
elseif(DS4_GPU STREQUAL "cpu")
|
||||
target_compile_definitions(ds4-worker PRIVATE DS4_NO_GPU)
|
||||
endif()
|
||||
|
||||
if(DS4_NATIVE)
|
||||
if(APPLE)
|
||||
target_compile_options(ds4-worker PRIVATE -mcpu=native)
|
||||
else()
|
||||
target_compile_options(ds4-worker PRIVATE -march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# ds4 backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as DS4_VERSION?=ef0a4905d05263df8e63689f2dd1efac618a752c
|
||||
# Upstream pin lives below as DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# llama-cpp / ik-llama-cpp / turboquant convention.
|
||||
|
||||
DS4_VERSION?=ef0a4905d05263df8e63689f2dd1efac618a752c
|
||||
DS4_VERSION?=477c0e82e2699b35a65fd0a1ed6fe66b41087dfe
|
||||
DS4_REPO?=https://github.com/antirez/ds4
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -18,16 +18,19 @@ UNAME_S := $(shell uname -s)
|
||||
|
||||
CMAKE_ARGS ?= -DCMAKE_BUILD_TYPE=Release
|
||||
|
||||
# ds4_distributed.o is a GPU-agnostic translation unit that ds4.c/ds4_cpu.o now
|
||||
# reference (upstream split distributed inference into its own .c). The same
|
||||
# object is shared by every GPU mode, so it is appended unconditionally below.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS += -DDS4_GPU=cuda
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_cuda.o ds4_distributed.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
CMAKE_ARGS += -DDS4_GPU=metal
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o
|
||||
DS4_OBJ_TARGET := ds4.o ds4_metal.o ds4_distributed.o
|
||||
else
|
||||
# CPU reference path (Linux only - macOS CPU path is broken by VM bug per ds4 README).
|
||||
CMAKE_ARGS += -DDS4_GPU=cpu
|
||||
DS4_OBJ_TARGET := ds4_cpu.o
|
||||
DS4_OBJ_TARGET := ds4_cpu.o ds4_distributed.o
|
||||
endif
|
||||
|
||||
ifneq ($(NATIVE),true)
|
||||
@@ -52,17 +55,18 @@ ds4:
|
||||
# the right per-platform compile flags (Objective-C/Metal on Darwin, nvcc on Linux+CUDA).
|
||||
ds4/ds4.o: ds4
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_cuda.o ds4_distributed.o
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o
|
||||
+$(MAKE) -C ds4 ds4.o ds4_metal.o ds4_distributed.o
|
||||
else
|
||||
+$(MAKE) -C ds4 ds4_cpu.o
|
||||
+$(MAKE) -C ds4 ds4_cpu.o ds4_distributed.o
|
||||
endif
|
||||
|
||||
grpc-server: ds4/ds4.o
|
||||
mkdir -p $(BUILD_DIR)
|
||||
cd $(BUILD_DIR) && cmake $(CMAKE_ARGS) $(CURRENT_MAKEFILE_DIR) && cmake --build . --config Release -j $(JOBS)
|
||||
cp $(BUILD_DIR)/grpc-server grpc-server
|
||||
cp $(BUILD_DIR)/ds4-worker ds4-worker
|
||||
|
||||
package: grpc-server
|
||||
bash package.sh
|
||||
@@ -71,7 +75,7 @@ test:
|
||||
@echo "ds4 backend: e2e coverage at tests/e2e-backends/ (BACKEND_BINARY mode)"
|
||||
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR) grpc-server package
|
||||
rm -rf $(BUILD_DIR) grpc-server ds4-worker package
|
||||
if [ -d ds4 ]; then $(MAKE) -C ds4 clean; fi
|
||||
|
||||
purge: clean
|
||||
|
||||
@@ -23,8 +23,11 @@ extern "C" {
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@@ -51,6 +54,12 @@ ds4_session *g_session = nullptr;
|
||||
int g_ctx_size = 32768;
|
||||
std::string g_kv_cache_dir; // empty disables disk cache
|
||||
|
||||
// Distributed coordinator state. g_distributed is set true when LoadModel is
|
||||
// given 'ds4_role:coordinator'; generation then waits for the worker route to
|
||||
// form before running. Single-node behavior is unchanged when unset.
|
||||
bool g_distributed = false;
|
||||
int g_route_timeout_sec = 60;
|
||||
|
||||
std::atomic<Server *> g_server{nullptr};
|
||||
|
||||
// Parse a "key:value" option string. Returns empty when no colon.
|
||||
@@ -60,6 +69,77 @@ static std::pair<std::string, std::string> split_option(const std::string &opt)
|
||||
return {opt.substr(0, colon), opt.substr(colon + 1)};
|
||||
}
|
||||
|
||||
// Parse a positive base-10 integer. Returns false (without throwing) on empty,
|
||||
// trailing garbage, non-positive, or overflow - unlike std::stoi.
|
||||
static bool parse_positive_int(const std::string &s, int *out) {
|
||||
if (s.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long v = std::strtol(s.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || v <= 0 || v > INT_MAX) return false;
|
||||
*out = static_cast<int>(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse a ds4 layer spec "START:END" or "START:output" into the engine's
|
||||
// distributed layer fields. Returns false on malformed input.
|
||||
static bool parse_layers_spec(const std::string &spec, ds4_distributed_layers *out) {
|
||||
auto colon = spec.find(':');
|
||||
if (colon == std::string::npos) return false;
|
||||
std::string lhs = spec.substr(0, colon);
|
||||
std::string rhs = spec.substr(colon + 1);
|
||||
if (lhs.empty() || rhs.empty()) return false;
|
||||
char *end = nullptr;
|
||||
long start = std::strtol(lhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || start < 0) return false;
|
||||
out->start = static_cast<uint32_t>(start);
|
||||
out->has_output = false;
|
||||
if (rhs == "output") {
|
||||
out->has_output = true;
|
||||
out->end = out->start; // engine treats has_output as "through final layer"
|
||||
} else {
|
||||
long e = std::strtol(rhs.c_str(), &end, 10);
|
||||
if (!end || *end != '\0' || e < start) return false;
|
||||
out->end = static_cast<uint32_t>(e);
|
||||
}
|
||||
out->set = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// When acting as a distributed coordinator, block until the worker route
|
||||
// covers all layers (ds4_session_distributed_route_ready == 1) or the timeout
|
||||
// elapses. Returns an empty string on success, or an error message to return
|
||||
// to the client. No-op when not distributed.
|
||||
//
|
||||
// Takes the g_engine_mu lock by reference and RELEASES it during each poll
|
||||
// sleep. The wait can span up to g_route_timeout_sec seconds while workers
|
||||
// connect; holding g_engine_mu the whole time would block the Status/Health
|
||||
// readiness probes (they also lock g_engine_mu), making LocalAI's loader treat
|
||||
// a still-starting worker as hung.
|
||||
static std::string wait_route_ready(std::unique_lock<std::mutex> &lock) {
|
||||
if (!g_distributed) return "";
|
||||
char err[256] = {0};
|
||||
const int deadline_polls = g_route_timeout_sec * 10; // 100ms per poll
|
||||
for (int i = 0; i <= deadline_polls; ++i) {
|
||||
int ready = ds4_session_distributed_route_ready(g_session, err, sizeof(err));
|
||||
if (ready == 1) return "";
|
||||
if (ready < 0) {
|
||||
return std::string("ds4 distributed route error: ") +
|
||||
(err[0] ? err : "unknown");
|
||||
}
|
||||
// Release the lock while sleeping so Status/Health and other RPCs can
|
||||
// interleave during worker startup.
|
||||
lock.unlock();
|
||||
struct timespec ts = {0, 100L * 1000L * 1000L}; // 100ms
|
||||
nanosleep(&ts, nullptr);
|
||||
lock.lock();
|
||||
// A concurrent Free() may have torn down the engine while we slept.
|
||||
if (!g_engine || !g_session) {
|
||||
return "ds4: model unloaded while waiting for distributed route";
|
||||
}
|
||||
}
|
||||
return "ds4 distributed route incomplete: workers not connected (layers uncovered)";
|
||||
}
|
||||
|
||||
static void append_token_text(ds4_engine *engine, int token, std::string &out) {
|
||||
size_t len = 0;
|
||||
const char *text = ds4_token_text(engine, token, &len);
|
||||
@@ -377,6 +457,11 @@ public:
|
||||
backend::Result *result) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
|
||||
// Reset distributed state so a model swap (a second LoadModel without
|
||||
// ds4_role) doesn't inherit a stale coordinator configuration.
|
||||
g_distributed = false;
|
||||
g_route_timeout_sec = 60;
|
||||
|
||||
if (g_engine) {
|
||||
if (g_session) { ds4_session_free(g_session); g_session = nullptr; }
|
||||
ds4_engine_close(g_engine);
|
||||
@@ -394,12 +479,23 @@ public:
|
||||
std::string mtp_path;
|
||||
int mtp_draft = 0;
|
||||
float mtp_margin = 3.0f;
|
||||
std::string ds4_role, ds4_layers, ds4_listen;
|
||||
for (const auto &opt : request->options()) {
|
||||
auto [k, v] = split_option(opt);
|
||||
if (k == "mtp_path") mtp_path = v;
|
||||
else if (k == "mtp_draft") mtp_draft = std::stoi(v);
|
||||
else if (k == "mtp_margin") mtp_margin = std::stof(v);
|
||||
else if (k == "kv_cache_dir") g_kv_cache_dir = v;
|
||||
else if (k == "ds4_role") ds4_role = v;
|
||||
else if (k == "ds4_layers") ds4_layers = v;
|
||||
else if (k == "ds4_listen") ds4_listen = v;
|
||||
else if (k == "ds4_route_timeout") {
|
||||
if (!parse_positive_int(v, &g_route_timeout_sec)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_route_timeout must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
g_kv_cache.SetDir(g_kv_cache_dir);
|
||||
@@ -422,6 +518,49 @@ public:
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
|
||||
// Coordinator wiring. 'ds4_role:coordinator' enables layer-split
|
||||
// distributed inference: this process listens on ds4_listen and owns
|
||||
// the ds4_layers slice; workers dial in (see `local-ai worker
|
||||
// ds4-distributed`). Absent ds4_role => unchanged single-node path.
|
||||
// Must be static: opt.distributed.listen_host is a const char* the
|
||||
// engine retains past this call, so it cannot point at a local that
|
||||
// goes out of scope (otherwise a future "simplify to local" refactor
|
||||
// reintroduces a dangling pointer).
|
||||
static std::string s_listen_host;
|
||||
if (ds4_role == "coordinator") {
|
||||
if (ds4_layers.empty() || ds4_listen.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_role:coordinator requires ds4_layers and ds4_listen");
|
||||
return GStatus::OK;
|
||||
}
|
||||
// host:port for IPv4/hostname; IPv6 literals are unsupported (the
|
||||
// first colon would split inside the address).
|
||||
auto host_port = split_option(ds4_listen); // "host:port" -> {host, port}
|
||||
if (host_port.second.empty()) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen must be host:port");
|
||||
return GStatus::OK;
|
||||
}
|
||||
int listen_port = 0;
|
||||
if (!parse_positive_int(host_port.second, &listen_port)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: ds4_listen port must be a positive integer");
|
||||
return GStatus::OK;
|
||||
}
|
||||
ds4_distributed_layers layers = {};
|
||||
if (!parse_layers_spec(ds4_layers, &layers)) {
|
||||
result->set_success(false);
|
||||
result->set_message("ds4: invalid ds4_layers (want START:END or START:output)");
|
||||
return GStatus::OK;
|
||||
}
|
||||
s_listen_host = host_port.first;
|
||||
opt.distributed.role = DS4_DISTRIBUTED_COORDINATOR;
|
||||
opt.distributed.layers = layers;
|
||||
opt.distributed.listen_host = s_listen_host.c_str();
|
||||
opt.distributed.listen_port = listen_port;
|
||||
g_distributed = true;
|
||||
}
|
||||
|
||||
int rc = ds4_engine_open(&g_engine, &opt);
|
||||
if (rc != 0 || !g_engine) {
|
||||
result->set_success(false);
|
||||
@@ -458,10 +597,13 @@ public:
|
||||
|
||||
GStatus Predict(ServerContext *, const backend::PredictOptions *request,
|
||||
backend::Reply *reply) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
@@ -554,10 +696,13 @@ public:
|
||||
|
||||
GStatus PredictStream(ServerContext *, const backend::PredictOptions *request,
|
||||
ServerWriter<backend::Reply> *writer) override {
|
||||
std::lock_guard<std::mutex> lock(g_engine_mu);
|
||||
std::unique_lock<std::mutex> lock(g_engine_mu);
|
||||
if (!g_engine || !g_session) {
|
||||
return GStatus(StatusCode::FAILED_PRECONDITION, "ds4: model not loaded");
|
||||
}
|
||||
if (std::string route_err = wait_route_ready(lock); !route_err.empty()) {
|
||||
return GStatus(StatusCode::UNAVAILABLE, route_err);
|
||||
}
|
||||
ds4_tokens prompt = {};
|
||||
build_prompt(g_engine, request, &prompt);
|
||||
int n_predict = request->tokens() > 0 ? request->tokens() : 256;
|
||||
|
||||
@@ -5,7 +5,8 @@ REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
cp -avf "$CURDIR/grpc-server" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/ds4-worker" "$CURDIR/package/"
|
||||
cp -rfv "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
UNAME_S=$(uname -s)
|
||||
if [ "$UNAME_S" = "Darwin" ]; then
|
||||
|
||||
126
backend/cpp/ds4/worker_main.c
Normal file
126
backend/cpp/ds4/worker_main.c
Normal file
@@ -0,0 +1,126 @@
|
||||
// ds4-worker: standalone distributed worker for the LocalAI ds4 backend.
|
||||
//
|
||||
// A ds4 distributed worker owns a slice of the model's transformer layers,
|
||||
// dials the coordinator, and serves activations for its slice. It does NOT
|
||||
// speak backend.proto - it speaks ds4's own TCP transport via ds4_dist_run().
|
||||
// This binary is intentionally minimal (no HTTP/web/kvstore/linenoise): it
|
||||
// only needs the engine objects + ds4_distributed.o, which the backend already
|
||||
// builds. It is launched by `local-ai worker ds4-distributed`.
|
||||
//
|
||||
// Usage:
|
||||
// ds4-worker --role worker --model <gguf> --layers 20:output \
|
||||
// --coordinator <host> <port> [--cpu|--cuda|--metal] [-c CTX] [-t N]
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <signal.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include "ds4.h"
|
||||
#include "ds4_distributed.h"
|
||||
|
||||
static const char *need_arg(int *i, int argc, char **argv, const char *flag) {
|
||||
if (*i + 1 >= argc) {
|
||||
fprintf(stderr, "ds4-worker: missing value for %s\n", flag);
|
||||
exit(2);
|
||||
}
|
||||
return argv[++(*i)];
|
||||
}
|
||||
|
||||
static int parse_int_arg(const char *s, const char *flag) {
|
||||
char *end = NULL;
|
||||
long v = strtol(s, &end, 10);
|
||||
if (!s[0] || *end || v <= 0 || v > INT_MAX) {
|
||||
fprintf(stderr, "ds4-worker: invalid value for %s: %s\n", flag, s);
|
||||
exit(2);
|
||||
}
|
||||
return (int)v;
|
||||
}
|
||||
|
||||
static ds4_backend default_backend(void) {
|
||||
#if defined(DS4_NO_GPU)
|
||||
return DS4_BACKEND_CPU;
|
||||
#elif defined(__APPLE__)
|
||||
return DS4_BACKEND_METAL;
|
||||
#else
|
||||
return DS4_BACKEND_CUDA;
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
|
||||
ds4_engine_options opt = {0};
|
||||
opt.backend = default_backend();
|
||||
int ctx_size = 32768;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const char *arg = argv[i];
|
||||
if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
|
||||
fprintf(stdout, "ds4-worker: standalone ds4 distributed worker\n");
|
||||
ds4_dist_usage(stdout);
|
||||
fprintf(stdout, " -m, --model PATH model GGUF (the worker loads only its --layers slice)\n");
|
||||
fprintf(stdout, " -c, --ctx N context size (default 32768)\n");
|
||||
fprintf(stdout, " -t, --threads N CPU threads\n");
|
||||
fprintf(stdout, " --cpu|--cuda|--metal backend override\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
char dist_err[256] = {0};
|
||||
ds4_dist_cli_parse_result dist_parse =
|
||||
ds4_dist_parse_cli_arg(arg, &i, argc, argv, &opt.distributed,
|
||||
dist_err, sizeof(dist_err));
|
||||
if (dist_parse == DS4_DIST_CLI_ERROR) {
|
||||
fprintf(stderr, "ds4-worker: %s\n",
|
||||
dist_err[0] ? dist_err : "invalid distributed option");
|
||||
return 2;
|
||||
}
|
||||
if (dist_parse == DS4_DIST_CLI_MATCHED) continue;
|
||||
|
||||
if (!strcmp(arg, "-m") || !strcmp(arg, "--model")) {
|
||||
opt.model_path = need_arg(&i, argc, argv, arg);
|
||||
} else if (!strcmp(arg, "-c") || !strcmp(arg, "--ctx")) {
|
||||
ctx_size = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "-t") || !strcmp(arg, "--threads")) {
|
||||
opt.n_threads = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
|
||||
} else if (!strcmp(arg, "--cpu")) {
|
||||
opt.backend = DS4_BACKEND_CPU;
|
||||
} else if (!strcmp(arg, "--cuda")) {
|
||||
opt.backend = DS4_BACKEND_CUDA;
|
||||
} else if (!strcmp(arg, "--metal")) {
|
||||
opt.backend = DS4_BACKEND_METAL;
|
||||
} else {
|
||||
fprintf(stderr, "ds4-worker: unknown option: %s\n", arg);
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.distributed.role != DS4_DISTRIBUTED_WORKER) {
|
||||
fprintf(stderr, "ds4-worker: --role worker is required\n");
|
||||
return 2;
|
||||
}
|
||||
if (!opt.model_path) {
|
||||
fprintf(stderr, "ds4-worker: --model is required\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
char prep_err[256] = {0};
|
||||
if (ds4_dist_prepare_engine_options(&opt.distributed, &opt,
|
||||
prep_err, sizeof(prep_err)) != 0) {
|
||||
fprintf(stderr, "ds4-worker: %s\n", prep_err);
|
||||
return 2;
|
||||
}
|
||||
|
||||
ds4_engine *engine = NULL;
|
||||
if (ds4_engine_open(&engine, &opt) != 0 || !engine) {
|
||||
fprintf(stderr, "ds4-worker: failed to open engine\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ds4_dist_generation_options gen = {0};
|
||||
gen.ctx_size = ctx_size;
|
||||
int rc = ds4_dist_run(engine, &opt.distributed, &gen);
|
||||
ds4_engine_close(engine);
|
||||
return rc;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=3e573cfea6e0a332eff822ffbdb1dd3b112e9051
|
||||
IK_LLAMA_VERSION?=1520eda980564241434b791ce2bbbd128c4be9ea
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=0253fb21f595246f54c192fe8332f34173be251b
|
||||
LLAMA_VERSION?=7c158fbb4aec1bdc9c81d6ca0e785139f4826fae
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
@@ -121,6 +122,40 @@ static std::string base64_encode_bytes(const unsigned char* data, size_t len) {
|
||||
|
||||
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
|
||||
|
||||
// Score bypasses the slot loop (see the comment on Score below) so it
|
||||
// must not run concurrently with any slot-loop RPC. These counters
|
||||
// are a defence-in-depth tripwire — ModelConfig.Validate already
|
||||
// rejects llama-cpp configs that mix score with chat/completion/
|
||||
// embeddings, so a healthy deployment never trips them. seq_cst is
|
||||
// load-bearing for the increment-then-check pattern below.
|
||||
static std::atomic<int> slot_loop_inflight{0};
|
||||
static std::atomic<int> score_inflight{0};
|
||||
|
||||
// Increment-then-check, not check-then-increment: two simultaneous
|
||||
// racers both observe the other's increment and both abort cleanly.
|
||||
// Reversed, both could see zero and proceed.
|
||||
struct conflict_guard {
|
||||
std::atomic<int>& self;
|
||||
conflict_guard(const char* rpc, std::atomic<int>& self_, std::atomic<int>& other, const char* other_name)
|
||||
: self(self_) {
|
||||
self.fetch_add(1, std::memory_order_seq_cst);
|
||||
int o = other.load(std::memory_order_seq_cst);
|
||||
if (o > 0) {
|
||||
fprintf(stderr,
|
||||
"FATAL: %s called with %s=%d. The llama-cpp backend cannot "
|
||||
"service Score and slot-loop RPCs concurrently — Score "
|
||||
"bypasses the slot loop and races the llama_context. Bind "
|
||||
"Score-using features to a model dedicated to scoring "
|
||||
"(known_usecases: [score] with no chat/completion/embeddings).\n",
|
||||
rpc, other_name, o);
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
~conflict_guard() {
|
||||
self.fetch_sub(1, std::memory_order_seq_cst);
|
||||
}
|
||||
};
|
||||
|
||||
static std::function<void(int)> shutdown_handler;
|
||||
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
@@ -517,16 +552,33 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.warmup = true;
|
||||
// no_op_offload: disable host tensor op offload (default: false)
|
||||
params.no_op_offload = false;
|
||||
// kv_unified: enable unified KV cache (default: false)
|
||||
params.kv_unified = false;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot (default: 8)
|
||||
params.n_ctx_checkpoints = 8;
|
||||
|
||||
// llama memory fit fails if we don't provide a buffer for tensor overrides
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
// kv_unified: enable unified KV cache. Upstream's server auto-enables this
|
||||
// when the slot count is auto (-np <0), bumping n_parallel to 4 alongside.
|
||||
// LocalAI keeps n_parallel=1 by default, which would skip that auto path
|
||||
// and leave kv_unified=false. We flip the default to true here so the
|
||||
// server-side prompt cache (cache_idle_slots) is actually usable on the
|
||||
// single-slot path that LocalAI ships with: without it, idle slots are
|
||||
// never persisted across requests and the prompt cache is dead weight.
|
||||
// Users can opt out with `options: [ "kv_unified:false" ]`.
|
||||
params.kv_unified = true;
|
||||
// n_ctx_checkpoints: max context checkpoints per slot. Match upstream's
|
||||
// default (32); the previous LocalAI-specific 8 was unnecessarily tight
|
||||
// and limits partial-prefix recovery without a clear memory rationale.
|
||||
params.n_ctx_checkpoints = 32;
|
||||
// cache_idle_slots: save and clear idle slot KV to the prompt cache on
|
||||
// task switch. Upstream default is true; the server auto-disables it if
|
||||
// kv_unified=false or cache_ram_mib=0, so flipping kv_unified above is
|
||||
// what actually unlocks it.
|
||||
params.cache_idle_slots = true;
|
||||
// checkpoint_min_step: minimum spacing between context checkpoints in
|
||||
// tokens (0 disables the minimum). Match upstream's default (256). This
|
||||
// field was renamed from `checkpoint_every_nt` in llama.cpp; the semantics
|
||||
// also shifted from a fixed cadence to a minimum spacing. The turboquant
|
||||
// fork branched before the field existed, so skip it on the legacy path
|
||||
// (LOCALAI_LEGACY_LLAMA_CPP_SPEC is injected by patch-grpc-server.sh).
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
params.checkpoint_min_step = 256;
|
||||
#endif
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
for (int i = 0; i < request->options_size(); i++) {
|
||||
@@ -685,10 +737,44 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
try {
|
||||
params.n_ctx_checkpoints = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (8)
|
||||
// If conversion fails, keep default value (32)
|
||||
}
|
||||
}
|
||||
|
||||
// --- server-side idle-slot prompt cache toggle (upstream --cache-idle-slots) ---
|
||||
// Saves the slot's KV state into the host-side prompt cache on task
|
||||
// switch so a later request with the same prefix can warm-load it.
|
||||
// Auto-disabled by the server if kv_unified=false or cache_ram=0.
|
||||
} else if (!strcmp(optname, "cache_idle_slots") || !strcmp(optname, "idle_slots_cache")) {
|
||||
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") {
|
||||
params.cache_idle_slots = true;
|
||||
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") {
|
||||
params.cache_idle_slots = false;
|
||||
}
|
||||
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
// --- minimum context-checkpoint spacing (upstream -cms / --checkpoint-min-step) ---
|
||||
// 0 disables the minimum-spacing gate. Old option names (`checkpoint_every_nt`,
|
||||
// `checkpoint_every_n_tokens`) are kept as aliases for backward compatibility
|
||||
// with existing user configs: upstream renamed the field and shifted its
|
||||
// semantics from a fixed cadence to a minimum spacing.
|
||||
//
|
||||
// Gated out for the turboquant fork, which lacks common_params::
|
||||
// checkpoint_min_step. The leading `}` closing the cache_idle_slots
|
||||
// branch is removed with this block; the next `} else if` (n_ubatch)
|
||||
// then closes cache_idle_slots, so braces stay balanced under both
|
||||
// preprocessor branches.
|
||||
} else if (!strcmp(optname, "checkpoint_min_step") || !strcmp(optname, "checkpoint_min_spacing") ||
|
||||
!strcmp(optname, "checkpoint_every_nt") || !strcmp(optname, "checkpoint_every_n_tokens")) {
|
||||
if (optval != NULL) {
|
||||
try {
|
||||
params.checkpoint_min_step = std::stoi(optval_str);
|
||||
} catch (const std::exception& e) {
|
||||
// If conversion fails, keep default value (256)
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// --- physical batch size (upstream -ub / --ubatch-size) ---
|
||||
// Note: line ~482 already aliases n_ubatch to n_batch as a default; this
|
||||
// option lets users decouple the two (useful for embeddings/rerank).
|
||||
@@ -1081,6 +1167,26 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
params.kv_overrides.back().key[0] = 0;
|
||||
}
|
||||
|
||||
// tensor_buft_overrides sentinel termination (mirrors upstream common/arg.cpp).
|
||||
// Real entries are pushed during option parsing; here we pad/terminate so the
|
||||
// model loader sees back().pattern == nullptr (GGML_ASSERT at common.cpp:1543)
|
||||
// and so llama_params_fit has the placeholder slots it requires.
|
||||
{
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
// The draft tensor_buft_overrides are only populated under the modern
|
||||
// (post-#22838) layout, whose population code is itself gated by
|
||||
// LOCALAI_LEGACY_LLAMA_CPP_SPEC above. The turboquant fork lacks
|
||||
// common_params_speculative::draft entirely, so skip the sentinel there too.
|
||||
#ifndef LOCALAI_LEGACY_LLAMA_CPP_SPEC
|
||||
if (!params.speculative.draft.tensor_buft_overrides.empty()) {
|
||||
params.speculative.draft.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: Add yarn
|
||||
|
||||
if (!request->tensorsplit().empty()) {
|
||||
@@ -1399,6 +1505,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("PredictStream", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
|
||||
@@ -1837,6 +1944,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto re_it = metadata.find("reasoning_effort");
|
||||
if (re_it != metadata.end() && !re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2097,7 +2215,15 @@ public:
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
// Upstream llama.cpp (ggml-org/llama.cpp#23884) now emits an initial
|
||||
// "begin" partial whose to_json() returns null, used only to signal the
|
||||
// HTTP layer to flush 200 status headers before any token. gRPC has no
|
||||
// such concept, so there is nothing to emit — the real tokens arrive in
|
||||
// the loop below. Feeding this null into build_reply_from_json would
|
||||
// throw (uncaught) and surface as a generic RPC error.
|
||||
if (first_res_json.is_null()) {
|
||||
// skip the begin-of-stream marker
|
||||
} else if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
@@ -2127,7 +2253,10 @@ public:
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
if (res_json.is_null()) {
|
||||
// begin-of-stream marker (see note above) — nothing to emit
|
||||
continue;
|
||||
} else if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
@@ -2158,6 +2287,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Predict", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json data = parse_options(true, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
data["stream"] = false;
|
||||
@@ -2618,6 +2748,17 @@ public:
|
||||
body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true");
|
||||
}
|
||||
|
||||
// Pass reasoning_effort via chat_template_kwargs too: the lever
|
||||
// jinja templates like gpt-oss (Harmony) / LFM2.5 read, distinct
|
||||
// from enable_thinking which those templates ignore.
|
||||
auto predict_re_it = predict_metadata.find("reasoning_effort");
|
||||
if (predict_re_it != predict_metadata.end() && !predict_re_it->second.empty()) {
|
||||
if (!body_json.contains("chat_template_kwargs")) {
|
||||
body_json["chat_template_kwargs"] = json::object();
|
||||
}
|
||||
body_json["chat_template_kwargs"]["reasoning_effort"] = predict_re_it->second;
|
||||
}
|
||||
|
||||
// Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.)
|
||||
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str());
|
||||
|
||||
@@ -2916,6 +3057,7 @@ public:
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("Embedding", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
|
||||
body["stream"] = false;
|
||||
@@ -3023,6 +3165,8 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
conflict_guard guard("Rerank", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// Create and queue the task
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
{
|
||||
@@ -3095,12 +3239,218 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
// Score returns the model's joint log-probability of each candidate
|
||||
// continuation given a shared prompt.
|
||||
//
|
||||
// WHY bypass the slot/task queue: upstream server_context exposes
|
||||
// get_llama_context as "main thread only" and the slot loop's
|
||||
// update_slots() owns the context whenever a task is in flight.
|
||||
// No public synchronization primitive is available — so Score is
|
||||
// unsafe to call concurrently with active generation through this
|
||||
// backend. In practice routing-classifier calls happen before the
|
||||
// request is routed to a generation backend, so the model used
|
||||
// for Score is typically idle. Concurrent Score calls are
|
||||
// serialised by a local mutex; KV-cache state is isolated behind
|
||||
// a dedicated sequence ID cleared between candidates.
|
||||
//
|
||||
// A patch to server-context.cpp that adds SERVER_TASK_TYPE_SCORE
|
||||
// and routes scoring through the slot loop would be the correct
|
||||
// long-term fix; tracked as a follow-up.
|
||||
//
|
||||
// Perf TODO (measured: ~450 ms warm for 3 candidates on Arch-
|
||||
// Router-1.5B Q4_K_M + Intel SYCL): the current loop re-decodes
|
||||
// `prompt + candidate` from scratch for every candidate, throwing
|
||||
// away the prompt's KV cache between iterations. A smarter
|
||||
// version would:
|
||||
// 1. Decode just the prompt once into score_seq_id.
|
||||
// 2. Snapshot/cp that sequence (llama_memory_seq_cp) into a
|
||||
// per-candidate sequence id.
|
||||
// 3. For each candidate, decode only its tokens onto the copy
|
||||
// (continuing from the saved prompt state), read logits.
|
||||
// 4. llama_memory_seq_rm the copy.
|
||||
// Estimated speedup: 3-candidate calls 450 ms -> ~150-200 ms,
|
||||
// 6-candidate calls 630 ms -> ~220 ms. Single source-file change,
|
||||
// no proto / Go-side changes needed. Worth doing once routing is
|
||||
// wired into the middleware and Score is on the hot path of every
|
||||
// chat request.
|
||||
grpc::Status Score(ServerContext* context, const backend::ScoreRequest* request, backend::ScoreResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
if (request->candidates_size() == 0) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "candidates must be non-empty");
|
||||
}
|
||||
|
||||
// Tripwire against the slot loop. Acquired before score_mutex
|
||||
// so it fires even when this Score is queued behind another.
|
||||
conflict_guard guard("Score", score_inflight, slot_loop_inflight, "slot_loop_inflight");
|
||||
|
||||
// Serialise concurrent Score calls. The slot loop is still
|
||||
// free to race with us — see the class comment above.
|
||||
static std::mutex score_mutex;
|
||||
std::lock_guard<std::mutex> score_lock(score_mutex);
|
||||
|
||||
llama_context * lctx = ctx_server.get_llama_context();
|
||||
if (lctx == nullptr) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "llama context unavailable (sleeping?)");
|
||||
}
|
||||
const llama_vocab * vocab = ctx_server.impl->vocab;
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
const int32_t n_ctx = llama_n_ctx(lctx);
|
||||
llama_memory_t mem = llama_get_memory(lctx);
|
||||
|
||||
// The KV-cache is sized to seq_to_stream.size() at load
|
||||
// (typically equal to n_slots, often 1). Sequence IDs must
|
||||
// be in [0, n_seq_max), so we can't pick a high-value
|
||||
// "private" ID — we have to share with the slot. We clear
|
||||
// the cache before AND after each candidate to keep
|
||||
// scoring isolated from whatever state the slot held, and
|
||||
// the static mutex above guarantees no other Score call is
|
||||
// racing in the meantime. The slot loop is still free to
|
||||
// race (see comment on this method) — Score must not run
|
||||
// concurrently with generation through this backend.
|
||||
const llama_seq_id score_seq_id = 0;
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
|
||||
// Tokenize the shared prompt once with add_special=true so
|
||||
// BOS is prepended when the model requires it. parse_special
|
||||
// keeps chat-template markers in the prompt intact.
|
||||
const std::string prompt = request->prompt();
|
||||
std::vector<llama_token> prompt_tokens = common_tokenize(vocab, prompt, /*add_special=*/true, /*parse_special=*/true);
|
||||
const int32_t prompt_len = (int32_t) prompt_tokens.size();
|
||||
|
||||
for (int ci = 0; ci < request->candidates_size(); ci++) {
|
||||
const std::string & candidate_text = request->candidates(ci);
|
||||
|
||||
// Re-tokenize prompt + candidate as a single string. BPE
|
||||
// merges across the boundary can shift the tokenization
|
||||
// versus tokenize(prompt) ++ tokenize(candidate), so we
|
||||
// find the divergence point against prompt_tokens.
|
||||
std::vector<llama_token> full_tokens = common_tokenize(vocab, prompt + candidate_text, /*add_special=*/true, /*parse_special=*/true);
|
||||
int32_t divergence = prompt_len;
|
||||
const int32_t min_len = std::min<int32_t>(prompt_len, (int32_t) full_tokens.size());
|
||||
for (int32_t i = 0; i < min_len; i++) {
|
||||
if (prompt_tokens[i] != full_tokens[i]) {
|
||||
divergence = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const int32_t cand_len = (int32_t) full_tokens.size() - divergence;
|
||||
backend::CandidateScore * cs = response->add_candidates();
|
||||
cs->set_num_tokens(cand_len);
|
||||
if (cand_len <= 0) {
|
||||
cs->set_log_prob(0.0);
|
||||
if (request->length_normalize()) {
|
||||
cs->set_length_normalized_log_prob(0.0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (divergence < 1) {
|
||||
// Need at least one prior token (typically BOS) to
|
||||
// predict the first candidate token's logit. Tokeniser
|
||||
// models without BOS + an empty prompt fall in here.
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
|
||||
"Score: prompt produced no leading tokens; need at least one (e.g. BOS) to predict candidate");
|
||||
}
|
||||
if ((int32_t) full_tokens.size() > n_ctx) {
|
||||
return grpc::Status(grpc::StatusCode::OUT_OF_RANGE,
|
||||
"Score: prompt+candidate exceeds context size (got " +
|
||||
std::to_string(full_tokens.size()) + ", n_ctx=" + std::to_string(n_ctx) + ")");
|
||||
}
|
||||
|
||||
// Build a batch covering the entire prompt+candidate. We
|
||||
// need logits at (divergence-1) onward — those are the
|
||||
// predictions for each candidate token.
|
||||
llama_batch batch = llama_batch_init((int32_t) full_tokens.size(), 0, 1);
|
||||
for (int32_t i = 0; i < (int32_t) full_tokens.size(); i++) {
|
||||
batch.token[i] = full_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = score_seq_id;
|
||||
// logits[i] is "do we want the prediction *for the
|
||||
// next token*, computed from this position?"
|
||||
// We want predictions for candidate tokens at
|
||||
// positions divergence .. full_tokens.size()-1, which
|
||||
// come from logits at positions (divergence-1) ..
|
||||
// (full_tokens.size()-2).
|
||||
bool need_logit = (i >= divergence - 1) && (i < (int32_t) full_tokens.size() - 1);
|
||||
batch.logits[i] = need_logit ? 1 : 0;
|
||||
}
|
||||
batch.n_tokens = (int32_t) full_tokens.size();
|
||||
|
||||
// Decode the batch. If decode fails (e.g. KV slot
|
||||
// exhaustion), surface as INTERNAL — the caller will
|
||||
// typically fall back to a sampling-based classifier.
|
||||
int decode_err = llama_decode(lctx, batch);
|
||||
if (decode_err != 0) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_decode failed during Score: " + std::to_string(decode_err));
|
||||
}
|
||||
|
||||
// Sum log-probabilities of the actual candidate tokens.
|
||||
double total_log_prob = 0.0;
|
||||
for (int32_t k = 0; k < cand_len; k++) {
|
||||
// The k-th candidate token sits at full_tokens index
|
||||
// (divergence + k). Its predicting logit is at batch
|
||||
// position (divergence + k - 1).
|
||||
int32_t logit_pos = divergence + k - 1;
|
||||
const float * logits = llama_get_logits_ith(lctx, logit_pos);
|
||||
if (logits == nullptr) {
|
||||
llama_batch_free(batch);
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL,
|
||||
"llama_get_logits_ith returned null at position " + std::to_string(logit_pos));
|
||||
}
|
||||
llama_token target_token = full_tokens[divergence + k];
|
||||
|
||||
// Compute log_softmax(logits)[target_token] with the
|
||||
// max-subtraction stability trick.
|
||||
float max_logit = logits[0];
|
||||
for (int32_t v = 1; v < n_vocab; v++) {
|
||||
if (logits[v] > max_logit) max_logit = logits[v];
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (int32_t v = 0; v < n_vocab; v++) {
|
||||
sum_exp += std::exp((double)(logits[v] - max_logit));
|
||||
}
|
||||
double token_log_prob = (double)(logits[target_token] - max_logit) - std::log(sum_exp);
|
||||
total_log_prob += token_log_prob;
|
||||
|
||||
if (request->include_token_logprobs()) {
|
||||
backend::TokenLogProb * tlp = cs->add_tokens();
|
||||
std::string piece = common_token_to_piece(lctx, target_token);
|
||||
tlp->set_token(piece);
|
||||
tlp->set_log_prob(token_log_prob);
|
||||
}
|
||||
}
|
||||
|
||||
cs->set_log_prob(total_log_prob);
|
||||
if (request->length_normalize() && cand_len > 0) {
|
||||
cs->set_length_normalized_log_prob(total_log_prob / (double) cand_len);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
// Drop this candidate's KV-cache contribution so the next
|
||||
// candidate starts from a clean state. Without this, the
|
||||
// next decode would conflict at positions 0..N-1 for our
|
||||
// sequence ID.
|
||||
llama_memory_seq_rm(mem, score_seq_id, -1, -1);
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
conflict_guard guard("TokenizeString", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
json body = parse_options(false, request, params_base, ctx_server.get_llama_context());
|
||||
body["stream"] = false;
|
||||
|
||||
@@ -3122,6 +3472,8 @@ public:
|
||||
|
||||
grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override {
|
||||
|
||||
conflict_guard guard("GetMetrics", slot_loop_inflight, score_inflight, "score_inflight");
|
||||
|
||||
// request slots data using task queue
|
||||
auto rd = ctx_server.get_response_reader();
|
||||
int task_id = rd.queue_tasks.get_new_id();
|
||||
|
||||
@@ -124,8 +124,11 @@ fi
|
||||
# 5. Define LOCALAI_LEGACY_LLAMA_CPP_SPEC at the top of the file so the
|
||||
# grpc-server option parser skips the new option-handler blocks (ngram_mod,
|
||||
# ngram_map_k, ngram_map_k4v, ngram_cache, draft.cache_type_*, draft.cpuparams*,
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout. Those
|
||||
# blocks reference struct fields that simply do not exist in the fork.
|
||||
# draft.tensor_buft_overrides) introduced for the post-#22838 layout, the
|
||||
# draft.tensor_buft_overrides sentinel termination, and the
|
||||
# common_params::checkpoint_min_step default/option (added with the
|
||||
# 35c9b1f3 bump). Those blocks reference struct fields that simply do not
|
||||
# exist in the fork.
|
||||
if grep -q '^#define LOCALAI_LEGACY_LLAMA_CPP_SPEC' "$SRC"; then
|
||||
echo "==> $SRC already defines LOCALAI_LEGACY_LLAMA_CPP_SPEC, skipping"
|
||||
else
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# acestep.cpp version
|
||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||
ACESTEP_CPP_VERSION?=e0c8d75a672fca5684c88c68dbf6d12f58754258
|
||||
ACESTEP_CPP_VERSION?=ed53caf164e4492a5620b2e3f2264629cf66da24
|
||||
SO_TARGET?=libgoacestepcpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -22,12 +22,11 @@
|
||||
#include <vector>
|
||||
|
||||
// Global model contexts (loaded once, reused across requests)
|
||||
static DiTGGML g_dit = {};
|
||||
static DiTGGMLConfig g_dit_cfg;
|
||||
static VAEGGML g_vae = {};
|
||||
static bool g_dit_loaded = false;
|
||||
static bool g_vae_loaded = false;
|
||||
static bool g_is_turbo = false;
|
||||
static DiTGGML g_dit = {};
|
||||
static VAEGGML g_vae = {};
|
||||
static bool g_dit_loaded = false;
|
||||
static bool g_vae_loaded = false;
|
||||
static bool g_is_turbo = false;
|
||||
|
||||
// Silence latent [15000, 64] — read once from DiT GGUF
|
||||
static std::vector<float> g_silence_full;
|
||||
@@ -72,10 +71,9 @@ int load_model(const char * lm_model_path, const char * text_encoder_path,
|
||||
g_text_enc_path = text_encoder_path;
|
||||
g_dit_path = dit_model_path;
|
||||
|
||||
// Load DiT model
|
||||
// Load DiT model (backend init + config are handled inside dit_ggml_load)
|
||||
fprintf(stderr, "[acestep-cpp] Loading DiT from %s\n", dit_model_path);
|
||||
dit_ggml_init_backend(&g_dit);
|
||||
if (!dit_ggml_load(&g_dit, dit_model_path, g_dit_cfg, nullptr, 0.0f)) {
|
||||
if (!dit_ggml_load(&g_dit, dit_model_path)) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load DiT from %s\n", dit_model_path);
|
||||
return 1;
|
||||
}
|
||||
@@ -149,16 +147,16 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
|
||||
// Compute T (latent frames at 25Hz)
|
||||
int T = (int)(duration * FRAMES_PER_SECOND);
|
||||
T = ((T + g_dit_cfg.patch_size - 1) / g_dit_cfg.patch_size) * g_dit_cfg.patch_size;
|
||||
int S = T / g_dit_cfg.patch_size;
|
||||
T = ((T + g_dit.cfg.patch_size - 1) / g_dit.cfg.patch_size) * g_dit.cfg.patch_size;
|
||||
int S = T / g_dit.cfg.patch_size;
|
||||
|
||||
if (T > 15000) {
|
||||
fprintf(stderr, "[acestep-cpp] ERROR: T=%d exceeds max 15000\n", T);
|
||||
return 2;
|
||||
}
|
||||
|
||||
int Oc = g_dit_cfg.out_channels; // 64
|
||||
int ctx_ch = g_dit_cfg.in_channels - Oc; // 128
|
||||
int Oc = g_dit.cfg.out_channels; // 64
|
||||
int ctx_ch = g_dit.cfg.in_channels - Oc; // 128
|
||||
|
||||
fprintf(stderr, "[acestep-cpp] T=%d, S=%d, duration=%.1fs, seed=%d\n", T, S, duration, seed);
|
||||
|
||||
@@ -191,9 +189,8 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
|
||||
fprintf(stderr, "[acestep-cpp] caption: %d tokens, lyrics: %d tokens\n", S_text, S_lyric);
|
||||
|
||||
// 4. Text encoder forward
|
||||
// 4. Text encoder forward (backend init handled inside qwen3_load_text_encoder)
|
||||
Qwen3GGML text_enc = {};
|
||||
qwen3_init_backend(&text_enc);
|
||||
if (!qwen3_load_text_encoder(&text_enc, g_text_enc_path.c_str())) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load text encoder\n");
|
||||
return 4;
|
||||
@@ -209,9 +206,8 @@ int generate_music(const char * caption, const char * lyrics, int bpm,
|
||||
std::vector<float> lyric_embed(H_text * S_lyric);
|
||||
qwen3_embed_lookup(&text_enc, lyric_ids.data(), S_lyric, lyric_embed.data());
|
||||
|
||||
// 6. Condition encoder
|
||||
// 6. Condition encoder (backend init handled inside cond_ggml_load)
|
||||
CondGGML cond = {};
|
||||
cond_ggml_init_backend(&cond);
|
||||
if (!cond_ggml_load(&cond, g_dit_path.c_str())) {
|
||||
fprintf(stderr, "[acestep-cpp] FATAL: failed to load condition encoder\n");
|
||||
qwen3_free(&text_enc);
|
||||
|
||||
12
backend/go/cloud-proxy/Makefile
Normal file
12
backend/go/cloud-proxy/Makefile
Normal file
@@ -0,0 +1,12 @@
|
||||
GOCMD=go
|
||||
|
||||
cloud-proxy:
|
||||
CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o cloud-proxy ./
|
||||
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
build: cloud-proxy package
|
||||
|
||||
clean:
|
||||
rm -f cloud-proxy
|
||||
16
backend/go/cloud-proxy/cloud_proxy_suite_test.go
Normal file
16
backend/go/cloud-proxy/cloud_proxy_suite_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Ginkgo bootstrap. The other Test* functions in this package use
|
||||
// raw testing.T and run independently; they coexist with Ginkgo
|
||||
// specs registered via Describe / Context.
|
||||
func TestCloudProxySpecs(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "cloud-proxy specs")
|
||||
}
|
||||
39
backend/go/cloud-proxy/main.go
Normal file
39
backend/go/cloud-proxy/main.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package main
|
||||
|
||||
// cloud-proxy is a LocalAI backend that forwards request traffic to an
|
||||
// external HTTP provider (OpenAI, Anthropic, etc.). Two modes:
|
||||
//
|
||||
// - passthrough: serves the Forward RPC; the client wire format is
|
||||
// preserved end-to-end, no translation.
|
||||
// - translate: serves Predict/PredictStream; the backend converts
|
||||
// internal proto to the provider's wire format. (Phases 5–6.)
|
||||
//
|
||||
// LoadModel reads UpstreamURL/Mode/Provider/key references from
|
||||
// ProxyOptions and resolves the API key once at load time.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/xlog"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:50051", "the address to listen on")
|
||||
|
||||
func main() {
|
||||
// xlog's default handler emits ANSI color codes; that's fine for an
|
||||
// interactive shell but unreadable when the backend's stdout is
|
||||
// captured by LocalAI and tee'd to a log file. Force plain text when
|
||||
// LOCALAI_LOG_FORMAT is unset and stdout isn't a terminal.
|
||||
format := os.Getenv("LOCALAI_LOG_FORMAT")
|
||||
if format == "" && !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
format = xlog.TextFormat
|
||||
}
|
||||
xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), format))
|
||||
flag.Parse()
|
||||
if err := grpc.StartServer(*addr, NewCloudProxy()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
13
backend/go/cloud-proxy/package.sh
Executable file
13
backend/go/cloud-proxy/package.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the cloud-proxy binary into the package dir for the
|
||||
# final Dockerfile stage. Mirrors backend/go/local-store/package.sh —
|
||||
# no extra runtime libs needed since the backend is pure Go.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
mkdir -p $CURDIR/package
|
||||
cp -avf $CURDIR/cloud-proxy $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
325
backend/go/cloud-proxy/passthrough_edge_test.go
Normal file
325
backend/go/cloud-proxy/passthrough_edge_test.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("composeURL", func() {
|
||||
// Upstream URL convention: gallery configs put the canonical path
|
||||
// in upstream_url, so per-request Path is ignored. A bare-host
|
||||
// upstream_url accepts the per-request path.
|
||||
DescribeTable("path resolution",
|
||||
func(upstream, reqPath, want string) {
|
||||
got, err := composeURL(upstream, reqPath)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(got).To(Equal(want))
|
||||
},
|
||||
Entry("full path wins", "https://api.openai.com/v1/chat/completions", "/v1/something-else", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host accepts path", "https://api.openai.com", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("root slash treated as bare", "https://api.openai.com/", "/v1/chat/completions", "https://api.openai.com/v1/chat/completions"),
|
||||
Entry("bare host + empty path", "https://api.openai.com", "", "https://api.openai.com"),
|
||||
)
|
||||
|
||||
It("returns an error on invalid upstream URL", func() {
|
||||
_, err := composeURL("://garbage", "")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("applyAuthHeader", func() {
|
||||
It("sets x-api-key and anthropic-version for Anthropic, no Authorization", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerAnthropic, "ant-key")
|
||||
Expect(req.Header.Get("x-api-key")).To(Equal("ant-key"))
|
||||
Expect(req.Header.Get("anthropic-version")).NotTo(BeEmpty())
|
||||
Expect(req.Header.Get("Authorization")).To(BeEmpty(), "Authorization must not leak on Anthropic backend")
|
||||
})
|
||||
|
||||
It("sets Bearer Authorization for OpenAI, no x-api-key", func() {
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, providerOpenAI, "sk-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer sk-key"))
|
||||
Expect(req.Header.Get("x-api-key")).To(BeEmpty(), "x-api-key must not leak on OpenAI backend")
|
||||
})
|
||||
|
||||
It("defaults to Bearer when provider is empty", func() {
|
||||
// Passthrough mode often has provider == "" because the operator
|
||||
// doesn't claim a specific upstream wire format. Most providers
|
||||
// (including OpenAI-compatible ones) accept Bearer, so default to it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
applyAuthHeader(req, "", "some-key")
|
||||
Expect(req.Header.Get("Authorization")).To(Equal("Bearer some-key"))
|
||||
})
|
||||
|
||||
It("preserves an existing anthropic-version header", func() {
|
||||
// If the client supplied anthropic-version (rare but legitimate
|
||||
// for an upstream pinned to a specific date), the proxy must not
|
||||
// clobber it.
|
||||
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||
req.Header.Set("anthropic-version", "2024-10-01")
|
||||
applyAuthHeader(req, providerAnthropic, "k")
|
||||
Expect(req.Header.Get("anthropic-version")).To(Equal("2024-10-01"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("isHopByHopHeader", func() {
|
||||
DescribeTable("hop-by-hop classification",
|
||||
func(header string, want bool) {
|
||||
Expect(isHopByHopHeader(header)).To(Equal(want))
|
||||
},
|
||||
Entry("Connection is hop-by-hop", "Connection", true),
|
||||
Entry("Keep-Alive is hop-by-hop", "Keep-Alive", true),
|
||||
Entry("Proxy-Connection is hop-by-hop", "Proxy-Connection", true),
|
||||
Entry("Transfer-Encoding is hop-by-hop", "Transfer-Encoding", true),
|
||||
Entry("TE is hop-by-hop", "TE", true),
|
||||
Entry("Trailer is hop-by-hop", "Trailer", true),
|
||||
Entry("Upgrade is hop-by-hop", "Upgrade", true),
|
||||
Entry("Host is hop-by-hop", "Host", true),
|
||||
Entry("Content-Length is hop-by-hop", "Content-Length", true),
|
||||
// Case-insensitive — RFC 7230 doesn't constrain header case.
|
||||
Entry("lowercase connection is hop-by-hop", "connection", true),
|
||||
Entry("uppercase HOST is hop-by-hop", "HOST", true),
|
||||
// Non hop-by-hop — must NOT be stripped.
|
||||
Entry("Authorization is end-to-end", "Authorization", false),
|
||||
Entry("Content-Type is end-to-end", "Content-Type", false),
|
||||
Entry("Accept is end-to-end", "Accept", false),
|
||||
Entry("X-Custom is end-to-end", "X-Custom", false),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Forward", func() {
|
||||
It("strips hop-by-hop and Connection headers before upstream, preserves custom headers", func() {
|
||||
gotConnection := make(chan string, 1)
|
||||
gotXCustom := make(chan string, 1)
|
||||
gotHost := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotConnection <- r.Header.Get("Connection")
|
||||
gotXCustom <- r.Header.Get("X-Custom")
|
||||
gotHost <- r.Header.Get("Host")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-hopbyhop"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
{Name: "Connection", Value: "keep-alive"},
|
||||
{Name: "Host", Value: "spoofed.example.com"},
|
||||
{Name: "X-Custom", Value: "preserved"},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotConnection).To(BeEmpty(), "Connection must not leak to upstream")
|
||||
Expect(<-gotHost).NotTo(Equal("spoofed.example.com"), "Host header must not be spoofed through")
|
||||
Expect(<-gotXCustom).To(Equal("preserved"), "X-Custom header must survive")
|
||||
})
|
||||
|
||||
It("replaces caller-supplied Authorization with the configured key", func() {
|
||||
// The proxy must overwrite a client-supplied Authorization header
|
||||
// so a downstream caller can't smuggle stale or wrong credentials.
|
||||
gotAuth := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_AUTH_REPLACE_KEY", "sk-real")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_AUTH_REPLACE_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-replaces-auth"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{
|
||||
// Client-supplied Authorization with the wrong scheme / key.
|
||||
{Name: "Authorization", Value: "Basic Zm9vOmJhcg=="},
|
||||
},
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
_, _ = stream.Recv()
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("refuses to follow upstream redirects and never leaks the key to the redirect target", func() {
|
||||
// A 3xx from the configured upstream means misconfiguration or a
|
||||
// hijacked/spoofed host. Following it would replay the request —
|
||||
// and the injected API key — to the Location host. Anthropic's
|
||||
// x-api-key is NOT stripped by Go on cross-host redirects, so this
|
||||
// would be a credential leak. The proxy must refuse the redirect.
|
||||
sinkHit := make(chan string, 1)
|
||||
sink := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sinkHit <- r.Header.Get("x-api-key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer sink.Close()
|
||||
|
||||
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, sink.URL, http.StatusFound)
|
||||
}))
|
||||
defer redirector.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_REDIRECT_KEY", "ant-secret")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: redirector.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_REDIRECT_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-no-redirect"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/messages",
|
||||
Method: "POST",
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
|
||||
// Drain the stream; a refused redirect surfaces as a non-EOF error.
|
||||
var streamErr error
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
streamErr = err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(streamErr).To(HaveOccurred(), "refused redirect must surface as an error")
|
||||
Expect(sinkHit).NotTo(Receive(), "the redirect target must never be contacted")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
},
|
||||
})).To(Succeed())
|
||||
addr := "test://forward-concurrent"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
|
||||
const N = 8
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
stream, err := c.Forward(context.Background())
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
payload := "request-" + string(rune('A'+idx))
|
||||
if err := stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
BodyChunk: []byte(payload),
|
||||
}); err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv()
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
body = append(body, r.GetBodyChunk()...)
|
||||
}
|
||||
if string(body) != payload {
|
||||
errs <- &echoMismatch{want: payload, got: string(body)}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
var collected []error
|
||||
for err := range errs {
|
||||
collected = append(collected, err)
|
||||
}
|
||||
Expect(collected).To(BeEmpty(), "no concurrent Forward call should fail")
|
||||
})
|
||||
})
|
||||
|
||||
type echoMismatch struct{ want, got string }
|
||||
|
||||
func (e *echoMismatch) Error() string {
|
||||
return "echo mismatch: want " + strconv.Quote(e.want) + " got " + strconv.Quote(e.got)
|
||||
}
|
||||
508
backend/go/cloud-proxy/provider_anthropic.go
Normal file
508
backend/go/cloud-proxy/provider_anthropic.go
Normal file
@@ -0,0 +1,508 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Anthropic Messages API wire-format types. Narrowed to what translate
|
||||
// mode preserves through the Reply proto: text + tool_use blocks +
|
||||
// usage tokens. Image blocks, prompt caching, metadata, and stop
|
||||
// sequence metadata are not modelled — passthrough mode covers those.
|
||||
//
|
||||
// Notable differences from OpenAI:
|
||||
// - max_tokens is REQUIRED. Anthropic 400s without it.
|
||||
// - Roles are user/assistant only — system messages move to a
|
||||
// top-level `system` string field.
|
||||
// - Streaming SSE uses event: lines alongside data: lines. The
|
||||
// events we care about: content_block_start (carries tool_use
|
||||
// init: id + name), content_block_delta (text_delta with text;
|
||||
// input_json_delta with partial_json for tool arguments), and
|
||||
// message_stop (terminates the stream). Others are ignored.
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int32 `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []anthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice *anthropicToolChoice `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// Content is `any` because Anthropic accepts a bare string OR a
|
||||
// list of content blocks. Use the string form for plain user/
|
||||
// assistant turns; switch to []anthropicContentBlock when the
|
||||
// turn needs tool_use (assistant) or tool_result (user) blocks.
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// anthropicToolChoice mirrors the four shapes Anthropic accepts:
|
||||
// {"type":"auto"} | {"type":"any"} | {"type":"tool","name":"X"} |
|
||||
// {"type":"none"} (newer models). OpenAI's "auto"/"none"/
|
||||
// "required"/{"function":{"name":"X"}} all map here.
|
||||
type anthropicToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// anthropicContentBlock is the union shape used both for response
|
||||
// blocks (text/tool_use we read off the wire) and outbound request
|
||||
// blocks (tool_use/tool_result we emit in the conversation history).
|
||||
// Anthropic encodes tool calls inline rather than as a separate field,
|
||||
// so we walk Content[] looking for type=="tool_use" on responses and
|
||||
// produce equivalent blocks when serialising prior-turn tool calls.
|
||||
type anthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
// Tool-result block fields. tool_result uses `content` (not
|
||||
// `text`) and pairs with `tool_use_id`; modelling them as
|
||||
// distinct fields avoids ambiguity at marshal time.
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
ResultContent string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []anthropicContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// anthropicStreamEvent is the union shape used for every event type we
|
||||
// process. Type discriminates; only the matching fields are populated.
|
||||
// content_block_start carries ContentBlock (with id/name for tool_use);
|
||||
// content_block_delta carries Delta (text or partial_json).
|
||||
type anthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ContentBlock *anthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta *anthropicStreamDelta `json:"delta,omitempty"`
|
||||
Message *anthropicResponse `json:"message,omitempty"`
|
||||
Usage *anthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicStreamDelta struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
}
|
||||
|
||||
// Anthropic requires max_tokens. If the caller didn't set it, use a
|
||||
// generous-but-bounded default so the request doesn't 400.
|
||||
const anthropicDefaultMaxTokens int32 = 4096
|
||||
|
||||
const anthropicToolChoiceNone = "none"
|
||||
|
||||
// Reused JSON-Schema defaults for malformed inputs. Anthropic requires
|
||||
// input_schema to be a JSON object and tool_use.input to be a JSON
|
||||
// object; clients that omit them must not 400 the entire request.
|
||||
var (
|
||||
emptyJSONObject = json.RawMessage(`{}`)
|
||||
emptyObjectSchema = json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
)
|
||||
|
||||
func buildAnthropicRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := anthropicRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
MaxTokens: opts.GetTokens(),
|
||||
Stream: stream,
|
||||
StopSequences: opts.GetStopPrompts(),
|
||||
}
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = anthropicDefaultMaxTokens
|
||||
}
|
||||
// Newer Anthropic models 400 when both temperature and top_p are
|
||||
// set ("`temperature` and `top_p` cannot both be specified for
|
||||
// this model. Please use only one.") even though their docs only
|
||||
// "recommend" picking one. The OpenAI-compatible chat UI almost
|
||||
// always sends both with default values, so prefer temperature
|
||||
// and drop top_p when both are present.
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
} else if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
|
||||
req.Tools = convertOpenAITools(opts.GetTools())
|
||||
req.ToolChoice = convertOpenAIToolChoice(opts.GetToolChoice())
|
||||
// Anthropic rejects tool_choice without tools and older models
|
||||
// don't accept {"type":"none"} — collapse to a no-tools request.
|
||||
if req.ToolChoice != nil && req.ToolChoice.Type == anthropicToolChoiceNone {
|
||||
req.Tools, req.ToolChoice = nil, nil
|
||||
}
|
||||
|
||||
var systemParts []string
|
||||
for _, m := range opts.GetMessages() {
|
||||
role := m.GetRole()
|
||||
if role == "system" {
|
||||
if c := m.GetContent(); c != "" {
|
||||
systemParts = append(systemParts, c)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch role {
|
||||
case "user":
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "assistant":
|
||||
if blocks := assistantBlocks(m); blocks != nil {
|
||||
req.Messages = append(req.Messages, anthropicMessage{Role: "assistant", Content: blocks})
|
||||
continue
|
||||
}
|
||||
req.Messages = append(req.Messages, anthropicMessage{
|
||||
Role: "assistant",
|
||||
Content: m.GetContent(),
|
||||
})
|
||||
case "tool", "function":
|
||||
req.Messages = appendToolResult(req.Messages, anthropicContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: m.GetToolCallId(),
|
||||
ResultContent: m.GetContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
req.System = strings.Join(systemParts, "\n\n")
|
||||
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []anthropicMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// appendToolResult appends a tool_result block as a user message,
|
||||
// merging into a preceding user message that already carries blocks.
|
||||
// Anthropic concatenates consecutive same-role messages on its end,
|
||||
// but explicit merging keeps the body smaller and the conversation
|
||||
// strictly alternating — which some upstream filters require.
|
||||
func appendToolResult(msgs []anthropicMessage, block anthropicContentBlock) []anthropicMessage {
|
||||
if n := len(msgs); n > 0 && msgs[n-1].Role == "user" {
|
||||
if existing, ok := msgs[n-1].Content.([]anthropicContentBlock); ok {
|
||||
msgs[n-1].Content = append(existing, block)
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
return append(msgs, anthropicMessage{
|
||||
Role: "user",
|
||||
Content: []anthropicContentBlock{block},
|
||||
})
|
||||
}
|
||||
|
||||
func convertOpenAITools(toolsJSON string) []anthropicTool {
|
||||
if toolsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var raw []openAITool
|
||||
if err := json.Unmarshal([]byte(toolsJSON), &raw); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tools JSON, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
tools := make([]anthropicTool, 0, len(raw))
|
||||
for _, t := range raw {
|
||||
if t.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
schema := t.Function.Parameters
|
||||
if len(schema) == 0 {
|
||||
schema = emptyObjectSchema
|
||||
}
|
||||
tools = append(tools, anthropicTool{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// convertOpenAIToolChoice accepts the spec form
|
||||
// ({type:function, function:{name:X}}) and the flat legacy form
|
||||
// ({type:function, name:X}) some clients send. Unknown object shapes
|
||||
// are warned and dropped rather than silently treated as auto.
|
||||
func convertOpenAIToolChoice(toolChoiceJSON string) *anthropicToolChoice {
|
||||
if toolChoiceJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var asString string
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asString); err == nil {
|
||||
switch asString {
|
||||
case "auto":
|
||||
return &anthropicToolChoice{Type: "auto"}
|
||||
case "none":
|
||||
return &anthropicToolChoice{Type: anthropicToolChoiceNone}
|
||||
case "required":
|
||||
return &anthropicToolChoice{Type: "any"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var asObj struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolChoiceJSON), &asObj); err != nil {
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unparseable tool_choice, dropping", "error", err)
|
||||
return nil
|
||||
}
|
||||
if name := asObj.Function.Name; name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: name}
|
||||
}
|
||||
if asObj.Name != "" {
|
||||
return &anthropicToolChoice{Type: "tool", Name: asObj.Name}
|
||||
}
|
||||
xlog.Warn("cloud-proxy: anthropic translate: unrecognised tool_choice shape, dropping", "shape", toolChoiceJSON)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAITool mirrors pkg/functions.Tool but keeps Parameters as
|
||||
// json.RawMessage so the input_schema passes through verbatim — no
|
||||
// re-marshal cost, no fidelity loss on exotic schemas.
|
||||
type openAITool struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters json.RawMessage `json:"parameters"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
func assistantBlocks(m *pb.Message) []anthropicContentBlock {
|
||||
toolCallsJSON := m.GetToolCalls()
|
||||
if toolCallsJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var toolCalls []openAIToolCall
|
||||
if err := json.Unmarshal([]byte(toolCallsJSON), &toolCalls); err != nil || len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
blocks := make([]anthropicContentBlock, 0, len(toolCalls)+1)
|
||||
if text := m.GetContent(); text != "" {
|
||||
blocks = append(blocks, anthropicContentBlock{Type: "text", Text: text})
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
// OpenAI's arguments are a JSON-encoded string; pass through
|
||||
// as RawMessage so a non-JSON string from a poorly-formed
|
||||
// local model doesn't crash the marshaller downstream.
|
||||
args := json.RawMessage(tc.Function.Arguments)
|
||||
if len(args) == 0 {
|
||||
args = emptyJSONObject
|
||||
}
|
||||
blocks = append(blocks, anthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: args,
|
||||
})
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
// doAnthropicRequest is the Anthropic counterpart of doOpenAIRequest.
|
||||
// applyAuthHeader sets x-api-key and anthropic-version when provider
|
||||
// is anthropic, so this method doesn't need to duplicate that.
|
||||
func (c *CloudProxy) doAnthropicRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictAnthropicRich returns the full Reply: joined text from all
|
||||
// text blocks, tool_use blocks mapped to ToolCallDelta, and usage
|
||||
// tokens.
|
||||
func (c *CloudProxy) predictAnthropicRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildAnthropicRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed anthropicResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
|
||||
reply := &pb.Reply{}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.InputTokens)
|
||||
reply.Tokens = int32(parsed.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
var toolCalls []*pb.ToolCallDelta
|
||||
toolIdx := 0
|
||||
for _, b := range parsed.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
content.WriteString(b.Text)
|
||||
case "tool_use":
|
||||
// Input is a structured JSON object; we serialise to a
|
||||
// string so it fits the OpenAI-shaped arguments field
|
||||
// downstream consumers expect.
|
||||
args := ""
|
||||
if len(b.Input) > 0 {
|
||||
args = string(b.Input)
|
||||
}
|
||||
toolCalls = append(toolCalls, newToolCallDelta(toolIdx, b.ID, b.Name, args))
|
||||
toolIdx++
|
||||
}
|
||||
}
|
||||
reply.Message = []byte(content.String())
|
||||
if len(toolCalls) > 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{ToolCalls: toolCalls}}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictAnthropicStreamRich streams Reply chunks from Anthropic's SSE.
|
||||
// Three event types matter: content_block_start (initialises tool_use
|
||||
// id+name), content_block_delta (carries text or input_json_delta),
|
||||
// message_stop (terminates). The block index from the wire feeds
|
||||
// straight into ToolCallDelta.Index so downstream consumers can
|
||||
// reassemble multiple parallel tool calls.
|
||||
func (c *CloudProxy) predictAnthropicStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildAnthropicRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doAnthropicRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
var ev anthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
switch ev.Type {
|
||||
case "content_block_start":
|
||||
// tool_use blocks announce id + name here; arguments arrive
|
||||
// in subsequent input_json_delta events. Emit a Reply with
|
||||
// just the tool_call init fields so consumers can allocate
|
||||
// a slot at this index.
|
||||
if ev.ContentBlock != nil && ev.ContentBlock.Type == "tool_use" {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, ev.ContentBlock.ID, ev.ContentBlock.Name, ""),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if ev.Delta == nil {
|
||||
continue
|
||||
}
|
||||
switch ev.Delta.Type {
|
||||
case "text_delta":
|
||||
if ev.Delta.Text == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Message: []byte(ev.Delta.Text),
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: ev.Delta.Text}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
case "input_json_delta":
|
||||
if ev.Delta.PartialJSON == "" {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
ChatDeltas: []*pb.ChatDelta{{ToolCalls: []*pb.ToolCallDelta{
|
||||
newToolCallDelta(ev.Index, "", "", ev.Delta.PartialJSON),
|
||||
}}},
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
// Anthropic sends final usage in message_delta.usage. Emit
|
||||
// a usage-only Reply so the consumer can record totals.
|
||||
if ev.Usage != nil {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
Tokens: int32(ev.Usage.OutputTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
334
backend/go/cloud-proxy/provider_anthropic_test.go
Normal file
334
backend/go/cloud-proxy/provider_anthropic_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// fakeAnthropicUpstream mirrors fakeOpenAIUpstream but decodes the
|
||||
// request body as an anthropicRequest so tests can assert on the
|
||||
// translated wire shape (system field, max_tokens, etc.).
|
||||
func fakeAnthropicUpstream(t *testing.T, handler func(req anthropicRequest) (status int, body string, contentType string)) (*httptest.Server, *anthropicRequest) {
|
||||
t.Helper()
|
||||
var captured anthropicRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newAnthropicTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_FAKE", "sk-ant-fake")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "claude-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_FAKE",
|
||||
UpstreamModel: "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_BasicMessages(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"text","text":"hi there"}],"model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":5,"output_tokens":2}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
g.Expect(captured.Model).To(Equal("claude-3-5-sonnet-20241022"))
|
||||
// System message must be hoisted out of Messages into top-level field.
|
||||
g.Expect(captured.System).To(Equal("be brief"))
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
// Anthropic 400s when both temperature and top_p are set; the
|
||||
// translator must prefer temperature and drop top_p.
|
||||
g.Expect(captured.TopP).To(BeNil())
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
// When only top_p is set, it should be forwarded.
|
||||
func TestPredict_Anthropic_TopPOnly(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hello"}},
|
||||
TopP: 0.9,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Temperature).To(BeNil())
|
||||
// PredictOptions.TopP is float32 on the wire; the translator widens
|
||||
// to float64 so 0.9 round-trips as 0.8999999761581421… — compare
|
||||
// with a small tolerance rather than exact equality.
|
||||
g.Expect(captured.TopP).NotTo(BeNil())
|
||||
g.Expect(math.Abs(*captured.TopP - 0.9)).To(BeNumerically("<=", 1e-6))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_DefaultsMaxTokens(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic 400s without max_tokens. The translator must default
|
||||
// it when the caller doesn't supply Tokens.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.MaxTokens).To(Equal(anthropicDefaultMaxTokens))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?", Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_ConcatenatesContentBlocks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Anthropic may return multiple text blocks; the translator joins
|
||||
// them so the Predict() string return is the full assistant message.
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"hello "},{"type":"text","text":"world"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello world"))
|
||||
}
|
||||
|
||||
func TestPredict_Anthropic_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 401, `{"error":{"type":"authentication_error","message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}, Tokens: 16})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_Anthropic_StreamsTextDeltas(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Real Anthropic SSE has event: lines + data: lines. The translator
|
||||
// only needs the data: payload; only content_block_delta with
|
||||
// delta.type=text_delta carries content. message_stop ends.
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" \"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"world\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
Tokens: 16,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_TranslatesOpenAITools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "weather in Paris?"}},
|
||||
Tools: tools,
|
||||
ToolChoice: `"auto"`,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(HaveLen(1))
|
||||
g.Expect(captured.Tools[0].Name).To(Equal("get_weather"))
|
||||
g.Expect(captured.Tools[0].Description).To(Equal("Get weather"))
|
||||
// input_schema must be the parameters object verbatim.
|
||||
g.Expect(string(captured.Tools[0].InputSchema)).To(ContainSubstring(`"city"`))
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("auto"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_RequiredMapsToAny(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"required"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("any"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NoneDropsTools(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"t","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `"none"`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Tools).To(BeNil())
|
||||
g.Expect(captured.ToolChoice).To(BeNil())
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_ToolChoice_NamedFunction(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: `[{"type":"function","function":{"name":"weather","parameters":{"type":"object"}}}]`,
|
||||
ToolChoice: `{"type":"function","function":{"name":"weather"}}`,
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.ToolChoice).NotTo(BeNil())
|
||||
g.Expect(captured.ToolChoice.Type).To(Equal("tool"))
|
||||
g.Expect(captured.ToolChoice.Name).To(Equal("weather"))
|
||||
}
|
||||
|
||||
func TestBuildAnthropic_RoundTripsAssistantToolCalls(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// LocalAI Assistant's second turn: the LLM previously emitted a
|
||||
// tool_use, the server executed it, and the conversation now
|
||||
// includes the assistant turn (with tool_calls) plus a tool-role
|
||||
// result message. Both must convert to Anthropic block form.
|
||||
srv, captured := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"content":[{"type":"text","text":"ok"}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
tools := `[{"type":"function","function":{"name":"list_models","parameters":{"type":"object"}}}]`
|
||||
toolCallsJSON := `[{"id":"call_abc","type":"function","function":{"name":"list_models","arguments":"{}"}}]`
|
||||
_, err := cp.Predict(&pb.PredictOptions{
|
||||
Tools: tools,
|
||||
Messages: []*pb.Message{
|
||||
{Role: "user", Content: "what models are installed?"},
|
||||
{Role: "assistant", Content: "", ToolCalls: toolCallsJSON},
|
||||
{Role: "tool", Content: `{"models":["a","b"]}`, ToolCallId: "call_abc"},
|
||||
},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
g.Expect(captured.Messages).To(HaveLen(3))
|
||||
// 1. user text — bare string
|
||||
s, ok := captured.Messages[0].Content.(string)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(s).To(Equal("what models are installed?"))
|
||||
// 2. assistant — must be a content-block list with one tool_use
|
||||
// json.Unmarshal of `any` produces []any not []anthropicContentBlock.
|
||||
blocks, ok := captured.Messages[1].Content.([]any)
|
||||
g.Expect(ok).To(BeTrue())
|
||||
g.Expect(blocks).To(HaveLen(1))
|
||||
b0, _ := blocks[0].(map[string]any)
|
||||
g.Expect(b0["type"]).To(Equal("tool_use"))
|
||||
g.Expect(b0["id"]).To(Equal("call_abc"))
|
||||
g.Expect(b0["name"]).To(Equal("list_models"))
|
||||
// 3. tool → user with tool_result block
|
||||
g.Expect(captured.Messages[2].Role).To(Equal("user"))
|
||||
resBlocks, _ := captured.Messages[2].Content.([]any)
|
||||
r0, _ := resBlocks[0].(map[string]any)
|
||||
g.Expect(r0["type"]).To(Equal("tool_result"))
|
||||
g.Expect(r0["tool_use_id"]).To(Equal("call_abc"))
|
||||
g.Expect(r0["content"]).To(Equal(`{"models":["a","b"]}`))
|
||||
}
|
||||
119
backend/go/cloud-proxy/provider_edge_test.go
Normal file
119
backend/go/cloud-proxy/provider_edge_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Verify buildOpenAIRequest preserves caller-supplied tools and
|
||||
// tool_choice as opaque JSON. PredictOptions carries them as strings;
|
||||
// they must land in the outbound request body unchanged so the
|
||||
// upstream sees the caller's intent verbatim. A regression here would
|
||||
// silently disable function calling for translate-mode clients.
|
||||
func TestBuildOpenAIRequest_ToolsAndToolChoicePassthrough(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
toolsJSON := `[{"type":"function","function":{"name":"search","parameters":{"type":"object"}}}]`
|
||||
choiceJSON := `{"type":"function","function":{"name":"search"}}`
|
||||
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find x"}},
|
||||
Tools: toolsJSON,
|
||||
ToolChoice: choiceJSON,
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var decoded openAIRequest
|
||||
err = json.Unmarshal(body, &decoded)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// Compare the JSON-canonical form so whitespace differences are ignored.
|
||||
gotTools, _ := json.Marshal(json.RawMessage(decoded.Tools))
|
||||
wantTools, _ := json.Marshal(json.RawMessage(toolsJSON))
|
||||
g.Expect(string(gotTools)).To(Equal(string(wantTools)))
|
||||
gotChoice, _ := json.Marshal(json.RawMessage(decoded.ToolChoice))
|
||||
wantChoice, _ := json.Marshal(json.RawMessage(choiceJSON))
|
||||
g.Expect(string(gotChoice)).To(Equal(string(wantChoice)))
|
||||
}
|
||||
|
||||
// Garbage JSON in tools / tool_choice is silently dropped (omitted)
|
||||
// rather than blowing up the request. Documents the parseRawJSON
|
||||
// behaviour — operators shouldn't see hard failures from an upstream
|
||||
// caller's mis-formatted tools field.
|
||||
func TestBuildOpenAIRequest_InvalidToolsJSONDropped(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cfg := &proxyConfig{upstreamModel: "gpt-4o"}
|
||||
body, err := buildOpenAIRequest(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tools: "this is not json",
|
||||
ToolChoice: "{also bad",
|
||||
}, cfg, false)
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("this is not json"))
|
||||
g.Expect(string(body)).NotTo(ContainSubstring("{also bad"))
|
||||
}
|
||||
|
||||
// Anthropic empty content array yields an empty Reply (not an error).
|
||||
// Mirrors how an upstream tool_use-only response might arrive — the
|
||||
// content array can legitimately be empty in some edge cases.
|
||||
func TestPredictRich_Anthropic_EmptyContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{"id":"m1","type":"message","role":"assistant","content":[],"usage":{"input_tokens":3,"output_tokens":0}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "x"}},
|
||||
Tokens: 16,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(0))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(3)))
|
||||
}
|
||||
|
||||
// A truncated / malformed SSE payload mid-stream should be tolerated:
|
||||
// the malformed chunk gets skipped (xlog.Debug logged), valid chunks
|
||||
// before AND after it still reach the channel.
|
||||
func TestPredictStreamRich_OpenAI_TolerantOfBadChunks(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
body := strings.Join([]string{
|
||||
`data: {"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
``,
|
||||
`data: this-is-not-json{{`,
|
||||
``,
|
||||
`data: {"choices":[{"index":0,"delta":{"content":" world"}}]}`,
|
||||
``,
|
||||
`data: [DONE]`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var assembled strings.Builder
|
||||
for reply := range results {
|
||||
assembled.Write(reply.GetMessage())
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
// The good chunks before and after the malformed one both made it through.
|
||||
g.Expect(assembled.String()).To(Equal("hello world"))
|
||||
}
|
||||
320
backend/go/cloud-proxy/provider_openai.go
Normal file
320
backend/go/cloud-proxy/provider_openai.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// OpenAI Chat Completions wire-format types. Narrowed to the fields
|
||||
// translate mode needs to preserve through the Reply proto: content,
|
||||
// role, tool_calls (typed so we can map them to pb.ToolCallDelta),
|
||||
// and sampling params copied verbatim from PredictOptions.
|
||||
//
|
||||
// Provider-specific extensions (logit_bias, function calling beyond
|
||||
// tool_calls, etc.) are not modelled — passthrough mode covers callers
|
||||
// that need full upstream fidelity.
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int32 `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// openAIToolCall covers both the non-streaming response shape (full
|
||||
// id+function+arguments) and the streaming-delta shape (sparse fields,
|
||||
// index assignment). The proto's ToolCallDelta absorbs both — name is
|
||||
// set on first appearance, arguments arrive incrementally in streaming.
|
||||
type openAIToolCall struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function openAIFunctionCall `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type openAIFunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Choices []openAIChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChunk struct {
|
||||
Choices []openAIStreamChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// buildOpenAIRequest converts pb.PredictOptions into the OpenAI Chat
|
||||
// Completions request body. Prefers Messages when non-empty; falls
|
||||
// back to wrapping Prompt as a single user message so plain
|
||||
// /completions-style calls still work in translate mode.
|
||||
func buildOpenAIRequest(opts *pb.PredictOptions, cfg *proxyConfig, stream bool) ([]byte, error) {
|
||||
req := openAIRequest{
|
||||
Model: modelName(cfg, opts),
|
||||
Stream: stream,
|
||||
Stop: opts.GetStopPrompts(),
|
||||
Tools: parseRawJSON(opts.GetTools()),
|
||||
ToolChoice: parseRawJSON(opts.GetToolChoice()),
|
||||
}
|
||||
if t := opts.GetTemperature(); t != 0 {
|
||||
v := float64(t)
|
||||
req.Temperature = &v
|
||||
}
|
||||
if t := opts.GetTopP(); t != 0 {
|
||||
v := float64(t)
|
||||
req.TopP = &v
|
||||
}
|
||||
if n := opts.GetTokens(); n > 0 {
|
||||
req.MaxTokens = &n
|
||||
}
|
||||
if p := opts.GetFrequencyPenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.FrequencyPenalty = &v
|
||||
}
|
||||
if p := opts.GetPresencePenalty(); p != 0 {
|
||||
v := float64(p)
|
||||
req.PresencePenalty = &v
|
||||
}
|
||||
|
||||
for _, m := range opts.GetMessages() {
|
||||
msg := openAIMessage{
|
||||
Role: m.GetRole(),
|
||||
Content: m.GetContent(),
|
||||
Name: m.GetName(),
|
||||
ToolCallID: m.GetToolCallId(),
|
||||
}
|
||||
// Pre-existing tool_calls arrive as a JSON string from the
|
||||
// upstream caller's previous assistant turn; pass-through as-is.
|
||||
if tc := m.GetToolCalls(); tc != "" {
|
||||
_ = json.Unmarshal([]byte(tc), &msg.ToolCalls)
|
||||
}
|
||||
req.Messages = append(req.Messages, msg)
|
||||
}
|
||||
// Fallback for plain Prompt requests (no Messages array). LocalAI
|
||||
// templating may have produced a flat prompt; rewrap as a single
|
||||
// user message so the upstream chat endpoint accepts it.
|
||||
if len(req.Messages) == 0 && opts.GetPrompt() != "" {
|
||||
req.Messages = []openAIMessage{{Role: "user", Content: opts.GetPrompt()}}
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
// modelName picks the upstream model: upstream_model from the proxy
|
||||
// config wins (operator override), else the local model name captured
|
||||
// at LoadModel time. Operator sets upstream_model to map LocalAI's
|
||||
// alias (e.g. "claude-strict") to the upstream's canonical name
|
||||
// (e.g. "claude-3-5-sonnet-20241022").
|
||||
func modelName(cfg *proxyConfig, _ *pb.PredictOptions) string {
|
||||
if cfg.upstreamModel != "" {
|
||||
return cfg.upstreamModel
|
||||
}
|
||||
return cfg.localModel
|
||||
}
|
||||
|
||||
// parseRawJSON parses a JSON string into a RawMessage so it round-trips
|
||||
// into the upstream body. Returns nil for empty/invalid input so the
|
||||
// field is omitted (omitempty).
|
||||
func parseRawJSON(s string) json.RawMessage {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
var probe json.RawMessage
|
||||
if err := json.Unmarshal([]byte(s), &probe); err != nil {
|
||||
return nil
|
||||
}
|
||||
return probe
|
||||
}
|
||||
|
||||
// doOpenAIRequest builds + sends the upstream request. Returns the
|
||||
// raw response on success; caller handles status / body.
|
||||
func (c *CloudProxy) doOpenAIRequest(ctx context.Context, cfg *proxyConfig, body []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream request: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// predictOpenAIRich is the non-streaming translate path. Returns a
|
||||
// fully-populated *pb.Reply with assistant content, tool calls, and
|
||||
// token usage. The gRPC server forwards the Reply verbatim.
|
||||
func (c *CloudProxy) predictOpenAIRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions) (*pb.Reply, error) {
|
||||
body, err := buildOpenAIRequest(opts, cfg, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return nil, fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var parsed openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, fmt.Errorf("cloud-proxy: decode response: %w", err)
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return nil, errors.New("cloud-proxy: upstream returned no choices")
|
||||
}
|
||||
|
||||
choice := parsed.Choices[0]
|
||||
reply := &pb.Reply{
|
||||
Message: []byte(choice.Message.Content),
|
||||
}
|
||||
if parsed.Usage != nil {
|
||||
reply.PromptTokens = int32(parsed.Usage.PromptTokens)
|
||||
reply.Tokens = int32(parsed.Usage.CompletionTokens)
|
||||
}
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// Non-streaming: a single ChatDelta carries the full tool-call
|
||||
// set. Index/Name/Arguments are populated together; downstream
|
||||
// consumers don't need to assemble streaming deltas.
|
||||
delta := &pb.ChatDelta{}
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
delta.ToolCalls = append(delta.ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
reply.ChatDeltas = []*pb.ChatDelta{delta}
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// predictOpenAIStreamRich streams *pb.Reply chunks. Each chunk carries
|
||||
// either a content delta (Message + ChatDeltas[].Content) or tool-call
|
||||
// deltas (ChatDeltas[].ToolCalls). The final Reply carries usage tokens
|
||||
// when the upstream sends them (stream_options.include_usage).
|
||||
func (c *CloudProxy) predictOpenAIStreamRich(ctx context.Context, cfg *proxyConfig, opts *pb.PredictOptions, results chan<- *pb.Reply) error {
|
||||
body, err := buildOpenAIRequest(opts, cfg, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cloud-proxy: marshal request: %w", err)
|
||||
}
|
||||
resp, err := c.doOpenAIRequest(ctx, cfg, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("cloud-proxy: upstream %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
var chunk openAIStreamChunk
|
||||
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||
xlog.Debug("cloud-proxy: skip malformed SSE chunk", "error", err)
|
||||
continue
|
||||
}
|
||||
// Usage frames may arrive separately from content frames when
|
||||
// stream_options.include_usage is set; emit a usage-only Reply
|
||||
// in that case so the consumer sees the totals.
|
||||
if chunk.Usage != nil && len(chunk.Choices) == 0 {
|
||||
if !sendReply(ctx, results, &pb.Reply{
|
||||
PromptTokens: int32(chunk.Usage.PromptTokens),
|
||||
Tokens: int32(chunk.Usage.CompletionTokens),
|
||||
}) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, ch := range chunk.Choices {
|
||||
reply := &pb.Reply{}
|
||||
if ch.Delta.Content != "" {
|
||||
reply.Message = []byte(ch.Delta.Content)
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{Content: ch.Delta.Content}}
|
||||
}
|
||||
if len(ch.Delta.ToolCalls) > 0 {
|
||||
if len(reply.ChatDeltas) == 0 {
|
||||
reply.ChatDeltas = []*pb.ChatDelta{{}}
|
||||
}
|
||||
for _, tc := range ch.Delta.ToolCalls {
|
||||
reply.ChatDeltas[0].ToolCalls = append(reply.ChatDeltas[0].ToolCalls,
|
||||
newToolCallDelta(tc.Index, tc.ID, tc.Function.Name, tc.Function.Arguments))
|
||||
}
|
||||
}
|
||||
if reply.Message == nil && len(reply.ChatDeltas) == 0 {
|
||||
continue
|
||||
}
|
||||
if !sendReply(ctx, results, reply) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
170
backend/go/cloud-proxy/provider_openai_test.go
Normal file
170
backend/go/cloud-proxy/provider_openai_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// fakeOpenAIUpstream returns an httptest.Server that decodes the
|
||||
// inbound request as an openAIRequest, calls handler with it, and
|
||||
// writes the handler's reply as the response.
|
||||
func fakeOpenAIUpstream(t *testing.T, handler func(req openAIRequest) (status int, body string, contentType string)) (*httptest.Server, *openAIRequest) {
|
||||
t.Helper()
|
||||
var captured openAIRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(raw, &captured)
|
||||
status, body, ct := handler(captured)
|
||||
w.Header().Set("Content-Type", ct)
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
return srv, &captured
|
||||
}
|
||||
|
||||
func newTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
|
||||
t.Helper()
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_OPENAI_FAKE", "sk-fake-openai")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Model: "gpt-4o-local",
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstreamURL,
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
ApiKeyEnv: "CLOUD_PROXY_OPENAI_FAKE",
|
||||
UpstreamModel: "gpt-4o",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
return cp
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_BasicChat(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"id":"resp-1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{
|
||||
{Role: "system", Content: "be brief"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Temperature: 0.5,
|
||||
TopP: 0.9,
|
||||
Tokens: 32,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hi there"))
|
||||
|
||||
// Verify the upstream saw a properly-translated request.
|
||||
g.Expect(captured.Model).To(Equal("gpt-4o"))
|
||||
g.Expect(captured.Messages).To(HaveLen(2))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("system"))
|
||||
g.Expect(captured.Messages[1].Role).To(Equal("user"))
|
||||
g.Expect(captured.Temperature).NotTo(BeNil())
|
||||
g.Expect(*captured.Temperature).To(Equal(0.5))
|
||||
g.Expect(captured.MaxTokens).NotTo(BeNil())
|
||||
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
|
||||
g.Expect(captured.Stream).To(BeFalse())
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_PromptFallback(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// No Messages array — backend should synth a single user message
|
||||
// from Prompt so non-chat clients still route through translate.
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(captured.Messages).To(HaveLen(1))
|
||||
g.Expect(captured.Messages[0].Role).To(Equal("user"))
|
||||
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
|
||||
}
|
||||
|
||||
func TestPredict_OpenAI_UpstreamError(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 401, `{"error":{"message":"bad key"}}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("401"))
|
||||
}
|
||||
|
||||
func TestPredictStream_OpenAI_StreamsContent(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Stream three content deltas then [DONE]. Verify the channel
|
||||
// receives them in order with no missing pieces.
|
||||
chunks := []string{
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":" "}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"world"}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan string, 8)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStream(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var got []string
|
||||
for s := range results {
|
||||
got = append(got, s)
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
|
||||
g.Expect(captured.Stream).To(BeTrue())
|
||||
}
|
||||
|
||||
func TestPredict_RejectedInPassthroughMode(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
t.Setenv("CLOUD_PROXY_FAKE", "k")
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_, err = cp.Predict(&pb.PredictOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("only valid in translate"))
|
||||
}
|
||||
436
backend/go/cloud-proxy/proxy.go
Normal file
436
backend/go/cloud-proxy/proxy.go
Normal file
@@ -0,0 +1,436 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
// import core to keep the boundary clean.
|
||||
const (
|
||||
modePassthrough = "passthrough"
|
||||
modeTranslate = "translate"
|
||||
|
||||
providerOpenAI = "openai"
|
||||
providerAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// CloudProxy is the LocalAI backend that proxies model traffic to a
|
||||
// configured upstream HTTP provider. Concurrency: base.SingleThread is
|
||||
// NOT embedded — forward calls are independent and HTTP transport is
|
||||
// goroutine-safe, so multiple Forward streams can run in parallel.
|
||||
// Locking would serialise requests to a chat provider for no benefit.
|
||||
type CloudProxy struct {
|
||||
base.Base
|
||||
|
||||
cfg atomic.Pointer[proxyConfig]
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type proxyConfig struct {
|
||||
upstreamURL string
|
||||
mode string
|
||||
provider string
|
||||
upstreamModel string
|
||||
localModel string // ModelOptions.Model — fallback when upstream_model is unset
|
||||
apiKey string // resolved at Load time
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// httpclient.New refuses redirects outright: the proxy talks to a
|
||||
// single configured upstream API (OpenAI/Anthropic/...) that answers
|
||||
// directly, so a 3xx means misconfiguration, a hijacked upstream, or
|
||||
// DNS trickery — never normal operation. Following it would replay the
|
||||
// request, including the operator's x-api-key (which Go does NOT strip
|
||||
// on cross-host redirects), to an unvetted host and leak the key
|
||||
// (GHSA-3mj3-57v2-4636). It also imposes no body deadline, so streaming
|
||||
// SSE responses that legitimately last minutes are not truncated.
|
||||
return &CloudProxy{client: httpclient.New()}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
po := opts.GetProxy()
|
||||
if po == nil {
|
||||
return errors.New("cloud-proxy: Load requires ProxyOptions to be set")
|
||||
}
|
||||
if po.GetUpstreamUrl() == "" {
|
||||
return errors.New("cloud-proxy: upstream_url is required")
|
||||
}
|
||||
if _, err := url.ParseRequestURI(po.GetUpstreamUrl()); err != nil {
|
||||
return fmt.Errorf("cloud-proxy: upstream_url %q invalid: %w", po.GetUpstreamUrl(), err)
|
||||
}
|
||||
|
||||
mode := po.GetMode()
|
||||
if mode == "" {
|
||||
mode = modePassthrough
|
||||
}
|
||||
switch mode {
|
||||
case modePassthrough:
|
||||
case modeTranslate:
|
||||
switch po.GetProvider() {
|
||||
case providerOpenAI:
|
||||
// implemented in provider_openai.go
|
||||
case providerAnthropic:
|
||||
// implemented in provider_anthropic.go
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: translate mode requires provider in {%s, %s}, got %q",
|
||||
providerOpenAI, providerAnthropic, po.GetProvider())
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: unknown mode %q", mode)
|
||||
}
|
||||
|
||||
key, err := resolveAPIKey(po.GetApiKeyEnv(), po.GetApiKeyFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.cfg.Store(&proxyConfig{
|
||||
upstreamURL: po.GetUpstreamUrl(),
|
||||
mode: mode,
|
||||
provider: po.GetProvider(),
|
||||
upstreamModel: po.GetUpstreamModel(),
|
||||
localModel: opts.GetModel(),
|
||||
apiKey: key,
|
||||
})
|
||||
xlog.Info("cloud-proxy: ready",
|
||||
"upstream", po.GetUpstreamUrl(),
|
||||
"mode", mode,
|
||||
"provider", po.GetProvider(),
|
||||
"has_key", key != "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveAPIKey mirrors config.ProxyConfig.ResolveAPIKey. Duplicated
|
||||
// (a few lines) rather than importing core/config from a backend
|
||||
// binary — keeps backends independent of core's package layout.
|
||||
// Mutual-exclusion is enforced upstream in core/config.Validate.
|
||||
func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
if envName != "" {
|
||||
v := os.Getenv(envName)
|
||||
if v == "" {
|
||||
return "", fmt.Errorf("cloud-proxy: api_key_env %q is unset", envName)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
if filePath != "" {
|
||||
b, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: read api_key_file %q: %w", filePath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(b)), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// PredictRich is the non-streaming translate path. Returns a fully-
|
||||
// populated *pb.Reply: content, tool-call deltas (ChatDeltas), and
|
||||
// usage tokens. Implements the optional grpc.AIModelRich interface;
|
||||
// the gRPC server prefers this path over Predict when present so
|
||||
// tool calls survive the round-trip. Passthrough mode rejects
|
||||
// PredictRich — callers must use Forward.
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIRich(ctx, cfg, opts)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicRich(ctx, cfg, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("cloud-proxy: predict not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// PredictStreamRich is the rich streaming counterpart of PredictRich.
|
||||
// Each emitted Reply carries either a content delta, tool-call deltas,
|
||||
// or usage tokens (the final upstream frame). base.Base.PredictStream
|
||||
// is bypassed when AIModelRich is implemented, so the channel is
|
||||
// closed by the gRPC server pump.
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
}
|
||||
xlog.Info("cloud-proxy: predict-stream", "provider", cfg.provider, "upstream", cfg.upstreamURL, "upstream_model", cfg.upstreamModel)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: predict-stream failed", "provider", cfg.provider, "error", err)
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
switch cfg.provider {
|
||||
case providerOpenAI:
|
||||
return c.predictOpenAIStreamRich(ctx, cfg, opts, results)
|
||||
case providerAnthropic:
|
||||
return c.predictAnthropicStreamRich(ctx, cfg, opts, results)
|
||||
default:
|
||||
return fmt.Errorf("cloud-proxy: predictStream not implemented for provider %q", cfg.provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Predict is the legacy (string, error) AIModel signature. Used only
|
||||
// if a caller goes through the non-rich path (it shouldn't, since
|
||||
// server.go prefers PredictRich). Provided so the AIModel interface
|
||||
// is satisfied for backends that haven't opted into the rich variant.
|
||||
func (c *CloudProxy) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
reply, err := c.PredictRich(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.GetMessage()), nil
|
||||
}
|
||||
|
||||
// PredictStream is the legacy chan-string streaming path. Adapts the
|
||||
// rich stream by extracting only content text — tool-call-only chunks
|
||||
// (no Message bytes) and usage-only chunks are silently dropped, since
|
||||
// the legacy chan-string contract cannot represent them. Consumers
|
||||
// that need tool calls must call PredictStreamRich directly.
|
||||
func (c *CloudProxy) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
defer close(results)
|
||||
richCh := make(chan *pb.Reply)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- c.PredictStreamRich(opts, richCh)
|
||||
close(richCh)
|
||||
}()
|
||||
for reply := range richCh {
|
||||
if msg := reply.GetMessage(); len(msg) > 0 {
|
||||
results <- string(msg)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// sendReply pushes one Reply onto a stream channel honouring ctx
|
||||
// cancellation. Returns false on cancel so the caller can exit with
|
||||
// ctx.Err(). Used by both translate-mode providers.
|
||||
func sendReply(ctx context.Context, results chan<- *pb.Reply, reply *pb.Reply) bool {
|
||||
select {
|
||||
case results <- reply:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newToolCallDelta is a small constructor for the cross-provider
|
||||
// tool-call delta shape. Centralised so the int32 cast and the four
|
||||
// fields stay consistent across the OpenAI / Anthropic translators.
|
||||
// Empty name/args are valid — Anthropic streaming announces the call
|
||||
// with id+name then sends arguments incrementally; OpenAI's reverse
|
||||
// pattern (args without name) also lands here.
|
||||
func newToolCallDelta(index int, id, name, args string) *pb.ToolCallDelta {
|
||||
return &pb.ToolCallDelta{
|
||||
Index: int32(index),
|
||||
Id: id,
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward shovels bytes between a Forward gRPC stream and an upstream
|
||||
// HTTP request. First request message carries path/method/headers and
|
||||
// the initial body chunk; subsequent messages append body chunks. The
|
||||
// first reply carries upstream status + response headers; subsequent
|
||||
// replies stream body chunks until the upstream connection closes.
|
||||
// Cancellation of ctx (the gRPC stream context) closes the upstream
|
||||
// connection.
|
||||
func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest, out chan<- *pb.ForwardReply) error {
|
||||
defer close(out)
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
}
|
||||
|
||||
first, ok := <-in
|
||||
if !ok {
|
||||
return errors.New("cloud-proxy: Forward stream closed before first request")
|
||||
}
|
||||
|
||||
// Honour the per-request path only when the configured upstream_url
|
||||
// has no path of its own — gallery convention is to put the
|
||||
// canonical path in upstream_url.
|
||||
fullURL, err := composeURL(cfg.upstreamURL, first.GetPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method := first.GetMethod()
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
// Pipe the body in from the gRPC stream so the HTTP request can
|
||||
// start before the client finishes sending. The pipe-reader is
|
||||
// closed via CloseWithError on the error paths so the writer
|
||||
// goroutine doesn't block forever.
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go func() {
|
||||
var writeErr error
|
||||
defer func() { _ = pw.CloseWithError(writeErr) }()
|
||||
if len(first.GetBodyChunk()) > 0 {
|
||||
if _, writeErr = pw.Write(first.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
for req := range in {
|
||||
if len(req.GetBodyChunk()) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, writeErr = pw.Write(req.GetBodyChunk()); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fullURL, pr)
|
||||
if err != nil {
|
||||
_ = pr.CloseWithError(err) // unblocks the body-pump's pw.Write
|
||||
return fmt.Errorf("cloud-proxy: build request: %w", err)
|
||||
}
|
||||
|
||||
// Apply caller-supplied headers, then override with the
|
||||
// authorization header derived from the resolved key. Caller-
|
||||
// supplied Authorization is always replaced — operators may not
|
||||
// know the backend's auth scheme, and silently leaking through a
|
||||
// client Authorization header to a different upstream would
|
||||
// confuse the upstream and could leak credentials.
|
||||
for _, h := range first.GetHeaders() {
|
||||
if h == nil || h.GetName() == "" {
|
||||
continue
|
||||
}
|
||||
// Strip hop-by-hop headers that aren't meaningful to the
|
||||
// upstream (Host is set by the http client from the URL;
|
||||
// Content-Length is computed from the body).
|
||||
if isHopByHopHeader(h.GetName()) {
|
||||
continue
|
||||
}
|
||||
req.Header.Add(h.GetName(), h.GetValue())
|
||||
}
|
||||
if cfg.apiKey != "" {
|
||||
applyAuthHeader(req, cfg.provider, cfg.apiKey)
|
||||
}
|
||||
|
||||
xlog.Info("cloud-proxy: forward", "method", method, "url", fullURL, "provider", cfg.provider)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
xlog.Warn("cloud-proxy: forward upstream failed", "url", fullURL, "error", err)
|
||||
return fmt.Errorf("cloud-proxy: upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
logFn := xlog.Info
|
||||
if resp.StatusCode >= 400 {
|
||||
logFn = xlog.Warn
|
||||
}
|
||||
logFn("cloud-proxy: forward response", "url", fullURL, "status", resp.StatusCode)
|
||||
|
||||
// First reply: status + response headers, no body.
|
||||
headers := make([]*pb.ForwardHeader, 0, len(resp.Header))
|
||||
for k, vs := range resp.Header {
|
||||
for _, v := range vs {
|
||||
headers = append(headers, &pb.ForwardHeader{Name: k, Value: v})
|
||||
}
|
||||
}
|
||||
out <- &pb.ForwardReply{Status: int32(resp.StatusCode), Headers: headers}
|
||||
|
||||
// Subsequent replies: body chunks. Use a fixed 8KB buffer — small
|
||||
// enough that SSE token frames flush promptly, large enough that
|
||||
// long chunked-transfer bodies aren't death by a thousand reads.
|
||||
buf := make([]byte, 8*1024)
|
||||
for {
|
||||
n, rerr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
out <- &pb.ForwardReply{BodyChunk: chunk}
|
||||
}
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("cloud-proxy: upstream body read: %w", rerr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// composeURL combines the configured upstream URL with the per-request
|
||||
// path. The upstream URL typically already includes the canonical path
|
||||
// (e.g. https://api.openai.com/v1/chat/completions) so the per-request
|
||||
// path is ignored in that case. When upstream_url is a bare host
|
||||
// (https://api.openai.com), the request path is appended.
|
||||
func composeURL(upstream, reqPath string) (string, error) {
|
||||
u, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cloud-proxy: parse upstream_url %q: %w", upstream, err)
|
||||
}
|
||||
if u.Path == "" || u.Path == "/" {
|
||||
u.Path = reqPath
|
||||
}
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// applyAuthHeader writes the appropriate authorization header for the
|
||||
// provider. OpenAI/Anthropic/most providers use Bearer; Anthropic
|
||||
// historically uses x-api-key + anthropic-version, but accepts Bearer
|
||||
// too via the OpenAI-compatible path. Default to Bearer when provider
|
||||
// is empty (passthrough mode where the operator doesn't claim a
|
||||
// provider).
|
||||
func applyAuthHeader(req *http.Request, provider, key string) {
|
||||
switch provider {
|
||||
case providerAnthropic:
|
||||
req.Header.Set("x-api-key", key)
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
}
|
||||
|
||||
// isHopByHopHeader returns true for headers that should not be
|
||||
// forwarded from the client request to the upstream (RFC 7230 §6.1
|
||||
// hop-by-hop list, plus a few that the http.Client sets itself).
|
||||
func isHopByHopHeader(name string) bool {
|
||||
switch strings.ToLower(name) {
|
||||
case "connection", "proxy-connection", "keep-alive", "transfer-encoding",
|
||||
"te", "trailer", "upgrade", "host", "content-length":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
206
backend/go/cloud-proxy/proxy_test.go
Normal file
206
backend/go/cloud-proxy/proxy_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// helper: run a CloudProxy in-process via grpc.Provide so tests can
|
||||
// call Forward through the public Backend interface without listening
|
||||
// on a real socket.
|
||||
func newInProcClient(t *testing.T, proxy *CloudProxy) grpc.Backend {
|
||||
t.Helper()
|
||||
addr := "test://" + t.Name()
|
||||
grpc.Provide(addr, proxy)
|
||||
return grpc.NewClient(addr, true, nil, false)
|
||||
}
|
||||
|
||||
func TestForward_PassthroughEcho(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
// Fake upstream: echoes the request body back, prefixed with a
|
||||
// canary so the test can assert both that the body reached the
|
||||
// upstream and the response made it back to the client.
|
||||
gotBody := make(chan string, 1)
|
||||
gotAuth := make(chan string, 1)
|
||||
gotPath := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody <- string(body)
|
||||
gotAuth <- r.Header.Get("Authorization")
|
||||
gotPath <- r.URL.Path
|
||||
w.Header().Set("X-Echo", "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("echo: " + string(body)))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_FAKE_KEY", "sk-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
ApiKeyEnv: "CLOUD_PROXY_FAKE_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: []*pb.ForwardHeader{{Name: "Content-Type", Value: "application/json"}},
|
||||
BodyChunk: []byte(`{"prompt":`),
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{BodyChunk: []byte(`"hi"}`)})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.CloseSend()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// First reply: status + headers.
|
||||
first, err := stream.Recv()
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(first.Status).To(Equal(int32(http.StatusOK)))
|
||||
g.Expect(hasHeader(first.Headers, "X-Echo", "true")).To(BeTrue())
|
||||
|
||||
// Subsequent replies: body.
|
||||
var body []byte
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
body = append(body, r.BodyChunk...)
|
||||
}
|
||||
g.Expect(string(body)).To(Equal(`echo: {"prompt":"hi"}`))
|
||||
|
||||
// Upstream observations.
|
||||
var gotBodyVal, gotAuthVal, gotPathVal string
|
||||
g.Eventually(gotBody).Should(Receive(&gotBodyVal), "upstream never saw body")
|
||||
g.Expect(gotBodyVal).To(Equal(`{"prompt":"hi"}`))
|
||||
g.Eventually(gotAuth).Should(Receive(&gotAuthVal), "upstream never saw auth header")
|
||||
g.Expect(gotAuthVal).To(Equal("Bearer sk-fake"))
|
||||
g.Eventually(gotPath).Should(Receive(&gotPathVal), "upstream never saw path")
|
||||
g.Expect(gotPathVal).To(Equal("/v1/chat/completions"))
|
||||
}
|
||||
|
||||
func TestForward_AnthropicAuthHeader(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
gotXAPIKey := make(chan string, 1)
|
||||
gotVersion := make(chan string, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotXAPIKey <- r.Header.Get("x-api-key")
|
||||
gotVersion <- r.Header.Get("anthropic-version")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
t.Setenv("CLOUD_PROXY_ANTHROPIC_KEY", "sk-ant-fake")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
err := cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: upstream.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_ANTHROPIC_KEY",
|
||||
},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
err = stream.Send(&pb.ForwardRequest{Path: "/v1/messages", Method: "POST"})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, _ = stream.Recv() // drain status
|
||||
for {
|
||||
if _, err := stream.Recv(); errors.Is(err, io.EOF) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
g.Expect(<-gotXAPIKey).To(Equal("sk-ant-fake"))
|
||||
g.Expect(<-gotVersion).NotTo(BeEmpty())
|
||||
}
|
||||
|
||||
func TestLoad_ValidatesConfig(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
|
||||
err := cp.Load(&pb.ModelOptions{})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("ProxyOptions"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("upstream_url"))
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
Mode: "rewrite",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unknown mode"))
|
||||
|
||||
// translate + openai should load successfully (Phase 5).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/chat/completions",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerOpenAI,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// translate + anthropic should load successfully (Phase 6).
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com/v1/messages",
|
||||
Mode: modeTranslate,
|
||||
Provider: providerAnthropic,
|
||||
}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = cp.Load(&pb.ModelOptions{Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: "https://example.com",
|
||||
ApiKeyEnv: "DEFINITELY_UNSET_ENV_VAR_XYZ",
|
||||
}})
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("unset"))
|
||||
}
|
||||
|
||||
func TestForward_RejectsWithoutLoad(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
cp := NewCloudProxy()
|
||||
c := newInProcClient(t, cp)
|
||||
stream, err := c.Forward(context.Background())
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
_ = stream.CloseSend()
|
||||
_, err = stream.Recv()
|
||||
g.Expect(err).To(HaveOccurred())
|
||||
g.Expect(err.Error()).To(ContainSubstring("not loaded"))
|
||||
}
|
||||
|
||||
func hasHeader(hs []*pb.ForwardHeader, name, value string) bool {
|
||||
for _, h := range hs {
|
||||
if strings.EqualFold(h.GetName(), name) && h.GetValue() == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
6
backend/go/cloud-proxy/run.sh
Executable file
6
backend/go/cloud-proxy/run.sh
Executable file
@@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
exec $CURDIR/cloud-proxy "$@"
|
||||
232
backend/go/cloud-proxy/toolcalls_test.go
Normal file
232
backend/go/cloud-proxy/toolcalls_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// OpenAI: non-streaming tool call response. Verify the response is
|
||||
// mapped to Reply.ChatDeltas[].ToolCalls with id/name/arguments intact,
|
||||
// and usage tokens land on Reply.PromptTokens / Reply.Tokens.
|
||||
func TestPredictRich_OpenAI_ToolCalls(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"resp-1",
|
||||
"choices":[{
|
||||
"index":0,
|
||||
"message":{
|
||||
"role":"assistant",
|
||||
"content":"",
|
||||
"tool_calls":[
|
||||
{"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"SF\"}"}},
|
||||
{"id":"call_def","type":"function","function":{"name":"get_time","arguments":"{\"tz\":\"PT\"}"}}
|
||||
]
|
||||
},
|
||||
"finish_reason":"tool_calls"
|
||||
}],
|
||||
"usage":{"prompt_tokens":42,"completion_tokens":18,"total_tokens":60}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal(""))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(42)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(18)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
tcs := reply.GetChatDeltas()[0].GetToolCalls()
|
||||
g.Expect(tcs).To(HaveLen(2))
|
||||
g.Expect(tcs[0].GetId()).To(Equal("call_abc"))
|
||||
g.Expect(tcs[0].GetName()).To(Equal("get_weather"))
|
||||
g.Expect(tcs[0].GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
g.Expect(tcs[1].GetId()).To(Equal("call_def"))
|
||||
g.Expect(tcs[1].GetName()).To(Equal("get_time"))
|
||||
}
|
||||
|
||||
// OpenAI: streaming tool call. Arguments arrive as a sequence of
|
||||
// delta chunks; the consumer is expected to concatenate by tool index.
|
||||
// Verify each chunk reaches the channel and the assembled arguments
|
||||
// match the input.
|
||||
func TestPredictStreamRich_OpenAI_ToolCallDeltas(t *testing.T) {
|
||||
chunks := []string{
|
||||
// Frame 0: announce the tool call (id + name, no args yet).
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_xyz","type":"function","function":{"name":"search"}}]}}]}`,
|
||||
// Frames 1-3: arguments arrive in fragments.
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"clo"}}]}}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"uds\"}"}}]}}]}`,
|
||||
// Stop frame.
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
}
|
||||
body := ""
|
||||
for _, c := range chunks {
|
||||
body += "data: " + c + "\n\n"
|
||||
}
|
||||
body += "data: [DONE]\n\n"
|
||||
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "find something"}},
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolName string
|
||||
toolID string
|
||||
toolIndex int32 = -1
|
||||
argsBuf strings.Builder
|
||||
)
|
||||
for reply := range results {
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if toolIndex == -1 {
|
||||
toolIndex = tc.GetIndex()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("call_xyz"))
|
||||
g.Expect(toolName).To(Equal("search"))
|
||||
g.Expect(toolIndex).To(Equal(int32(0)))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"clouds"}`))
|
||||
}
|
||||
|
||||
// Anthropic: non-streaming tool_use block. The block appears in
|
||||
// Content[] alongside text blocks; the input field is a structured
|
||||
// JSON object. Map to ToolCallDelta with arguments as serialised JSON
|
||||
// so downstream OpenAI-shaped consumers see a familiar format.
|
||||
func TestPredictRich_Anthropic_ToolUse(t *testing.T) {
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, `{
|
||||
"id":"msg_1","type":"message","role":"assistant",
|
||||
"content":[
|
||||
{"type":"text","text":"Let me check that."},
|
||||
{"type":"tool_use","id":"toolu_01","name":"weather","input":{"location":"SF"}}
|
||||
],
|
||||
"model":"claude","usage":{"input_tokens":12,"output_tokens":34}
|
||||
}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
reply, err := cp.PredictRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "what's the weather?"}},
|
||||
Tokens: 64,
|
||||
})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(string(reply.GetMessage())).To(Equal("Let me check that."))
|
||||
g.Expect(reply.GetPromptTokens()).To(Equal(int32(12)))
|
||||
g.Expect(reply.GetTokens()).To(Equal(int32(34)))
|
||||
g.Expect(reply.GetChatDeltas()).To(HaveLen(1))
|
||||
g.Expect(reply.GetChatDeltas()[0].GetToolCalls()).To(HaveLen(1))
|
||||
tc := reply.GetChatDeltas()[0].GetToolCalls()[0]
|
||||
g.Expect(tc.GetId()).To(Equal("toolu_01"))
|
||||
g.Expect(tc.GetName()).To(Equal("weather"))
|
||||
g.Expect(tc.GetArguments()).To(ContainSubstring(`"location":"SF"`))
|
||||
}
|
||||
|
||||
// Anthropic: streaming tool_use. content_block_start announces the
|
||||
// tool's id + name; input_json_delta events carry argument fragments
|
||||
// which the consumer accumulates. message_delta carries final usage.
|
||||
func TestPredictStreamRich_Anthropic_InputJSONDelta(t *testing.T) {
|
||||
frames := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
|
||||
// Block 0 is a tool_use; consumer should allocate a slot.
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_42\",\"name\":\"lookup\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"q\\\":\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"\\\"rain\\\"}\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
body := strings.Join(frames, "")
|
||||
|
||||
srv, _ := fakeAnthropicUpstream(t, func(_ anthropicRequest) (int, string, string) {
|
||||
return 200, body, "text/event-stream"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newAnthropicTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
results := make(chan *pb.Reply, 16)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cp.PredictStreamRich(&pb.PredictOptions{
|
||||
Messages: []*pb.Message{{Role: "user", Content: "rain?"}},
|
||||
Tokens: 64,
|
||||
}, results)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var (
|
||||
toolID, toolName string
|
||||
argsBuf strings.Builder
|
||||
finalTokens int32
|
||||
)
|
||||
for reply := range results {
|
||||
if reply.GetTokens() > 0 && len(reply.GetChatDeltas()) == 0 {
|
||||
finalTokens = reply.GetTokens()
|
||||
continue
|
||||
}
|
||||
for _, cd := range reply.GetChatDeltas() {
|
||||
for _, tc := range cd.GetToolCalls() {
|
||||
if tc.GetId() != "" {
|
||||
toolID = tc.GetId()
|
||||
}
|
||||
if tc.GetName() != "" {
|
||||
toolName = tc.GetName()
|
||||
}
|
||||
argsBuf.WriteString(tc.GetArguments())
|
||||
}
|
||||
}
|
||||
}
|
||||
err := <-done
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(toolID).To(Equal("toolu_42"))
|
||||
g.Expect(toolName).To(Equal("lookup"))
|
||||
g.Expect(argsBuf.String()).To(Equal(`{"q":"rain"}`))
|
||||
g.Expect(finalTokens).To(Equal(int32(7)))
|
||||
}
|
||||
|
||||
// Sanity: the legacy Predict() (string, error) signature still works
|
||||
// — it delegates to PredictRich and extracts Message.
|
||||
func TestPredict_LegacyWrapper_OpenAI(t *testing.T) {
|
||||
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
|
||||
return 200, `{"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, "application/json"
|
||||
})
|
||||
defer srv.Close()
|
||||
g := NewWithT(t)
|
||||
cp := newTranslateCloudProxy(t, srv.URL)
|
||||
|
||||
got, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "hi"}}})
|
||||
g.Expect(err).NotTo(HaveOccurred())
|
||||
g.Expect(got).To(Equal("hello"))
|
||||
}
|
||||
5
backend/go/crispasr/.gitignore
vendored
Normal file
5
backend/go/crispasr/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
sources
|
||||
build*
|
||||
libgocrispasr*.so
|
||||
crispasr
|
||||
package
|
||||
30
backend/go/crispasr/CMakeLists.txt
Normal file
30
backend/go/crispasr/CMakeLists.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(gocrispasr LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
add_subdirectory(./sources/CrispASR)
|
||||
|
||||
add_library(gocrispasr MODULE cpp/crispasr_shim.cpp)
|
||||
target_include_directories(gocrispasr PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sources/CrispASR/ggml/include)
|
||||
# Link the same backend set as crispasr-cli (examples/cli/CMakeLists.txt) so
|
||||
# the session API can dispatch to every compiled-in architecture, not just
|
||||
# whisper. crispasr is the referencer; the backend static libs supply the
|
||||
# per-architecture symbols; ggml is the math/runtime base.
|
||||
target_link_libraries(gocrispasr PRIVATE
|
||||
crispasr
|
||||
parakeet canary canary_ctc cohere granite_speech granite_nle
|
||||
voxtral voxtral4b qwen3_asr qwen3_tts orpheus chatterbox indextts
|
||||
kokoro voxcpm2_tts m2m100 t5_translate wav2vec2-ggml vibevoice
|
||||
silero-lid pyannote-seg funasr paraformer sensevoice
|
||||
crisp_audio
|
||||
ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gocrispasr PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET gocrispasr PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gocrispasr PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
132
backend/go/crispasr/Makefile
Normal file
132
backend/go/crispasr/Makefile
Normal file
@@ -0,0 +1,132 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# CrispASR version (release tag)
|
||||
CRISPASR_REPO?=https://github.com/CrispStrobe/CrispASR
|
||||
CRISPASR_VERSION?=13d54e110e1538e0f0bc3af0680b9ab246cfb48d
|
||||
SO_TARGET?=libgocrispasr.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
# Keep the build lean: no tests/examples/server/SDL2/curl/ffmpeg (the FROM scratch
|
||||
# image cannot satisfy those runtime deps). All ASR/TTS model backends stay enabled.
|
||||
CMAKE_ARGS+=-DCRISPASR_BUILD_TESTS=OFF -DCRISPASR_BUILD_EXAMPLES=OFF -DCRISPASR_BUILD_SERVER=OFF
|
||||
CMAKE_ARGS+=-DCRISPASR_SDL2=OFF -DCRISPASR_CURL=OFF -DCRISPASR_FFMPEG=OFF
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/CrispASR:
|
||||
mkdir -p sources/CrispASR
|
||||
cd sources/CrispASR && \
|
||||
git init && \
|
||||
git remote add origin $(CRISPASR_REPO) && \
|
||||
git fetch origin && \
|
||||
git checkout $(CRISPASR_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
# CrispASR's src/CMakeLists.txt locates its vendored llama.cpp
|
||||
# (crispasr-llama-core, used by the chat C-ABI) via ${CMAKE_SOURCE_DIR},
|
||||
# which assumes CrispASR is the top-level CMake project. We add_subdirectory
|
||||
# it, so ${CMAKE_SOURCE_DIR} is THIS backend dir and the talk-llama sources
|
||||
# aren't found. Rewrite to ${PROJECT_SOURCE_DIR} (the crispasr project root),
|
||||
# which is correct both standalone and as a subproject. Idempotent.
|
||||
sed -i 's#\$${CMAKE_SOURCE_DIR}/examples/talk-llama#\$${PROJECT_SOURCE_DIR}/examples/talk-llama#' sources/CrispASR/src/CMakeLists.txt
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgocrispasr-avx.so libgocrispasr-avx2.so libgocrispasr-avx512.so libgocrispasr-fallback.so
|
||||
else
|
||||
VARIANT_TARGETS = libgocrispasr-fallback.so
|
||||
endif
|
||||
|
||||
crispasr: main.go gocrispasr.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o crispasr ./
|
||||
|
||||
package: crispasr
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgocrispasr*.so package sources/CrispASR crispasr
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgocrispasr-avx.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx${RESET})
|
||||
SO_TARGET=libgocrispasr-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx2.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx2${RESET})
|
||||
SO_TARGET=libgocrispasr-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-avx512.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:avx512${RESET})
|
||||
SO_TARGET=libgocrispasr-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
libgocrispasr-fallback.so: sources/CrispASR
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I crispasr build info:fallback${RESET})
|
||||
SO_TARGET=libgocrispasr-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgocrispasr-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgocrispasr-custom: CMakeLists.txt cpp/crispasr_shim.cpp cpp/crispasr_shim.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgocrispasr.so ./$(SO_TARGET)
|
||||
|
||||
test: crispasr
|
||||
CGO_ENABLED=0 $(GOCMD) test -v ./...
|
||||
|
||||
all: crispasr package
|
||||
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
253
backend/go/crispasr/cpp/crispasr_shim.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
#include "crispasr_shim.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "crispasr.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
// Opaque session types. crispasr.h declares `struct crispasr_session;` but not
|
||||
// the result type nor the open/transcribe/result accessors — those are
|
||||
// CA_EXPORT extern "C" symbols in src/crispasr_c_api.cpp, so we forward-declare
|
||||
// exactly the ones we use. Signatures verified against
|
||||
// sources/CrispASR/src/crispasr_c_api.cpp.
|
||||
struct crispasr_session_result;
|
||||
extern "C" {
|
||||
crispasr_session *crispasr_session_open(const char *model_path, int n_threads);
|
||||
crispasr_session *crispasr_session_open_explicit(const char *model_path,
|
||||
const char *backend_name,
|
||||
int n_threads);
|
||||
int crispasr_session_set_codec_path(crispasr_session *s, const char *path);
|
||||
void crispasr_session_close(crispasr_session *s);
|
||||
const char *crispasr_session_backend(crispasr_session *s);
|
||||
int crispasr_session_set_translate(crispasr_session *s, int enable);
|
||||
crispasr_session_result *crispasr_session_transcribe_lang(
|
||||
crispasr_session *s, const float *pcm, int n_samples, const char *language);
|
||||
int crispasr_session_result_n_segments(crispasr_session_result *r);
|
||||
const char *crispasr_session_result_segment_text(crispasr_session_result *r,
|
||||
int i);
|
||||
int64_t crispasr_session_result_segment_t0(crispasr_session_result *r, int i);
|
||||
int64_t crispasr_session_result_segment_t1(crispasr_session_result *r, int i);
|
||||
void crispasr_session_result_free(crispasr_session_result *r);
|
||||
float *crispasr_session_synthesize(crispasr_session *s, const char *text,
|
||||
int *out_n_samples);
|
||||
void crispasr_pcm_free(float *pcm);
|
||||
int crispasr_session_set_speaker_name(crispasr_session *s, const char *name);
|
||||
int crispasr_session_set_voice(crispasr_session *s, const char *path,
|
||||
const char *ref_text_or_null);
|
||||
}
|
||||
|
||||
static crispasr_session *g_session = nullptr;
|
||||
static crispasr_session_result *g_result = nullptr;
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static std::atomic<int> g_abort{0};
|
||||
|
||||
extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
|
||||
if (!log) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG:
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default: /* Potential future-proofing */
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[%-5s] ", level_str);
|
||||
fputs(log, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (backend_name && *backend_name) {
|
||||
g_session =
|
||||
crispasr_session_open_explicit(model_path, backend_name, threads);
|
||||
} else {
|
||||
g_session = crispasr_session_open(model_path, threads);
|
||||
}
|
||||
if (g_session == nullptr) {
|
||||
fprintf(stderr, "error: failed to open CrispASR session for model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "info: CrispASR backend selected: %s\n",
|
||||
crispasr_session_backend(g_session));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// set_codec_path forwards a companion file (qwen3-tts codec, orpheus SNAC,
|
||||
// chatterbox s3gen, or mimo-asr tokenizer) to the active session. Returns 0 on
|
||||
// success or when the active backend needs no companion, negative on failure,
|
||||
// and -1 when no session is open.
|
||||
int set_codec_path(const char *path) {
|
||||
return g_session ? crispasr_session_set_codec_path(g_session, path) : -1;
|
||||
}
|
||||
|
||||
int load_model_vad(const char *const model_path) {
|
||||
whisper_log_set(ggml_log_cb, nullptr);
|
||||
ggml_backend_load_all();
|
||||
|
||||
struct whisper_vad_context_params vcparams =
|
||||
whisper_vad_default_context_params();
|
||||
|
||||
// XXX: Overridden to false in upstream due to performance?
|
||||
// vcparams.use_gpu = true;
|
||||
|
||||
vctx = whisper_vad_init_from_file_with_params(model_path, vcparams);
|
||||
if (vctx == nullptr) {
|
||||
fprintf(stderr, "error: Failed to init model as VAD\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
||||
size_t *segs_out_len) {
|
||||
if (!whisper_vad_detect_speech(vctx, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: failed to detect speech\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct whisper_vad_params params = whisper_vad_default_params();
|
||||
struct whisper_vad_segments *segs =
|
||||
whisper_vad_segments_from_probs(vctx, params);
|
||||
size_t segn = whisper_vad_segments_n_segments(segs);
|
||||
|
||||
// fprintf(stderr, "Got segments %zd\n", segn);
|
||||
|
||||
flat_segs.clear();
|
||||
|
||||
for (int i = 0; i < segn; i++) {
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t0(segs, i));
|
||||
flat_segs.push_back(whisper_vad_segments_get_segment_t1(segs, i));
|
||||
}
|
||||
|
||||
// fprintf(stderr, "setting out variables: %p=%p -> %p, %p=%zx -> %zx\n",
|
||||
// segs_out, *segs_out, flat_segs.data(), segs_out_len, *segs_out_len,
|
||||
// flat_segs.size());
|
||||
*segs_out = flat_segs.data();
|
||||
*segs_out_len = flat_segs.size();
|
||||
|
||||
// fprintf(stderr, "freeing segs\n");
|
||||
whisper_vad_free_segments(segs);
|
||||
|
||||
// fprintf(stderr, "returning\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// threads, diarize and prompt are accepted for Go-side API parity but unused
|
||||
// in Phase 1: the thread count is fixed at session open, and diarization and
|
||||
// the initial prompt are separate CrispASR features not yet wired through the
|
||||
// session ASR path.
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt) {
|
||||
(void)threads;
|
||||
(void)diarize;
|
||||
(void)prompt;
|
||||
|
||||
if (!g_session) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Reset stale abort flag from any prior cancelled call. set_abort remains
|
||||
// best-effort: the session transcribe call is blocking and exposes no abort
|
||||
// hook, so a mid-decode abort cannot interrupt it.
|
||||
g_abort.store(0, std::memory_order_relaxed);
|
||||
|
||||
crispasr_session_set_translate(g_session, translate ? 1 : 0);
|
||||
|
||||
if (g_result) {
|
||||
crispasr_session_result_free(g_result);
|
||||
g_result = nullptr;
|
||||
}
|
||||
|
||||
const char *language = (lang && *lang) ? lang : nullptr;
|
||||
g_result = crispasr_session_transcribe_lang(g_session, pcmf32, (int)pcmf32_len,
|
||||
language);
|
||||
if (!g_result) {
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
*segs_out_len = crispasr_session_result_n_segments(g_result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char *get_segment_text(int i) {
|
||||
if (!g_result) {
|
||||
return "";
|
||||
}
|
||||
return crispasr_session_result_segment_text(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t0(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t0(g_result, i);
|
||||
}
|
||||
|
||||
int64_t get_segment_t1(int i) {
|
||||
if (!g_result) {
|
||||
return 0;
|
||||
}
|
||||
return crispasr_session_result_segment_t1(g_result, i);
|
||||
}
|
||||
|
||||
const char *get_backend(void) {
|
||||
return g_session ? crispasr_session_backend(g_session) : "";
|
||||
}
|
||||
|
||||
// TTS uses the already-open session (crispasr_session_open auto-detects a TTS
|
||||
// model). Output is 24 kHz mono float PCM (upstream CrispASR convention),
|
||||
// malloc'd by the C API; the caller must release it via tts_free.
|
||||
float *tts_synthesize(const char *text, int *out_n_samples) {
|
||||
if (out_n_samples) *out_n_samples = 0;
|
||||
if (!g_session || !text) return nullptr;
|
||||
return crispasr_session_synthesize(g_session, text, out_n_samples);
|
||||
}
|
||||
|
||||
void tts_free(float *pcm) {
|
||||
if (pcm) crispasr_pcm_free(pcm);
|
||||
}
|
||||
|
||||
int tts_set_voice(const char *name) {
|
||||
if (!g_session || !name || !*name) return 0;
|
||||
return crispasr_session_set_speaker_name(g_session, name);
|
||||
}
|
||||
|
||||
// tts_set_voice_file loads a voice from a file: a .gguf path selects a voice
|
||||
// pack, a .wav path with a non-empty ref_text performs zero-shot voice cloning
|
||||
// (the C API returns -2 when ref_text is required but missing). Returns -1 when
|
||||
// no session is open or path is null.
|
||||
int tts_set_voice_file(const char *path, const char *ref_text) {
|
||||
if (!g_session || !path) return -1;
|
||||
const char *ref = (ref_text && *ref_text) ? ref_text : nullptr;
|
||||
return crispasr_session_set_voice(g_session, path, ref);
|
||||
}
|
||||
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
23
backend/go/crispasr/cpp/crispasr_shim.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
extern "C" {
|
||||
int load_model(const char *const model_path, int threads,
|
||||
const char *backend_name);
|
||||
int set_codec_path(const char *path);
|
||||
int load_model_vad(const char *const model_path);
|
||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||
size_t *segs_out_len);
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool diarize,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt);
|
||||
const char *get_segment_text(int i);
|
||||
int64_t get_segment_t0(int i);
|
||||
int64_t get_segment_t1(int i);
|
||||
const char *get_backend(void);
|
||||
void set_abort(int v);
|
||||
float *tts_synthesize(const char *text, int *out_n_samples); // 24kHz mono float, malloc'd; NULL on failure
|
||||
void tts_free(float *pcm);
|
||||
int tts_set_voice(const char *name); // best-effort speaker selection; 0 ok
|
||||
int tts_set_voice_file(const char *path, const char *ref_text); // load voice pack (.gguf) or zero-shot clone (.wav + ref_text)
|
||||
}
|
||||
497
backend/go/crispasr/gocrispasr.go
Normal file
497
backend/go/crispasr/gocrispasr.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int, backendName string) int
|
||||
CppSetCodecPath func(path string) int
|
||||
CppLoadModelVAD func(modelPath string) int
|
||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
|
||||
CppGetSegmentText func(i int) string
|
||||
CppGetSegmentStart func(i int) int64
|
||||
CppGetSegmentEnd func(i int) int64
|
||||
CppGetBackend func() string
|
||||
CppSetAbort func(v int)
|
||||
CppTTSSynthesize func(text string, outNSamples unsafe.Pointer) uintptr
|
||||
CppTTSFree func(ptr uintptr)
|
||||
CppTTSSetVoice func(name string) int
|
||||
CppTTSSetVoiceFile func(path string, refText string) int
|
||||
)
|
||||
|
||||
type CrispASR struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
// splitOption splits a "prefix:value" model option into its key and value,
|
||||
// matching the convention used by other backends (see sherpa-onnx). It returns
|
||||
// ok=false when the option carries no ':' separator.
|
||||
func splitOption(oo string) (key, value string, ok bool) {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
|
||||
func (w *CrispASR) Load(opts *pb.ModelOptions) error {
|
||||
vadOnly := false
|
||||
backendName := ""
|
||||
codecPath := ""
|
||||
speakerName := ""
|
||||
voicePath := ""
|
||||
voiceRefText := ""
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
if oo == "vad_only" {
|
||||
vadOnly = true
|
||||
continue
|
||||
}
|
||||
switch key, value, ok := splitOption(oo); {
|
||||
case ok && key == "backend":
|
||||
backendName = value
|
||||
case ok && key == "codec":
|
||||
codecPath = value
|
||||
case ok && key == "speaker":
|
||||
speakerName = value
|
||||
case ok && key == "voice":
|
||||
voicePath = value
|
||||
case ok && key == "voice_text":
|
||||
voiceRefText = value
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
}
|
||||
|
||||
if vadOnly {
|
||||
if ret := CppLoadModelVAD(opts.ModelFile); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR VAD model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve a relative companion path against the model directory so a config
|
||||
// can reference a sibling codec/tokenizer file by name alone.
|
||||
if codecPath != "" && !filepath.IsAbs(codecPath) {
|
||||
codecPath = filepath.Join(filepath.Dir(opts.ModelFile), codecPath)
|
||||
}
|
||||
|
||||
// A voice file (.gguf pack or .wav prompt) is resolved against the model
|
||||
// directory just like the codec, so a config can reference a sibling file.
|
||||
if voicePath != "" && !filepath.IsAbs(voicePath) {
|
||||
voicePath = filepath.Join(filepath.Dir(opts.ModelFile), voicePath)
|
||||
}
|
||||
|
||||
if ret := CppLoadModel(opts.ModelFile, int(opts.Threads), backendName); ret != 0 {
|
||||
return fmt.Errorf("Failed to load CrispASR transcription model")
|
||||
}
|
||||
|
||||
// Load the companion file (codec/tokenizer/s3gen) after the session is open.
|
||||
// rc==0 means success or "not applicable" for the active backend; only a
|
||||
// negative code is fatal.
|
||||
if codecPath != "" {
|
||||
if rc := CppSetCodecPath(codecPath); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load companion file %q (rc=%d)", codecPath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR companion file loaded: %s\n", codecPath)
|
||||
}
|
||||
|
||||
// Apply the Load-time default voice. A baked speaker (speaker:) is selected
|
||||
// by name and is best-effort: a backend that can't honor it is logged, not
|
||||
// fatal. A voice file (voice:) is a hard requirement once configured, so a
|
||||
// negative rc fails Load.
|
||||
if speakerName != "" {
|
||||
if rc := CppTTSSetVoice(speakerName); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: speaker %q not applied (rc=%d)\n", speakerName, rc)
|
||||
}
|
||||
}
|
||||
if voicePath != "" {
|
||||
if rc := CppTTSSetVoiceFile(voicePath, voiceRefText); rc < 0 {
|
||||
return fmt.Errorf("crispasr: failed to load voice %q (rc=%d)", voicePath, rc)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "CrispASR voice loaded: %s\n", voicePath)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "CrispASR backend selected: %s\n", CppGetBackend())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
audio := req.Audio
|
||||
// We expect 0xdeadbeef to be overwritten and if we see it in a stack trace we know it wasn't
|
||||
segsPtr, segsLen := uintptr(0xdeadbeef), uintptr(0xdeadbeef)
|
||||
segsPtrPtr, segsLenPtr := unsafe.Pointer(&segsPtr), unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppVAD(audio, uintptr(len(audio)), segsPtrPtr, segsLenPtr); ret != 0 {
|
||||
return pb.VADResponse{}, fmt.Errorf("Failed VAD")
|
||||
}
|
||||
|
||||
// Happens when CPP vector has not had any elements pushed to it
|
||||
if segsPtr == 0 {
|
||||
return pb.VADResponse{
|
||||
Segments: []*pb.VADSegment{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// unsafeptr warning is caused by segsPtr being on the stack and therefor being subject to stack copying AFAICT
|
||||
// however the stack shouldn't have grown between setting segsPtr and now, also the memory pointed to is allocated by C++
|
||||
segs := unsafe.Slice((*float32)(unsafe.Pointer(segsPtr)), segsLen) //nolint:govet // segsPtr addresses C++-owned heap memory passed back through the cgo-free purego boundary; the uintptr->Pointer round-trip is intentional and the buffer outlives this read.
|
||||
|
||||
vadSegments := []*pb.VADSegment{}
|
||||
for i := range len(segs) >> 1 {
|
||||
s := segs[2*i] / 100
|
||||
t := segs[2*i+1] / 100
|
||||
vadSegments = append(vadSegments, &pb.VADSegment{
|
||||
Start: s,
|
||||
End: t,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.VADResponse{
|
||||
Segments: vadSegments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *CrispASR) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
// Watcher: flips the C-side abort flag when ctx is cancelled. The
|
||||
// goroutine is joined synchronously (close(done) signals it to exit,
|
||||
// wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire
|
||||
// after the function returns and corrupt the next transcription call.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
text := ""
|
||||
for i := range int(segsLen) {
|
||||
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
|
||||
s := CppGetSegmentStart(i) * (10000000)
|
||||
t := CppGetSegmentEnd(i) * (10000000)
|
||||
// The session result can emit bytes that aren't valid UTF-8 (e.g. a
|
||||
// multibyte codepoint split across token boundaries); protobuf string
|
||||
// fields reject those at marshal time. Scrub before the value escapes
|
||||
// cgo. The session result is segment+word based and exposes no token
|
||||
// IDs, so Tokens is left empty.
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
|
||||
segment := &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
text += " " + strings.TrimSpace(txt)
|
||||
}
|
||||
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: strings.TrimSpace(text),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream runs the session transcribe to completion and then
|
||||
// emits one delta per non-empty segment, followed by a final TranscriptResult.
|
||||
// Progressive/real-time streaming isn't available via the session API (there
|
||||
// is no per-decode callback), so deltas are emitted per-segment after the
|
||||
// blocking decode returns rather than as segments are produced. The offline
|
||||
// AudioTranscription is unchanged; both paths share the session and the
|
||||
// SingleThread concurrency model.
|
||||
func (w *CrispASR) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "crispasr")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
|
||||
// Same abort-watcher pattern as AudioTranscription. Joined synchronously
|
||||
// so a late CppSetAbort(1) cannot fire after this function returns.
|
||||
// Best-effort only: the session transcribe is blocking with no abort hook.
|
||||
done := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
// Walk the segments once: emit a delta per non-empty segment and build the
|
||||
// final TranscriptResult.Segments alongside. The first delta has no leading
|
||||
// space and subsequent ones are prefixed with a single space, so
|
||||
// concat(deltas) == final.Text exactly, matching the e2e contract.
|
||||
segments := []*pb.TranscriptSegment{}
|
||||
var assembled strings.Builder
|
||||
for i := range int(segsLen) {
|
||||
s := CppGetSegmentStart(i) * 10000000
|
||||
t := CppGetSegmentEnd(i) * 10000000
|
||||
txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "<22>")
|
||||
segments = append(segments, &pb.TranscriptSegment{
|
||||
Id: int32(i),
|
||||
Text: txt,
|
||||
Start: s, End: t,
|
||||
})
|
||||
|
||||
trimmed := strings.TrimSpace(txt)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
var delta string
|
||||
if assembled.Len() == 0 {
|
||||
delta = trimmed
|
||||
} else {
|
||||
delta = " " + trimmed
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
assembled.WriteString(delta)
|
||||
}
|
||||
|
||||
final := &pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: assembled.String(),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: final}
|
||||
return nil
|
||||
}
|
||||
|
||||
// synthesize returns 24 kHz mono float32 PCM for text via the open session.
|
||||
func (w *CrispASR) synthesize(text string) ([]float32, error) {
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("crispasr: TTS requires non-empty text")
|
||||
}
|
||||
var n int32
|
||||
ptr := CppTTSSynthesize(text, unsafe.Pointer(&n))
|
||||
if ptr == 0 || n <= 0 {
|
||||
return nil, fmt.Errorf("crispasr: synthesis failed (the loaded model may not be a supported TTS backend, or needs extra config e.g. orpheus SNAC codec)")
|
||||
}
|
||||
defer CppTTSFree(ptr)
|
||||
src := unsafe.Slice((*float32)(unsafe.Pointer(ptr)), int(n)) //nolint:govet // ptr addresses C-allocated PCM returned across the purego boundary; copied out immediately below, before tts_free.
|
||||
out := make([]float32, int(n)) // copy out of C memory before free
|
||||
copy(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// setVoice applies a per-call speaker/voice override (best effort). CrispASR
|
||||
// returns a negative code when the active backend can't honor the name; we log
|
||||
// it rather than fail, so an unknown voice falls back to the default speaker.
|
||||
func setVoice(voice string) {
|
||||
v := strings.TrimSpace(voice)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if rc := CppTTSSetVoice(v); rc != 0 {
|
||||
fmt.Fprintf(os.Stderr, "crispasr: voice %q not applied by the active TTS backend (rc=%d); using default\n", v, rc)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *CrispASR) TTS(req *pb.TTSRequest) error {
|
||||
if req.Dst == "" {
|
||||
return fmt.Errorf("crispasr: TTS requires a destination path")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWAV24k(req.Dst, pcm)
|
||||
}
|
||||
|
||||
// TTSStream is the streaming counterpart to TTS. CrispASR has no progressive
|
||||
// (native streaming) synth, so we synthesize the whole utterance, encode it to
|
||||
// a 24 kHz WAV, and emit the encoded bytes as a single chunk. The gRPC server
|
||||
// wrapper (pkg/grpc/server.go:TTSStream) ranges over the channel until it is
|
||||
// closed, so this method owns the close - mirrors vibevoice-cpp's TTSStream.
|
||||
func (w *CrispASR) TTSStream(req *pb.TTSRequest, results chan []byte) error {
|
||||
defer close(results)
|
||||
|
||||
if req.Text == "" {
|
||||
return fmt.Errorf("crispasr: TTSStream requires text")
|
||||
}
|
||||
setVoice(req.Voice)
|
||||
pcm, err := w.synthesize(req.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp("", "crispasr-tts-stream-*.wav")
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: tempfile: %w", err)
|
||||
}
|
||||
dst := tmp.Name()
|
||||
if err := tmp.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close tempfile: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(dst) }()
|
||||
|
||||
if err := writeWAV24k(dst, pcm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoded, err := os.ReadFile(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: read tempfile: %w", err)
|
||||
}
|
||||
results <- encoded
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeWAV24k writes pcm as a 24000 Hz, mono, 16-bit PCM WAV at dst.
|
||||
func writeWAV24k(dst string, pcm []float32) error {
|
||||
f, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("crispasr: create %q: %w", dst, err)
|
||||
}
|
||||
|
||||
enc := wav.NewEncoder(f, 24000, 16, 1, 1)
|
||||
ints := make([]int, len(pcm))
|
||||
for i, s := range pcm {
|
||||
if s > 1 {
|
||||
s = 1
|
||||
} else if s < -1 {
|
||||
s = -1
|
||||
}
|
||||
ints[i] = int(s * 32767)
|
||||
}
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 24000},
|
||||
Data: ints,
|
||||
SourceBitDepth: 16,
|
||||
}
|
||||
if err := enc.Write(buf); err != nil {
|
||||
_ = enc.Close()
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: encode WAV: %w", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = f.Close()
|
||||
return fmt.Errorf("crispasr: finalize WAV: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("crispasr: close %q: %w", dst, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
193
backend/go/crispasr/gocrispasr_test.go
Normal file
193
backend/go/crispasr/gocrispasr_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestCrispASR(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "CrispASR Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// bridge without spinning up the gRPC server. Skips the current spec when the
|
||||
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
if _, err := os.Stat(libName); err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&CppSetCodecPath, gosd, "set_codec_path")
|
||||
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||
purego.RegisterLibFunc(&CppGetBackend, gosd, "get_backend")
|
||||
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||
purego.RegisterLibFunc(&CppTTSSynthesize, gosd, "tts_synthesize")
|
||||
purego.RegisterLibFunc(&CppTTSFree, gosd, "tts_free")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoice, gosd, "tts_set_voice")
|
||||
purego.RegisterLibFunc(&CppTTSSetVoiceFile, gosd, "tts_set_voice_file")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("whisper library not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if either
|
||||
// env var is unset. The test never runs in default CI — it requires a real
|
||||
// whisper model and a long audio file (~3 minutes) on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("CRISPASR_MODEL_PATH")
|
||||
audioPath := os.Getenv("CRISPASR_AUDIO_PATH")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set CRISPASR_MODEL_PATH and CRISPASR_AUDIO_PATH to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// ttsModelOrSkip returns the TTS model path or skips the spec when the env var
|
||||
// is unset. Like the transcription fixtures, this never runs in default CI — it
|
||||
// needs a real TTS model (e.g. a vibevoice GGUF) on disk.
|
||||
func ttsModelOrSkip() string {
|
||||
modelPath := os.Getenv("CRISPASR_TTS_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
Skip("set CRISPASR_TTS_MODEL_PATH to run this spec")
|
||||
}
|
||||
return modelPath
|
||||
}
|
||||
|
||||
var _ = Describe("CrispASR", func() {
|
||||
Context("AudioTranscription cancellation", func() {
|
||||
It("returns codes.Canceled on a pre-cancelled context and still succeeds afterwards", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
// The session transcribe is blocking and exposes no abort hook, so
|
||||
// a mid-decode cancel can't interrupt it. The contract we can rely
|
||||
// on is the pre-call ctx.Err() check: a context cancelled before
|
||||
// the call must yield codes.Canceled without starting a decode.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "expected pre-cancelled context to fail")
|
||||
st, ok := status.FromError(err)
|
||||
Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err)
|
||||
Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err)
|
||||
|
||||
// Subsequent transcription must succeed — proves g_abort reset.
|
||||
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed")
|
||||
Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("emits multiple deltas progressively for a multi-segment clip", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- w.AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
Stream: true,
|
||||
}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var assembled strings.Builder
|
||||
var finalText string
|
||||
var finalSegmentCount int
|
||||
for chunk := range results {
|
||||
if d := chunk.GetDelta(); d != "" {
|
||||
deltas = append(deltas, d)
|
||||
assembled.WriteString(d)
|
||||
}
|
||||
if final := chunk.GetFinalResult(); final != nil {
|
||||
finalText = final.GetText()
|
||||
finalSegmentCount = len(final.GetSegments())
|
||||
}
|
||||
}
|
||||
Expect(<-done).ToNot(HaveOccurred())
|
||||
|
||||
// One delta per non-empty segment is emitted after the blocking
|
||||
// decode returns (the session API has no per-decode callback), so a
|
||||
// multi-segment clip MUST produce >=2 delta events, and
|
||||
// concat(deltas) MUST equal final.Text exactly.
|
||||
Expect(len(deltas)).To(BeNumerically(">=", 2),
|
||||
"expected multiple deltas from a multi-segment clip, got %d (assembled=%q)",
|
||||
len(deltas), assembled.String())
|
||||
Expect(finalSegmentCount).To(BeNumerically(">=", 2),
|
||||
"expected final to carry multiple segments")
|
||||
Expect(assembled.String()).To(Equal(finalText),
|
||||
"concat(deltas) must equal final.Text")
|
||||
})
|
||||
})
|
||||
|
||||
Context("TTS", func() {
|
||||
It("synthesizes a non-empty WAV", func() {
|
||||
ttsModel := ttsModelOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
w := &CrispASR{}
|
||||
Expect(w.Load(&pb.ModelOptions{ModelFile: ttsModel})).To(Succeed())
|
||||
|
||||
dst := filepath.Join(GinkgoT().TempDir(), "out.wav")
|
||||
Expect(w.TTS(&pb.TTSRequest{Text: "Hello from CrispASR.", Dst: dst})).To(Succeed())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).ToNot(HaveOccurred(), "synthesized WAV should exist at %q", dst)
|
||||
// A real 24 kHz mono WAV is a 44-byte header plus samples; anything
|
||||
// this small would mean an empty/failed synth.
|
||||
Expect(info.Size()).To(BeNumerically(">", 1024),
|
||||
"expected a non-trivial WAV, got %d bytes", info.Size())
|
||||
})
|
||||
})
|
||||
})
|
||||
58
backend/go/crispasr/main.go
Normal file
58
backend/go/crispasr/main.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("CRISPASR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgocrispasr-fallback.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "load_model"},
|
||||
{&CppSetCodecPath, "set_codec_path"},
|
||||
{&CppLoadModelVAD, "load_model_vad"},
|
||||
{&CppVAD, "vad"},
|
||||
{&CppTranscribe, "transcribe"},
|
||||
{&CppGetSegmentText, "get_segment_text"},
|
||||
{&CppGetSegmentStart, "get_segment_t0"},
|
||||
{&CppGetSegmentEnd, "get_segment_t1"},
|
||||
{&CppGetBackend, "get_backend"},
|
||||
{&CppSetAbort, "set_abort"},
|
||||
{&CppTTSSynthesize, "tts_synthesize"},
|
||||
{&CppTTSFree, "tts_free"},
|
||||
{&CppTTSSetVoice, "tts_set_voice"},
|
||||
{&CppTTSSetVoiceFile, "tts_set_voice_file"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &CrispASR{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
65
backend/go/crispasr/package.sh
Executable file
65
backend/go/crispasr/package.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/crispasr $CURDIR/package/
|
||||
cp -fv $CURDIR/libgocrispasr-*.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
# The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/crispasr/run.sh
Executable file
52
backend/go/crispasr/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgocrispasr-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgocrispasr-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgocrispasr-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export CRISPASR_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/crispasr "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/crispasr "$@"
|
||||
@@ -8,6 +8,6 @@ import (
|
||||
|
||||
func assert(cond bool, msg string) {
|
||||
if !cond {
|
||||
xlog.Fatal().Stack().Msg(msg)
|
||||
xlog.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
package main
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
// LocalAI's in-process vector store, exposed as a gRPC backend. Keep
|
||||
// the implementation here — NOT in a pkg/ library imported by the main
|
||||
// LocalAI process. The whole point of the gRPC surface is that vector
|
||||
// storage is a backend like any other (local-store, qdrant, pinecone,
|
||||
// ...) and can be swapped without changing the routing/recognition
|
||||
// code that consumes it.
|
||||
//
|
||||
// Storage is a sorted parallel-slice (keys [][]float32, values
|
||||
// [][]byte). Set/Delete preserve the sort so Get can binary-search.
|
||||
// Find scans linearly and uses a heap to keep the top-K — fine for
|
||||
// the tens-to-thousands range. The "normalized fast path" (Find when
|
||||
// every stored key has unit magnitude AND the query is normalized)
|
||||
// skips the per-item magnitude calculation.
|
||||
//
|
||||
// Concurrency: base.SingleThread serialises gRPC calls so the
|
||||
// non-thread-safe slice/heap manipulation here is sound.
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
@@ -10,30 +25,27 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
base.SingleThread
|
||||
|
||||
// The sorted keys
|
||||
keys [][]float32
|
||||
// The sorted values
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
|
||||
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
|
||||
// TODO: Should we normalize incoming keys if they are not instead?
|
||||
// keysAreNormalized stays true until any non-unit-magnitude key
|
||||
// is added; once false, the magnitude-aware fallback path is
|
||||
// used by Find. Re-evaluated only at Set time, never again on
|
||||
// its own — a deletion of the offending key does NOT flip it
|
||||
// back to true (the bookkeeping cost would dominate the gain).
|
||||
keysAreNormalized bool
|
||||
// The first key decides the length of the keys
|
||||
keyLen int
|
||||
}
|
||||
|
||||
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
|
||||
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
|
||||
type Pair struct {
|
||||
Key []float32
|
||||
Value []byte
|
||||
// keyLen is the dimension of every stored key. -1 means "no
|
||||
// keys yet, dimension is open". Dimension mismatch on Set is
|
||||
// rejected so cosine similarity (which requires equal-length
|
||||
// vectors) doesn't silently mis-match.
|
||||
keyLen int
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
@@ -45,334 +57,278 @@ func NewStore() *Store {
|
||||
}
|
||||
}
|
||||
|
||||
func compareSlices(k1, k2 []float32) int {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
return slices.Compare(k1, k2)
|
||||
}
|
||||
|
||||
func hasKey(unsortedSlice [][]float32, target []float32) bool {
|
||||
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
|
||||
return compareSlices(k, target) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
|
||||
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
|
||||
return compareSlices(k, t)
|
||||
})
|
||||
}
|
||||
|
||||
func isSortedPairs(kvs []Pair) bool {
|
||||
for i := 1; i < len(kvs); i++ {
|
||||
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isSortedKeys(keys [][]float32) bool {
|
||||
for i := 1; i < len(keys); i++ {
|
||||
if compareSlices(keys[i-1], keys[i]) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
||||
ks := make([][]float32, len(keys))
|
||||
|
||||
for i, k := range keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
slices.SortFunc(ks, compareSlices)
|
||||
|
||||
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
|
||||
assert(isSortedKeys(ks), "keys are not sorted")
|
||||
|
||||
return ks
|
||||
}
|
||||
|
||||
// Load is a no-op — local-store has no on-disk artefact. opts.Model is
|
||||
// just a namespace identifier; isolation is already handled upstream
|
||||
// (ModelLoader spawns a fresh local-store process per (backend,
|
||||
// model) tuple, so each namespace is its own Store{} instance).
|
||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||
// local-store is an in-memory vector store with no on-disk artefact to
|
||||
// load — opts.Model is just a namespace identifier. The old `!= ""` guard
|
||||
// rejected any non-empty model name with "not implemented", which broke
|
||||
// callers that pass a namespace to isolate embedding spaces (face vs.
|
||||
// voice biometrics both go through local-store but need distinct stores
|
||||
// so ArcFace 512-D and ECAPA-TDNN 192-D don't collide). Namespace
|
||||
// isolation is already handled upstream: ModelLoader spawns a fresh
|
||||
// local-store process per (backend, model) tuple, so each namespace is
|
||||
// its own Store{} instance. Nothing to do here beyond accepting the load.
|
||||
_ = opts
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort the incoming kvs and merge them with the existing sorted kvs
|
||||
func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
values := store.UnwrapValues(opts.Values)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Set: no keys to add")
|
||||
}
|
||||
|
||||
if len(opts.Keys) != len(opts.Values) {
|
||||
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
||||
if len(keys) != len(values) {
|
||||
return fmt.Errorf("local-store: Set: len(keys) = %d, len(values) = %d", len(keys), len(values))
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
s.keyLen = len(keys[0])
|
||||
} else if len(keys[0]) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key length %d does not match existing %d", len(keys[0]), s.keyLen)
|
||||
}
|
||||
|
||||
kvs := make([]Pair, len(opts.Keys))
|
||||
|
||||
for i, k := range opts.Keys {
|
||||
if s.keysAreNormalized && !isNormalized(k.Floats) {
|
||||
kvs := make([]incomingPair, len(keys))
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Set: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
if s.keysAreNormalized && !isNormalized(k) {
|
||||
s.keysAreNormalized = false
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = k.Floats[:5]
|
||||
} else {
|
||||
sample = k.Floats
|
||||
}
|
||||
xlog.Debug("Key is not normalized", "sample", sample)
|
||||
}
|
||||
|
||||
kvs[i] = Pair{
|
||||
Key: k.Floats,
|
||||
Value: opts.Values[i].Bytes,
|
||||
}
|
||||
kvs[i] = incomingPair{key: k, value: values[i]}
|
||||
}
|
||||
|
||||
slices.SortFunc(kvs, func(a, b Pair) int {
|
||||
return compareSlices(a.Key, b.Key)
|
||||
})
|
||||
|
||||
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
|
||||
assert(isSortedPairs(kvs), "keys are not sorted")
|
||||
|
||||
l := len(kvs) + len(s.keys)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
i, j := 0, 0
|
||||
for {
|
||||
if i+j >= l {
|
||||
break
|
||||
}
|
||||
|
||||
if i >= len(kvs) {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
continue
|
||||
}
|
||||
|
||||
if j >= len(s.keys) {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
c := compareSlices(kvs[i].Key, s.keys[j])
|
||||
if c < 0 {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
} else if c > 0 {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
} else {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
|
||||
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
slices.SortFunc(kvs, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) })
|
||||
|
||||
merged := mergeSortedPairs(s.keys, s.values, kvs)
|
||||
s.keys = merged.keys
|
||||
s.values = merged.values
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Set: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Set: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to delete")
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("local-store: Delete: no keys to delete")
|
||||
}
|
||||
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
l := len(s.keys) - len(ks)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
tail_ks := s.keys
|
||||
tail_vs := s.values
|
||||
for _, k := range ks {
|
||||
j, found := findInSortedSlice(tail_ks, k)
|
||||
|
||||
if found {
|
||||
merge_ks = append(merge_ks, tail_ks[:j]...)
|
||||
merge_vs = append(merge_vs, tail_vs[:j]...)
|
||||
tail_ks = tail_ks[j+1:]
|
||||
tail_vs = tail_vs[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
|
||||
}
|
||||
|
||||
xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs))
|
||||
}
|
||||
|
||||
merge_ks = append(merge_ks, tail_ks...)
|
||||
merge_vs = append(merge_vs, tail_vs...)
|
||||
|
||||
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
|
||||
assert(isSortedKeys(s.keys), "keys are not sorted")
|
||||
assert(func() bool {
|
||||
for _, k := range ks {
|
||||
if _, found := findInSortedSlice(s.keys, k); found {
|
||||
return false
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return fmt.Errorf("local-store: Delete: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}(), "Keys to delete still present")
|
||||
|
||||
if len(s.keys) != l {
|
||||
xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l)
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
mergedK := make([][]float32, 0, len(s.keys))
|
||||
mergedV := make([][]byte, 0, len(s.keys))
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if ok {
|
||||
mergedK = append(mergedK, tailK[:j]...)
|
||||
mergedV = append(mergedV, tailV[:j]...)
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
}
|
||||
mergedK = append(mergedK, tailK...)
|
||||
mergedV = append(mergedV, tailV...)
|
||||
s.keys = mergedK
|
||||
s.values = mergedV
|
||||
assert(slices.IsSortedFunc(s.keys, slices.Compare[[]float32]), "Delete: s.keys not sorted post-merge")
|
||||
assert(len(s.keys) == len(s.values), "Delete: keys/values length skew")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoresGet fetches values for the given keys. Missing keys are
|
||||
// omitted from the result rather than reported as an error — callers
|
||||
// compare returned-key length against requested-key length to detect
|
||||
// them. Returned slices are aligned.
|
||||
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
|
||||
pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
keys := store.UnwrapKeys(opts.Keys)
|
||||
if len(s.keys) == 0 {
|
||||
xlog.Debug("Get: No keys in store")
|
||||
return pb.StoresGetResult{}, nil
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
if s.keyLen != -1 {
|
||||
for i, k := range keys {
|
||||
if len(k) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("local-store: Get: key %d length %d does not match existing %d", i, len(k), s.keyLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
sortedKeys := append([][]float32(nil), keys...)
|
||||
slices.SortFunc(sortedKeys, slices.Compare[[]float32])
|
||||
|
||||
tail_k := s.keys
|
||||
tail_v := s.values
|
||||
for i, k := range ks {
|
||||
j, found := findInSortedSlice(tail_k, k)
|
||||
|
||||
if found {
|
||||
pbKeys = append(pbKeys, &pb.StoresKey{
|
||||
Floats: k,
|
||||
})
|
||||
pbValues = append(pbValues, &pb.StoresValue{
|
||||
Bytes: tail_v[j],
|
||||
})
|
||||
|
||||
tail_k = tail_k[j+1:]
|
||||
tail_v = tail_v[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
|
||||
var foundKeys [][]float32
|
||||
var foundValues [][]byte
|
||||
tailK := s.keys
|
||||
tailV := s.values
|
||||
for _, k := range sortedKeys {
|
||||
j, ok := slices.BinarySearchFunc(tailK, k, slices.Compare[[]float32])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
foundKeys = append(foundKeys, tailK[j])
|
||||
foundValues = append(foundValues, tailV[j])
|
||||
tailK = tailK[j+1:]
|
||||
tailV = tailV[j+1:]
|
||||
}
|
||||
|
||||
if len(pbKeys) != len(opts.Keys) {
|
||||
xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys))
|
||||
}
|
||||
|
||||
return pb.StoresGetResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Keys: store.WrapKeys(foundKeys),
|
||||
Values: store.WrapValues(foundValues),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StoresFind returns the topK nearest stored entries by cosine
|
||||
// similarity, ordered most-similar first. An empty store returns
|
||||
// empty slices and no error.
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
query := opts.Key.Floats
|
||||
topK := int(opts.TopK)
|
||||
if topK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: topK = %d, must be >= 1", topK)
|
||||
}
|
||||
if len(s.keys) == 0 {
|
||||
return pb.StoresFindResult{}, nil
|
||||
}
|
||||
if len(query) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("local-store: Find: query length %d does not match existing %d", len(query), s.keyLen)
|
||||
}
|
||||
|
||||
var keys [][]float32
|
||||
var values [][]byte
|
||||
var sims []float32
|
||||
if s.keysAreNormalized && isNormalized(query) {
|
||||
keys, values, sims = s.findNormalized(query, topK)
|
||||
} else {
|
||||
keys, values, sims = s.findFallback(query, topK)
|
||||
}
|
||||
return pb.StoresFindResult{
|
||||
Keys: store.WrapKeys(keys),
|
||||
Values: store.WrapValues(values),
|
||||
Similarities: sims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) findNormalized(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
assert(s.keysAreNormalized, "findNormalized: s.keysAreNormalized is false")
|
||||
assert(isNormalized(query), "findNormalized: query is not unit-length")
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot float32
|
||||
for j := range k {
|
||||
dot += query[j] * k[j]
|
||||
}
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("findNormalized: dot %f out of [-1, 1] — keysAreNormalized invariant violated", dot))
|
||||
heap.Push(&pq, &priorityItem{similarity: dot, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func (s *Store) findFallback(query []float32, topK int) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
var qmag float64
|
||||
for _, v := range query {
|
||||
qmag += float64(v) * float64(v)
|
||||
}
|
||||
qmag = math.Sqrt(qmag)
|
||||
pq := make(priorityQueue, 0, topK)
|
||||
heap.Init(&pq)
|
||||
for i, k := range s.keys {
|
||||
var dot, kmag float64
|
||||
for j := range k {
|
||||
dot += float64(query[j]) * float64(k[j])
|
||||
kmag += float64(k[j]) * float64(k[j])
|
||||
}
|
||||
denom := qmag * math.Sqrt(kmag)
|
||||
var sim float32
|
||||
if denom > 0 {
|
||||
sim = float32(dot / denom)
|
||||
}
|
||||
heap.Push(&pq, &priorityItem{similarity: sim, key: k, value: s.values[i]})
|
||||
if pq.Len() > topK {
|
||||
heap.Pop(&pq)
|
||||
}
|
||||
}
|
||||
return drainPQ(&pq)
|
||||
}
|
||||
|
||||
func isNormalized(k []float32) bool {
|
||||
var sum float64
|
||||
|
||||
for _, v := range k {
|
||||
v64 := float64(v)
|
||||
sum += v64 * v64
|
||||
sum += float64(v) * float64(v)
|
||||
}
|
||||
|
||||
s := math.Sqrt(sum)
|
||||
|
||||
return s >= 0.99 && s <= 1.01
|
||||
mag := math.Sqrt(sum)
|
||||
return mag >= 0.99 && mag <= 1.01
|
||||
}
|
||||
|
||||
// TODO: This we could replace with handwritten SIMD code
|
||||
func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
type incomingPair struct {
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
|
||||
var dot float32
|
||||
for i := range len(k1) {
|
||||
dot += k1[i] * k2[i]
|
||||
type pairs struct {
|
||||
keys [][]float32
|
||||
values [][]byte
|
||||
}
|
||||
|
||||
// mergeSortedPairs merges (existing, incoming) into a fresh sorted
|
||||
// slice. Equal keys take the incoming value — Set is upsert.
|
||||
func mergeSortedPairs(existingK [][]float32, existingV [][]byte, incoming []incomingPair) pairs {
|
||||
assert(slices.IsSortedFunc(existingK, slices.Compare[[]float32]), "mergeSortedPairs: existing not sorted")
|
||||
assert(slices.IsSortedFunc(incoming, func(a, b incomingPair) int { return slices.Compare(a.key, b.key) }), "mergeSortedPairs: incoming not sorted")
|
||||
l := len(existingK) + len(incoming)
|
||||
mk := make([][]float32, 0, l)
|
||||
mv := make([][]byte, 0, l)
|
||||
i, j := 0, 0
|
||||
for i < len(incoming) || j < len(existingK) {
|
||||
switch {
|
||||
case j >= len(existingK):
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case i >= len(incoming):
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
c := slices.Compare(incoming[i].key, existingK[j])
|
||||
switch {
|
||||
case c < 0:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
case c > 0:
|
||||
mk = append(mk, existingK[j])
|
||||
mv = append(mv, existingV[j])
|
||||
j++
|
||||
default:
|
||||
mk = append(mk, incoming[i].key)
|
||||
mv = append(mv, incoming[i].value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot))
|
||||
|
||||
// 2.0 * (1.0 - dot) would be the Euclidean distance
|
||||
return dot
|
||||
return pairs{keys: mk, values: mv}
|
||||
}
|
||||
|
||||
type PriorityItem struct {
|
||||
Similarity float32
|
||||
Key []float32
|
||||
Value []byte
|
||||
type priorityItem struct {
|
||||
similarity float32
|
||||
key []float32
|
||||
value []byte
|
||||
}
|
||||
|
||||
type PriorityQueue []*PriorityItem
|
||||
type priorityQueue []*priorityItem
|
||||
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
// Inverted because the most similar should be at the top
|
||||
return pq[i].Similarity < pq[j].Similarity
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Push(x any) {
|
||||
item := x.(*PriorityItem)
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() any {
|
||||
func (pq priorityQueue) Len() int { return len(pq) }
|
||||
func (pq priorityQueue) Less(i, j int) bool { return pq[i].similarity < pq[j].similarity }
|
||||
func (pq priorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
|
||||
func (pq *priorityQueue) Push(x any) { *pq = append(*pq, x.(*priorityItem)) }
|
||||
func (pq *priorityQueue) Pop() any {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
@@ -380,142 +336,16 @@ func (pq *PriorityQueue) Pop() any {
|
||||
return item
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
for i, k := range s.keys {
|
||||
sim := normalizedCosineSimilarity(tk, k)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: sim,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := range len(k1) {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
sim := float32(dot / (mag1 * math.Sqrt(mag2)))
|
||||
|
||||
assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim))
|
||||
|
||||
return sim
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
var mag1 float64
|
||||
for _, v := range tk {
|
||||
mag1 += float64(v * v)
|
||||
}
|
||||
mag1 = math.Sqrt(mag1)
|
||||
|
||||
for i, k := range s.keys {
|
||||
dist := cosineSimilarity(tk, k, mag1)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: dist,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
|
||||
if len(tk) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
|
||||
}
|
||||
|
||||
if opts.TopK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Key.Floats)
|
||||
} else {
|
||||
if len(opts.Key.Floats) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
if s.keysAreNormalized && isNormalized(tk) {
|
||||
return s.StoresFindNormalized(opts)
|
||||
} else {
|
||||
if s.keysAreNormalized {
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = tk[:5]
|
||||
} else {
|
||||
sample = tk
|
||||
}
|
||||
xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample)
|
||||
}
|
||||
|
||||
return s.StoresFindFallback(opts)
|
||||
func drainPQ(pq *priorityQueue) (keys [][]float32, values [][]byte, similarities []float32) {
|
||||
n := pq.Len()
|
||||
keys = make([][]float32, n)
|
||||
values = make([][]byte, n)
|
||||
similarities = make([]float32, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
item := heap.Pop(pq).(*priorityItem)
|
||||
keys[i] = item.key
|
||||
values[i] = item.value
|
||||
similarities[i] = item.similarity
|
||||
}
|
||||
return keys, values, similarities
|
||||
}
|
||||
|
||||
13
backend/go/local-store/store_suite_test.go
Normal file
13
backend/go/local-store/store_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLocalStore(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "local-store test suite")
|
||||
}
|
||||
284
backend/go/local-store/store_test.go
Normal file
284
backend/go/local-store/store_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package main
|
||||
|
||||
// Regression suite for the local-store gRPC backend. Exercises the
|
||||
// Stores{Set,Get,Find,Delete} surface — the only public contract.
|
||||
// Callers (face/voice recognition, the routing KNN classifier) reach
|
||||
// this code via grpc.Backend, so testing at the wire-shaped boundary
|
||||
// matches the production import shape.
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StoresSet", func() {
|
||||
It("rejects empty input", func() {
|
||||
Expect(NewStore().StoresSet(&pb.StoresSetOptions{})).NotTo(Succeed(), "Set with no keys should fail")
|
||||
})
|
||||
|
||||
It("rejects key/value length mismatch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("a"), []byte("b")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "len(keys) != len(values) should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch on later add", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("3d")})
|
||||
err := s.StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "dimension mismatch on later Set should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch within batch", func() {
|
||||
err := NewStore().StoresSet(&pb.StoresSetOptions{
|
||||
Keys: wrapKeys([][]float32{{1, 0, 0}, {1, 0}}),
|
||||
Values: wrapValues([][]byte{[]byte("3d"), []byte("2d")}),
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "mixed-dimension within one batch should fail")
|
||||
})
|
||||
|
||||
It("merges sorted and updates existing key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.3, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("c"), []byte("a")})
|
||||
mustSet(s, [][]float32{{0.2, 0, 0}, {0.1, 0, 0}}, [][]byte{[]byte("b"), []byte("a-updated")})
|
||||
Expect(s.keys).To(HaveLen(3))
|
||||
got := singleGet(s, []float32{0.1, 0, 0})
|
||||
Expect(string(got)).To(Equal("a-updated"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresGet", func() {
|
||||
It("round-trips multi-key", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c")},
|
||||
)
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("omits missing keys rather than erroring", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{
|
||||
Keys: wrapKeys([][]float32{{0.1, 0, 0}, {0.9, 0, 0}}),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresDelete", func() {
|
||||
It("removes and preserves sort", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{{0.1, 0, 0}, {0.2, 0, 0}, {0.3, 0, 0}, {0.4, 0, 0}},
|
||||
[][]byte{[]byte("a"), []byte("b"), []byte("c"), []byte("d")},
|
||||
)
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.2, 0, 0}, {0.4, 0, 0}}),
|
||||
})).To(Succeed())
|
||||
Expect(s.keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("tolerates missing keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{0.1, 0, 0}}, [][]byte{[]byte("a")})
|
||||
Expect(s.StoresDelete(&pb.StoresDeleteOptions{
|
||||
Keys: wrapKeys([][]float32{{0.9, 0, 0}}),
|
||||
})).To(Succeed(), "delete of missing key should succeed")
|
||||
Expect(s.keys).To(HaveLen(1))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresFind", func() {
|
||||
It("returns normalized top-K", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{
|
||||
normalizeVec([]float32{1, 0, 0}),
|
||||
normalizeVec([]float32{0, 1, 0}),
|
||||
normalizeVec([]float32{0, 0, 1}),
|
||||
},
|
||||
[][]byte{[]byte("x"), []byte("y"), []byte("z")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{0.9, 0.1, 0})},
|
||||
TopK: 2,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", res.Similarities[1]), "results not sorted desc by similarity")
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
})
|
||||
|
||||
It("falls back for non-normalized keys", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{2, 0, 0}, {0, 3, 0}}, [][]byte{[]byte("x"), []byte("y")})
|
||||
Expect(s.keysAreNormalized).To(BeFalse(), "store should report non-normalized after Set with magnitude > 1")
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{4, 0, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(string(res.Values[0].Bytes)).To(Equal("x"))
|
||||
Expect(res.Similarities[0]).To(BeNumerically(">=", float32(0.99)))
|
||||
Expect(res.Similarities[0]).To(BeNumerically("<=", float32(1.01)))
|
||||
})
|
||||
|
||||
It("rejects zero topK", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 0,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with topK=0 should fail")
|
||||
})
|
||||
|
||||
It("rejects dimension mismatch", func() {
|
||||
s := NewStore()
|
||||
mustSet(s, [][]float32{{1, 0, 0}}, [][]byte{[]byte("x")})
|
||||
_, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0}},
|
||||
TopK: 1,
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "Find with mismatched dimension should fail")
|
||||
})
|
||||
|
||||
It("returns empty result on empty store", func() {
|
||||
res, err := NewStore().StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: []float32{1, 0, 0}},
|
||||
TopK: 5,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred(), "Find on empty store should succeed")
|
||||
Expect(res.Keys).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("handles topK larger than store", func() {
|
||||
s := NewStore()
|
||||
mustSet(s,
|
||||
[][]float32{normalizeVec([]float32{1, 0, 0}), normalizeVec([]float32{0, 1, 0})},
|
||||
[][]byte{[]byte("x"), []byte("y")},
|
||||
)
|
||||
res, err := s.StoresFind(&pb.StoresFindOptions{
|
||||
Key: &pb.StoresKey{Floats: normalizeVec([]float32{1, 0, 0})},
|
||||
TopK: 10,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Keys).To(HaveLen(2))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("StoresLoad", func() {
|
||||
It("is a no-op", func() {
|
||||
Expect(NewStore().Load(&pb.ModelOptions{Model: "any-namespace"})).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
func BenchmarkStoresFindNormalized(b *testing.B) {
|
||||
const dim = 768
|
||||
for _, n := range []int{8, 32, 128, 512} {
|
||||
b.Run(fmtN(n), func(b *testing.B) {
|
||||
s := buildStore(b, n, dim)
|
||||
query := normalizeVec(randVec(dim, 42))
|
||||
req := &pb.StoresFindOptions{Key: &pb.StoresKey{Floats: query}, TopK: 1}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := s.StoresFind(req); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
func mustSet(s *Store, keys [][]float32, values [][]byte) {
|
||||
ExpectWithOffset(1, s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)})).To(Succeed())
|
||||
}
|
||||
|
||||
func singleGet(s *Store, key []float32) []byte {
|
||||
res, err := s.StoresGet(&pb.StoresGetOptions{Keys: wrapKeys([][]float32{key})})
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
if len(res.Values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return res.Values[0].Bytes
|
||||
}
|
||||
|
||||
func wrapKeys(in [][]float32) []*pb.StoresKey {
|
||||
out := make([]*pb.StoresKey, len(in))
|
||||
for i, k := range in {
|
||||
out[i] = &pb.StoresKey{Floats: k}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func wrapValues(in [][]byte) []*pb.StoresValue {
|
||||
out := make([]*pb.StoresValue, len(in))
|
||||
for i, v := range in {
|
||||
out[i] = &pb.StoresValue{Bytes: v}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildStore(tb testing.TB, n, dim int) *Store {
|
||||
tb.Helper()
|
||||
s := NewStore()
|
||||
keys := make([][]float32, n)
|
||||
values := make([][]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
keys[i] = normalizeVec(randVec(dim, int64(i)+1))
|
||||
values[i] = []byte{byte(i)}
|
||||
}
|
||||
if err := s.StoresSet(&pb.StoresSetOptions{Keys: wrapKeys(keys), Values: wrapValues(values)}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func randVec(dim int, seed int64) []float32 {
|
||||
r := rand.New(rand.NewPCG(uint64(seed), 0xabcdef))
|
||||
v := make([]float32, dim)
|
||||
for i := range v {
|
||||
v[i] = float32(r.NormFloat64())
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func normalizeVec(v []float32) []float32 {
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
mag := math.Sqrt(sum)
|
||||
if mag == 0 {
|
||||
return v
|
||||
}
|
||||
out := make([]float32, len(v))
|
||||
for i, x := range v {
|
||||
out[i] = float32(float64(x) / mag)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fmtN(n int) string {
|
||||
return map[int]string{8: "n=8", 32: "n=32", 128: "n=128", 512: "n=512"}[n]
|
||||
}
|
||||
@@ -9,7 +9,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
# LocalVQE upstream version pin. Bump to a specific commit when picking up
|
||||
# a new release; `main` works for development but is not reproducible.
|
||||
LOCALVQE_REPO?=https://github.com/localai-org/LocalVQE
|
||||
LOCALVQE_VERSION?=72bfb4c6
|
||||
LOCALVQE_VERSION?=b0f0378a450e87c871b85689554801601ca56d98
|
||||
|
||||
# LocalVQE handles CPU feature selection internally (it ships the multiple
|
||||
# libggml-cpu-*.so variants and its loader picks the best one at runtime
|
||||
@@ -27,7 +27,8 @@ endif
|
||||
|
||||
# LocalVQE upstream supports CPU + Vulkan only. Other BUILD_TYPE values
|
||||
# fall through to the default CPU build — Vulkan is already as fast as the
|
||||
# specialised GPU paths would be on this 1.3 M-parameter model.
|
||||
# specialised GPU paths would be on these small (1.3 M–4.8 M parameter)
|
||||
# models.
|
||||
ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DLOCALVQE_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -46,24 +46,24 @@ const (
|
||||
// through the options builder (CppOptionsNew + setters + CppNewWithOptions)
|
||||
// — the bare localvqe_new path doesn't expose backend / device selection.
|
||||
var (
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
CppOptionsNew func() uintptr
|
||||
CppOptionsFree func(opts uintptr)
|
||||
CppOptionsSetModelPath func(opts uintptr, modelPath string) int32
|
||||
CppOptionsSetBackend func(opts uintptr, backend string) int32
|
||||
CppOptionsSetDevice func(opts uintptr, device int32) int32
|
||||
CppNewWithOptions func(opts uintptr) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppProcessF32 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessS16 func(ctx uintptr, mic, ref uintptr, nSamples int32, out uintptr) int32
|
||||
CppProcessFrameF32 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppProcessFrameS16 func(ctx uintptr, mic, ref uintptr, hopSamples int32, out uintptr) int32
|
||||
CppReset func(ctx uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
CppSampleRate func(ctx uintptr) int32
|
||||
CppHopLength func(ctx uintptr) int32
|
||||
CppFFTSize func(ctx uintptr) int32
|
||||
CppSetNoiseGate func(ctx uintptr, enabled int32, thresholdDBFS float32) int32
|
||||
CppGetNoiseGate func(ctx uintptr, enabledOut, thresholdDBFSOut uintptr) int32
|
||||
)
|
||||
|
||||
// LocalVQE speaks gRPC against LocalVQE's flat C ABI. The streaming
|
||||
@@ -490,11 +490,14 @@ func (v *LocalVQE) applyStreamConfig(cfg *pb.AudioTransformStreamConfig) error {
|
||||
|
||||
// ---- WAV I/O ----------------------------------------------------------
|
||||
//
|
||||
// Minimal mono PCM WAV reader/writer. Only handles the subset LocalVQE
|
||||
// cares about (mono, 16-bit signed, no extensible chunks). For broader
|
||||
// audio support the HTTP layer's `audio.NormalizeAudioFile` already
|
||||
// converts arbitrary input to a canonical WAV before we see it; this
|
||||
// reader just decodes the canonical shape.
|
||||
// Reader/writer for the mono 16-bit PCM shape LocalVQE works with. Decoding
|
||||
// goes through the shared go-audio/wav decoder (as the whisper and parakeet
|
||||
// backends do) so RIFF chunk walking is handled robustly — an 18/40-byte
|
||||
// extensible `fmt ` chunk, or JUNK/bext/LIST metadata before or after `data`
|
||||
// (e.g. ffmpeg's trailing "Lavf" tag), is skipped rather than spliced into
|
||||
// the PCM stream as an audible click. The HTTP layer normalises arbitrary
|
||||
// input to WAV before we see it, but that WAV is ffmpeg output and is not
|
||||
// guaranteed to be the canonical 44-byte layout.
|
||||
|
||||
func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
f, err := os.Open(path)
|
||||
@@ -502,35 +505,26 @@ func readMonoWAVf32(path string) ([]float32, int, error) {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
header := make([]byte, 44)
|
||||
if _, err := io.ReadFull(f, header); err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
buf, err := wav.NewDecoder(f).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("decode WAV: %w", err)
|
||||
}
|
||||
if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" {
|
||||
if buf == nil || buf.Format == nil {
|
||||
return nil, 0, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
channels := binary.LittleEndian.Uint16(header[22:24])
|
||||
sampleRate := binary.LittleEndian.Uint32(header[24:28])
|
||||
bitsPerSample := binary.LittleEndian.Uint16(header[34:36])
|
||||
|
||||
if channels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", channels)
|
||||
if buf.Format.NumChannels != 1 {
|
||||
return nil, 0, fmt.Errorf("only mono WAV supported (got %d channels)", buf.Format.NumChannels)
|
||||
}
|
||||
if bitsPerSample != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", bitsPerSample)
|
||||
if buf.SourceBitDepth != 16 {
|
||||
return nil, 0, fmt.Errorf("only 16-bit PCM supported (got %d bits)", buf.SourceBitDepth)
|
||||
}
|
||||
|
||||
rest, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
if len(buf.Data) == 0 {
|
||||
return nil, 0, fmt.Errorf("WAV has no audio data")
|
||||
}
|
||||
n := len(rest) / 2
|
||||
out := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
s := int16(binary.LittleEndian.Uint16(rest[i*2 : i*2+2]))
|
||||
out[i] = float32(s) / 32768.0
|
||||
}
|
||||
return out, int(sampleRate), nil
|
||||
// AsFloat32Buffer normalises by 2^(bitDepth-1) == /32768 for 16-bit,
|
||||
// matching the model's expected [-1, 1) input range.
|
||||
return buf.AsFloat32Buffer().Data, buf.Format.SampleRate, nil
|
||||
}
|
||||
|
||||
func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
@@ -546,13 +540,13 @@ func writeMonoWAVf32(path string, samples []float32, sampleRate int) error {
|
||||
binary.LittleEndian.PutUint32(header[4:8], 36+dataLen)
|
||||
copy(header[8:12], []byte("WAVE"))
|
||||
copy(header[12:16], []byte("fmt "))
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[16:20], 16) // fmt chunk size
|
||||
binary.LittleEndian.PutUint16(header[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(header[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(header[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(header[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint16(header[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(header[34:36], 16) // bits per sample
|
||||
copy(header[36:40], []byte("data"))
|
||||
binary.LittleEndian.PutUint32(header[40:44], dataLen)
|
||||
if _, err := f.Write(header); err != nil {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -92,6 +94,147 @@ var _ = Describe("LocalVQE-cpp", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("readMonoWAVf32 chunk parsing", func() {
|
||||
// chunk builds a word-aligned RIFF sub-chunk (id + size + body + pad).
|
||||
chunk := func(id string, body []byte) []byte {
|
||||
out := append([]byte(id), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
out = append(out, body...)
|
||||
if len(body)&1 == 1 {
|
||||
out = append(out, 0) // pad byte for odd-sized chunks
|
||||
}
|
||||
return out
|
||||
}
|
||||
// fmtBody returns a PCM `fmt ` chunk body. extra bytes simulate the
|
||||
// 18/40-byte extensible form (cbSize + extension).
|
||||
fmtBody := func(channels, bits uint16, rate uint32, extra int) []byte {
|
||||
b := make([]byte, 16+extra)
|
||||
binary.LittleEndian.PutUint16(b[0:2], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(b[2:4], channels)
|
||||
binary.LittleEndian.PutUint32(b[4:8], rate)
|
||||
binary.LittleEndian.PutUint32(b[8:12], rate*uint32(channels)*uint32(bits)/8)
|
||||
binary.LittleEndian.PutUint16(b[12:14], channels*bits/8)
|
||||
binary.LittleEndian.PutUint16(b[14:16], bits)
|
||||
if extra >= 2 {
|
||||
binary.LittleEndian.PutUint16(b[16:18], uint16(extra-2)) // cbSize
|
||||
}
|
||||
return b
|
||||
}
|
||||
// pcm encodes int16 samples little-endian.
|
||||
pcm := func(samples ...int16) []byte {
|
||||
b := make([]byte, len(samples)*2)
|
||||
for i, s := range samples {
|
||||
binary.LittleEndian.PutUint16(b[i*2:i*2+2], uint16(s))
|
||||
}
|
||||
return b
|
||||
}
|
||||
riff := func(chunks ...[]byte) []byte {
|
||||
body := []byte("WAVE")
|
||||
for _, c := range chunks {
|
||||
body = append(body, c...)
|
||||
}
|
||||
out := append([]byte("RIFF"), 0, 0, 0, 0)
|
||||
binary.LittleEndian.PutUint32(out[4:8], uint32(len(body)))
|
||||
return append(out, body...)
|
||||
}
|
||||
writeWAV := func(b []byte) string {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "in.wav")
|
||||
Expect(os.WriteFile(p, b, 0o600)).To(Succeed())
|
||||
return p
|
||||
}
|
||||
// A canonical sample run with distinct values so any off-by-one /
|
||||
// misalignment shows up as wrong numbers, not just wrong length.
|
||||
samples := []int16{1000, -2000, 3000, -4000, 5000, -6000}
|
||||
expectSamples := func(got []float32) {
|
||||
Expect(got).To(HaveLen(len(samples)))
|
||||
for i, s := range samples {
|
||||
Expect(got[i]).To(BeNumerically("~", float32(s)/32768.0, 1e-6))
|
||||
}
|
||||
}
|
||||
|
||||
It("reads a canonical 44-byte WAV", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("ignores a LIST/JUNK chunk placed before data (no leading-impulse splice)", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("JUNK", []byte("padding-bytes-here!")), // odd length → exercises pad
|
||||
chunk("LIST", []byte("INFOISFTLavf60.0")),
|
||||
chunk("data", pcm(samples...)),
|
||||
))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out) // not corrupted by the preceding chunks
|
||||
})
|
||||
|
||||
It("honours the data chunk size and drops a trailing metadata chunk", func() {
|
||||
p := writeWAV(riff(
|
||||
chunk("fmt ", fmtBody(1, 16, 16000, 0)),
|
||||
chunk("data", pcm(samples...)),
|
||||
chunk("LIST", []byte("INFOISFTLavf60.16.100")), // ffmpeg trailer tag
|
||||
))
|
||||
out, _, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectSamples(out) // trailing LIST bytes not decoded as PCM
|
||||
})
|
||||
|
||||
It("handles the 18-byte extensible fmt chunk", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 2)), chunk("data", pcm(samples...))))
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
expectSamples(out)
|
||||
})
|
||||
|
||||
It("rejects non-mono input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(2, 16, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("mono"))
|
||||
})
|
||||
|
||||
It("rejects non-16-bit input", func() {
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 8, 16000, 0)), chunk("data", pcm(samples...))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("16-bit"))
|
||||
})
|
||||
|
||||
It("rejects a non-WAV file", func() {
|
||||
p := writeWAV([]byte("not a riff file at all"))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when the data chunk is missing", func() {
|
||||
// fmt but no data: the decoder must fail rather than return an
|
||||
// empty (or garbage) sample slice. The exact message is the
|
||||
// decoder's, so just assert it errors.
|
||||
p := writeWAV(riff(chunk("fmt ", fmtBody(1, 16, 16000, 0))))
|
||||
_, _, err := readMonoWAVf32(p)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("round-trips through writeMonoWAVf32", func() {
|
||||
p := filepath.Join(GinkgoT().TempDir(), "rt.wav")
|
||||
in := []float32{0.1, -0.2, 0.3, -0.4}
|
||||
Expect(writeMonoWAVf32(p, in, 16000)).To(Succeed())
|
||||
out, sr, err := readMonoWAVf32(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sr).To(Equal(16000))
|
||||
Expect(out).To(HaveLen(len(in)))
|
||||
for i := range in {
|
||||
Expect(out[i]).To(BeNumerically("~", in[i], 1e-4))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("model-gated integration (LOCALVQE_MODEL_PATH)", func() {
|
||||
It("load + sample rate + hop + fft", func() {
|
||||
path := modelPathOrSkip()
|
||||
|
||||
11
backend/go/parakeet-cpp/.gitignore
vendored
Normal file
11
backend/go/parakeet-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
.cache/
|
||||
sources/
|
||||
build/
|
||||
package/
|
||||
parakeet-cpp-grpc
|
||||
# build artifacts staged in-tree by the Makefile (cp from sources/) or
|
||||
# symlinked for local dev; the real sources live in parakeet.cpp upstream.
|
||||
*.so
|
||||
*.so.*
|
||||
parakeet_capi.h
|
||||
compile_commands.json
|
||||
93
backend/go/parakeet-cpp/Makefile
Normal file
93
backend/go/parakeet-cpp/Makefile
Normal file
@@ -0,0 +1,93 @@
|
||||
# parakeet-cpp backend Makefile.
|
||||
#
|
||||
# Upstream pin lives below as PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
|
||||
# (.github/bump_deps.sh) can find and update it - matches the
|
||||
# whisper.cpp / ds4 / vibevoice-cpp convention.
|
||||
#
|
||||
# Local dev shortcut: if you already have an out-of-tree parakeet.cpp
|
||||
# build, you can symlink the .so + header into this directory and skip
|
||||
# the clone/cmake steps entirely, e.g.:
|
||||
#
|
||||
# ln -sf /path/to/parakeet.cpp/build-shared/libparakeet.so .
|
||||
# ln -sf /path/to/parakeet.cpp/include/parakeet_capi.h .
|
||||
# go build -o parakeet-cpp-grpc .
|
||||
#
|
||||
# That's what the L0 smoke test uses. The default target below does the
|
||||
# proper clone-at-pin + cmake build so CI doesn't need a side-checkout.
|
||||
|
||||
PARAKEET_VERSION?=b11fe5bca78ad8b342dd559a43d76df3984bb447
|
||||
PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
||||
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
# Build ggml statically into libparakeet.so (PIC) so the shared lib is
|
||||
# self-contained: dlopen needs no libggml*.so alongside it, only system libs
|
||||
# (libstdc++/libgomp/libc) that the runtime image already provides.
|
||||
CMAKE_ARGS?=-DCMAKE_BUILD_TYPE=Release -DPARAKEET_SHARED=ON -DPARAKEET_BUILD_CLI=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# parakeet.cpp gates its GGML backends behind PARAKEET_GGML_* options and does
|
||||
# set(GGML_CUDA ${PARAKEET_GGML_CUDA} CACHE BOOL "" FORCE), so a bare -DGGML_CUDA=ON
|
||||
# is overwritten back to OFF and the build silently falls back to CPU. Forward the
|
||||
# PARAKEET_GGML_* options instead. (openblas is not gated, so -DGGML_BLAS passes through.)
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_HIP=ON
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DPARAKEET_GGML_VULKAN=ON
|
||||
endif
|
||||
|
||||
.PHONY: parakeet-cpp-grpc package build clean purge test all
|
||||
|
||||
all: parakeet-cpp-grpc
|
||||
|
||||
# Clone the upstream parakeet.cpp source at the pinned commit. Directory
|
||||
# acts as the target so make only re-clones when missing. After a
|
||||
# PARAKEET_VERSION bump, run 'make purge && make' to refetch.
|
||||
sources/parakeet.cpp:
|
||||
mkdir -p sources/parakeet.cpp
|
||||
cd sources/parakeet.cpp && \
|
||||
git init -q && \
|
||||
git remote add origin $(PARAKEET_REPO) && \
|
||||
git fetch --depth 1 origin $(PARAKEET_VERSION) && \
|
||||
git checkout FETCH_HEAD && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Build the shared lib + header out-of-tree, then stage them next to the
|
||||
# Go sources so purego.Dlopen("libparakeet.so") and the cgo-less build
|
||||
# both pick them up.
|
||||
libparakeet.so: sources/parakeet.cpp
|
||||
cmake -B sources/parakeet.cpp/build-shared -S sources/parakeet.cpp $(CMAKE_ARGS)
|
||||
cmake --build sources/parakeet.cpp/build-shared --config Release -j$(JOBS)
|
||||
cp -fv sources/parakeet.cpp/build-shared/libparakeet.so* ./ 2>/dev/null || true
|
||||
cp -fv sources/parakeet.cpp/include/parakeet_capi.h ./
|
||||
|
||||
parakeet-cpp-grpc: libparakeet.so main.go goparakeetcpp.go
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o parakeet-cpp-grpc .
|
||||
|
||||
package: parakeet-cpp-grpc
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
# Test target. Smoke test is gated on PARAKEET_BACKEND_TEST_MODEL +
|
||||
# PARAKEET_BACKEND_TEST_WAV; without them the spec auto-skips.
|
||||
test:
|
||||
LD_LIBRARY_PATH=$(CURDIR):$$LD_LIBRARY_PATH $(GOCMD) test ./... -count=1
|
||||
|
||||
clean: purge
|
||||
rm -rf libparakeet.so* parakeet_capi.h package parakeet-cpp-grpc
|
||||
|
||||
purge:
|
||||
rm -rf sources/parakeet.cpp
|
||||
79
backend/go/parakeet-cpp/batcher.go
Normal file
79
backend/go/parakeet-cpp/batcher.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package main
|
||||
|
||||
import "time"
|
||||
|
||||
// batchRequest is one in-flight unary transcription waiting to be batched.
|
||||
// In production pcm/decoder are set; tag is an opaque marker used by tests.
|
||||
type batchRequest struct {
|
||||
pcm []float32
|
||||
decoder int32
|
||||
tag string
|
||||
reply chan batchReply
|
||||
}
|
||||
|
||||
// batchReply carries one per-item JSON object string (an element of the C-API's
|
||||
// JSON array) or an error back to the waiting handler goroutine.
|
||||
type batchReply struct {
|
||||
json string
|
||||
err error
|
||||
}
|
||||
|
||||
// batcher coalesces concurrent batchRequests into batched runBatch calls. A
|
||||
// single run() goroutine is the sole caller of runBatch, so runBatch (which in
|
||||
// production calls the thread-unsafe C engine) is never entered concurrently.
|
||||
type batcher struct {
|
||||
submit chan *batchRequest
|
||||
maxSize int
|
||||
maxWait time.Duration
|
||||
runBatch func(reqs []*batchRequest) // must deliver a reply to every req
|
||||
}
|
||||
|
||||
func newBatcher(maxSize int, maxWait time.Duration, runBatch func([]*batchRequest)) *batcher {
|
||||
if maxSize < 1 {
|
||||
maxSize = 1
|
||||
}
|
||||
return &batcher{
|
||||
submit: make(chan *batchRequest),
|
||||
maxSize: maxSize,
|
||||
maxWait: maxWait,
|
||||
runBatch: runBatch,
|
||||
}
|
||||
}
|
||||
|
||||
// run is the dispatcher loop: accumulate submitted requests until either maxSize
|
||||
// is reached or maxWait elapses since the first queued request, then dispatch.
|
||||
// Exits when stop is closed (draining any partially-filled batch first).
|
||||
func (b *batcher) run(stop <-chan struct{}) {
|
||||
for {
|
||||
var first *batchRequest
|
||||
select {
|
||||
case first = <-b.submit:
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
batch := []*batchRequest{first}
|
||||
|
||||
// maxSize==1 disables batching: dispatch immediately (passthrough).
|
||||
if b.maxSize == 1 {
|
||||
b.runBatch(batch)
|
||||
continue
|
||||
}
|
||||
|
||||
timer := time.NewTimer(b.maxWait)
|
||||
fill:
|
||||
for len(batch) < b.maxSize {
|
||||
select {
|
||||
case r := <-b.submit:
|
||||
batch = append(batch, r)
|
||||
case <-timer.C:
|
||||
break fill
|
||||
case <-stop:
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
return
|
||||
}
|
||||
}
|
||||
timer.Stop()
|
||||
b.runBatch(batch)
|
||||
}
|
||||
}
|
||||
108
backend/go/parakeet-cpp/batcher_test.go
Normal file
108
backend/go/parakeet-cpp/batcher_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("batcher", func() {
|
||||
echoReply := func(reqs []*batchRequest) {
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{json: r.tag}
|
||||
}
|
||||
}
|
||||
|
||||
It("coalesces concurrent submits into batches", func() {
|
||||
var mu sync.Mutex
|
||||
var sizes []int
|
||||
run := func(reqs []*batchRequest) {
|
||||
mu.Lock()
|
||||
sizes = append(sizes, len(reqs))
|
||||
mu.Unlock()
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(4, 50*time.Millisecond, run)
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
|
||||
const N = 4
|
||||
var wg sync.WaitGroup
|
||||
got := make([]string, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: string(rune('a' + i)), reply: rep}
|
||||
got[i] = (<-rep).json
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
total, maxBatch := 0, 0
|
||||
for _, s := range sizes {
|
||||
total += s
|
||||
if s > maxBatch {
|
||||
maxBatch = s
|
||||
}
|
||||
}
|
||||
Expect(total).To(Equal(N))
|
||||
Expect(maxBatch).To(BeNumerically(">=", 2), "expected at least one batch to coalesce >1 request")
|
||||
})
|
||||
|
||||
It("dispatches when max size is reached", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(2, time.Hour, run) // huge window: only size can trigger
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
for i := 0; i < 2; i++ {
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func(rep chan batchReply) { <-rep }(rep)
|
||||
}
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(2)))
|
||||
})
|
||||
|
||||
It("dispatches when the wait window elapses", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(8, 20*time.Millisecond, run) // size unreachable; window fires
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
|
||||
It("bypasses batching when max size is 1", func() {
|
||||
dispatched := make(chan int, 8)
|
||||
run := func(reqs []*batchRequest) {
|
||||
dispatched <- len(reqs)
|
||||
echoReply(reqs)
|
||||
}
|
||||
b := newBatcher(1, time.Hour, run) // size 1 => immediate dispatch
|
||||
stop := make(chan struct{})
|
||||
go b.run(stop)
|
||||
defer close(stop)
|
||||
rep := make(chan batchReply, 1)
|
||||
b.submit <- &batchRequest{tag: "x", reply: rep}
|
||||
go func() { <-rep }()
|
||||
Eventually(dispatched, "2s").Should(Receive(Equal(1)))
|
||||
})
|
||||
})
|
||||
556
backend/go/parakeet-cpp/goparakeetcpp.go
Normal file
556
backend/go/parakeet-cpp/goparakeetcpp.go
Normal file
@@ -0,0 +1,556 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// purego-bound entry points from libparakeet.so. Names match
|
||||
// parakeet_capi.h exactly so a `nm libparakeet.so | grep parakeet_capi`
|
||||
// is enough to spot drift.
|
||||
//
|
||||
// Functions that return char* are declared as uintptr so we can call
|
||||
// parakeet_capi_free_string on the same pointer after copying, the
|
||||
// C-API contract is "caller owns and must free the returned buffer".
|
||||
var (
|
||||
CppAbiVersion func() int32
|
||||
CppLoad func(ggufPath string) uintptr
|
||||
CppFree func(ctx uintptr)
|
||||
CppTranscribePath func(ctx uintptr, wavPath string, decoder int32) uintptr
|
||||
CppTranscribePathJSON func(ctx uintptr, wavPath string, decoder int32) uintptr
|
||||
CppFreeString func(s uintptr)
|
||||
CppLastError func(ctx uintptr) string
|
||||
|
||||
// Batched JSON transcription: takes a concatenated float buffer of clips
|
||||
// plus their per-clip sample counts (sum(nSamples)==len(samplesConcat))
|
||||
// and returns a malloc'd char* JSON ARRAY of per-clip {"text","words",
|
||||
// "tokens"} objects (uintptr, freed via CppFreeString). purego passes the
|
||||
// Go slices as the base pointer of their backing array (kept alive for the
|
||||
// call), matching the CppStreamFeed pcm []float32 binding pattern; the C
|
||||
// side reads them as const float*/const int*.
|
||||
CppTranscribePcmBatchJSON func(ctx uintptr, samplesConcat []float32, nSamples []int32, nClips int32, sampleRate int32, decoder int32) uintptr
|
||||
|
||||
// Cache-aware streaming (RNN-T) entry points. stream_begin returns 0 for
|
||||
// non-streaming models. feed/finalize return a malloc'd char* (uintptr,
|
||||
// freed via CppFreeString); feed writes 1 to *eouOut on an <EOU>/<EOB>.
|
||||
CppStreamBegin func(ctx uintptr) uintptr
|
||||
CppStreamFeed func(s uintptr, pcm []float32, nSamples int32, eouOut unsafe.Pointer) uintptr
|
||||
CppStreamFinalize func(s uintptr) uintptr
|
||||
CppStreamFree func(s uintptr)
|
||||
)
|
||||
|
||||
// streamChunkSamples is how much 16 kHz mono PCM we hand to stream_feed per
|
||||
// call (1 s). The session buffers internally and decodes once a full
|
||||
// cache-aware encoder chunk is available, so this only bounds how often we
|
||||
// poll for newly-finalized text, not the model's actual chunk size.
|
||||
const streamChunkSamples = 16000
|
||||
|
||||
// transcriptJSON mirrors the document returned by
|
||||
// parakeet_capi_transcribe_path_json (see parakeet_capi.h):
|
||||
//
|
||||
// {"text":"...",
|
||||
// "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...],
|
||||
// "tokens":[{"id":123,"t":0.480,"conf":0.9100}, ...]}
|
||||
//
|
||||
// "start"/"end"/"t" are seconds; "conf" is confidence in (0,1].
|
||||
type transcriptJSON struct {
|
||||
Text string `json:"text"`
|
||||
Words []transcriptWord `json:"words"`
|
||||
Tokens []transcriptToken `json:"tokens"`
|
||||
}
|
||||
|
||||
type transcriptWord struct {
|
||||
W string `json:"w"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Conf float64 `json:"conf"`
|
||||
}
|
||||
|
||||
type transcriptToken struct {
|
||||
ID int32 `json:"id"`
|
||||
T float64 `json:"t"`
|
||||
Conf float64 `json:"conf"`
|
||||
}
|
||||
|
||||
// ParakeetCpp owns a single loaded parakeet_ctx. The C engine is a
|
||||
// thread-unsafe singleton (mirrors whisper.cpp / vibevoice.cpp). Rather than
|
||||
// serialize every call through base.SingleThread, we route unary
|
||||
// transcription through an in-process batcher (its sole dispatcher goroutine
|
||||
// is the only caller of the engine on that path) and guard the shared engine
|
||||
// with engineMu so a streaming session and a batched-unary dispatch never
|
||||
// touch it concurrently.
|
||||
type ParakeetCpp struct {
|
||||
base.Base
|
||||
ctxPtr uintptr
|
||||
engineMu sync.Mutex // sole guard of the one C engine (dispatcher + streaming)
|
||||
bat *batcher
|
||||
batStop chan struct{}
|
||||
}
|
||||
|
||||
// Load is the LocalAI gRPC entry point for LoadModel: it calls
|
||||
// parakeet_capi_load with the GGUF path and stashes the resulting
|
||||
// opaque context pointer for AudioTranscription.
|
||||
func (p *ParakeetCpp) Load(opts *pb.ModelOptions) error {
|
||||
if opts.ModelFile == "" {
|
||||
return errors.New("parakeet-cpp: ModelFile is required")
|
||||
}
|
||||
|
||||
ctx := CppLoad(opts.ModelFile)
|
||||
if ctx == 0 {
|
||||
// No ctx to ask for last_error (the C-API's last-error buffer
|
||||
// lives on the ctx that was never returned). Surface the path
|
||||
// so the operator at least knows which load failed.
|
||||
return fmt.Errorf("parakeet-cpp: parakeet_capi_load failed for %q", opts.ModelFile)
|
||||
}
|
||||
p.ctxPtr = ctx
|
||||
|
||||
// Dynamic batching knobs (model YAML options:, key:value form). Batching is
|
||||
// OFF by default (batch_max_size:1): each request runs on its own. On GPU,
|
||||
// raising batch_max_size coalesces concurrent requests into one batched
|
||||
// engine call and improves throughput under load; leave it at 1 on CPU and
|
||||
// for low-concurrency setups, where batching only adds latency.
|
||||
maxSize := optInt(opts, "batch_max_size", 1)
|
||||
maxWaitMs := optInt(opts, "batch_max_wait_ms", 15)
|
||||
if maxWaitMs < 0 {
|
||||
maxWaitMs = 0
|
||||
}
|
||||
if CppTranscribePcmBatchJSON != nil {
|
||||
p.batStop = make(chan struct{})
|
||||
p.bat = newBatcher(maxSize, time.Duration(maxWaitMs)*time.Millisecond, p.runBatch)
|
||||
go p.bat.run(p.batStop) // dispatcher runs until Free closes batStop
|
||||
if maxSize > 1 {
|
||||
xlog.Info("parakeet-cpp: dynamic batching enabled",
|
||||
"batch_max_size", maxSize, "batch_max_wait_ms", maxWaitMs)
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: dynamic batching off (batch_max_size=1); " +
|
||||
"set batch_max_size>1 to coalesce concurrent requests on GPU")
|
||||
}
|
||||
} else {
|
||||
xlog.Info("parakeet-cpp: batched C-API not present in libparakeet.so; " +
|
||||
"batching disabled, using per-request transcription")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// optInt reads an integer model option (key:value form) from ModelOptions,
|
||||
// returning def when absent or unparseable. The options array carries the
|
||||
// model YAML's options: entries (see core/config; siblings such as
|
||||
// acestep-cpp parse the same key:value form via strings.Cut on ":").
|
||||
func optInt(opts *pb.ModelOptions, key string, def int) int {
|
||||
for _, o := range opts.GetOptions() {
|
||||
k, v, ok := strings.Cut(o, ":")
|
||||
if ok && strings.TrimSpace(k) == key {
|
||||
if n, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// runBatch is the dispatcher's batch handler and the ONLY caller of the C
|
||||
// engine on the unary path. It concatenates the batch PCM, calls the batched
|
||||
// JSON C-API under engineMu, splits the JSON array, and replies to each request.
|
||||
func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// Observability: the actual coalesced batch size per engine call. Debug-level
|
||||
// so it stays silent in normal operation but lets operators confirm/tune batching.
|
||||
xlog.Debug("parakeet-cpp: dispatching batch", "size", len(reqs))
|
||||
nSamples := make([]int32, len(reqs))
|
||||
total := 0
|
||||
for i, r := range reqs {
|
||||
nSamples[i] = int32(len(r.pcm))
|
||||
total += len(r.pcm)
|
||||
}
|
||||
concat := make([]float32, 0, total)
|
||||
for _, r := range reqs {
|
||||
concat = append(concat, r.pcm...)
|
||||
}
|
||||
var dec int32
|
||||
if len(reqs) > 0 {
|
||||
dec = reqs[0].decoder
|
||||
}
|
||||
p.engineMu.Lock()
|
||||
cstr := CppTranscribePcmBatchJSON(p.ctxPtr, concat, nSamples, int32(len(reqs)), 16000, dec)
|
||||
p.engineMu.Unlock()
|
||||
if cstr == 0 {
|
||||
err := fmt.Errorf("parakeet-cpp: batch transcribe failed: %s", CppLastError(p.ctxPtr))
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: err}
|
||||
}
|
||||
return
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var docs []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &docs); err != nil || len(docs) != len(reqs) {
|
||||
e := fmt.Errorf("parakeet-cpp: batch json: got %d results for %d reqs (%v)", len(docs), len(reqs), err)
|
||||
for _, r := range reqs {
|
||||
r.reply <- batchReply{err: e}
|
||||
}
|
||||
return
|
||||
}
|
||||
for i, r := range reqs {
|
||||
r.reply <- batchReply{json: string(docs[i])}
|
||||
}
|
||||
}
|
||||
|
||||
// AudioTranscription decodes the wav at opts.Dst to 16 kHz mono PCM and
|
||||
// submits it to the in-process batcher, which coalesces concurrent requests
|
||||
// into a single batched engine call (parakeet_capi_transcribe_pcm_batch_json)
|
||||
// with the default decoder (decoder=0, which selects the right head per
|
||||
// architecture: transducer for tdt/rnnt/hybrid, CTC for ctc) and shapes the
|
||||
// per-word timestamps into a LocalAI TranscriptResult.
|
||||
//
|
||||
// Parakeet emits word- and token-level timestamps but no native segment
|
||||
// boundaries, so we synthesise a single whole-clip segment spanning the first
|
||||
// word start to the last word end. Word-level timings are attached only when
|
||||
// the caller opts in via timestamp_granularities=["word"] (matching the
|
||||
// OpenAI API, whose default is segment-level); token ids always populate
|
||||
// Segment.Tokens.
|
||||
//
|
||||
// translate/diarize/prompt/temperature/language/threads are not applicable to
|
||||
// parakeet and are ignored; streaming is handled by AudioTranscriptionStream
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
|
||||
// Fallback when the batched C-API is unavailable: transcribe from a file
|
||||
// path (original behavior, no batching). The C library's audio loader only
|
||||
// understands 16 kHz mono WAV/PCM, so convert the input first - otherwise
|
||||
// any non-WAV upload (MP3, etc.) fails with "failed to load audio". This
|
||||
// mirrors what every other audio backend (whisper, crispasr) does via
|
||||
// utils.AudioToWav before handing the file to the engine.
|
||||
if p.bat == nil {
|
||||
converted, cleanup, err := convertToWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
cstr := CppTranscribePathJSON(p.ctxPtr, converted, 0)
|
||||
if cstr == 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: transcribe_path_json failed: %s", CppLastError(p.ctxPtr))
|
||||
}
|
||||
raw := goStringFromCPtr(cstr)
|
||||
CppFreeString(cstr)
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(raw), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
// Batched path: decode to PCM, submit to the batcher, wait for this request's
|
||||
// JSON element. The dispatcher is the sole engine caller on this path; both
|
||||
// sends honour ctx cancellation.
|
||||
pcm, _, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
rep := make(chan batchReply, 1)
|
||||
select {
|
||||
case p.bat.submit <- &batchRequest{pcm: pcm, decoder: 0, reply: rep}:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
var res batchReply
|
||||
select {
|
||||
case res = <-rep:
|
||||
case <-ctx.Done():
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if res.err != nil {
|
||||
return pb.TranscriptResult{}, res.err
|
||||
}
|
||||
var doc transcriptJSON
|
||||
if err := json.Unmarshal([]byte(res.json), &doc); err != nil {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("parakeet-cpp: decode transcript json: %w", err)
|
||||
}
|
||||
return transcriptResultFromDoc(doc, opts), nil
|
||||
}
|
||||
|
||||
// transcriptResultFromDoc maps a decoded transcriptJSON to a TranscriptResult,
|
||||
// synthesising a single whole-clip segment and attaching word timings only when
|
||||
// the caller requested word granularity. Shared by the batched and direct paths.
|
||||
func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest) pb.TranscriptResult {
|
||||
text := strings.TrimSpace(doc.Text)
|
||||
words := make([]*pb.TranscriptWord, 0, len(doc.Words))
|
||||
for _, w := range doc.Words {
|
||||
words = append(words, &pb.TranscriptWord{Start: secondsToNanos(w.Start), End: secondsToNanos(w.End), Text: w.W})
|
||||
}
|
||||
tokens := make([]int32, 0, len(doc.Tokens))
|
||||
for _, t := range doc.Tokens {
|
||||
tokens = append(tokens, t.ID)
|
||||
}
|
||||
var segStart, segEnd int64
|
||||
if len(words) > 0 {
|
||||
segStart = words[0].Start
|
||||
segEnd = words[len(words)-1].End
|
||||
}
|
||||
seg := &pb.TranscriptSegment{Id: 0, Start: segStart, End: segEnd, Text: text, Tokens: tokens}
|
||||
if wordsRequested(opts.TimestampGranularities) {
|
||||
seg.Words = words
|
||||
}
|
||||
return pb.TranscriptResult{Text: text, Segments: []*pb.TranscriptSegment{seg}}
|
||||
}
|
||||
|
||||
// wordsRequested reports whether the caller asked for word-level timestamps.
|
||||
// The OpenAI transcription API gates word timings behind
|
||||
// timestamp_granularities[] containing "word" and defaults to segment-level
|
||||
// otherwise; we follow that contract.
|
||||
func wordsRequested(granularities []string) bool {
|
||||
for _, g := range granularities {
|
||||
if strings.EqualFold(strings.TrimSpace(g), "word") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// secondsToNanos converts the C-API's fractional-second timestamps into the
|
||||
// int64 nanoseconds LocalAI carries on TranscriptSegment/TranscriptWord, the
|
||||
// same nanosecond convention the whisper backend uses.
|
||||
func secondsToNanos(sec float64) int64 {
|
||||
return int64(sec * 1e9)
|
||||
}
|
||||
|
||||
// AudioTranscriptionStream drives the cache-aware streaming RNN-T over the
|
||||
// audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it in
|
||||
// chunks to parakeet_capi_stream_feed, and emits each newly-finalized text
|
||||
// run as a TranscriptStreamResponse delta. <EOU>/<EOB> events close the
|
||||
// current segment; a closing FinalResult carries the full transcript and the
|
||||
// per-utterance segments.
|
||||
//
|
||||
// stream_begin returns 0 for models that are not cache-aware streaming models
|
||||
// (only e.g. nvidia/parakeet_realtime_eou_120m-v1 qualifies). For those we fall
|
||||
// back to a single offline transcription emitted as one delta plus a closing
|
||||
// FinalResult, matching LocalAI's non-streaming streaming contract (and the
|
||||
// whisper backend), so the streaming endpoint works for every model.
|
||||
func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
|
||||
stream := CppStreamBegin(p.ctxPtr)
|
||||
if stream == 0 {
|
||||
// Not a cache-aware streaming model: run a normal offline
|
||||
// transcription and emit it as one delta + a closing final result.
|
||||
res, err := p.AudioTranscription(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t := strings.TrimSpace(res.Text); t != "" {
|
||||
results <- &pb.TranscriptStreamResponse{Delta: t}
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{FinalResult: &res}
|
||||
return nil
|
||||
}
|
||||
defer CppStreamFree(stream)
|
||||
// The C engine is a single shared context: a streaming session and a batched
|
||||
// unary dispatch must never touch it at once, so hold engineMu for the whole
|
||||
// stream. This lock is intentionally taken AFTER the non-streaming fallback
|
||||
// above returns: that fallback goes through AudioTranscription -> the batcher
|
||||
// -> runBatch, which itself acquires engineMu, so locking here first would
|
||||
// deadlock. Do not hoist this lock above the fallback.
|
||||
p.engineMu.Lock()
|
||||
defer p.engineMu.Unlock()
|
||||
|
||||
data, duration, err := decodeWavMono16k(opts.Dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
full strings.Builder
|
||||
segText strings.Builder
|
||||
segments []*pb.TranscriptSegment
|
||||
segID int32
|
||||
)
|
||||
|
||||
flushSegment := func() {
|
||||
t := strings.TrimSpace(segText.String())
|
||||
segText.Reset()
|
||||
if t == "" {
|
||||
return
|
||||
}
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: segID, Text: t})
|
||||
segID++
|
||||
}
|
||||
|
||||
// emitDelta consumes the malloc'd char* returned by feed/finalize: frees
|
||||
// it, accumulates the text, and sends a delta when non-empty. A 0 return
|
||||
// is an error (vs the "" empty-but-non-NULL no-new-text case).
|
||||
emitDelta := func(ret uintptr) error {
|
||||
if ret == 0 {
|
||||
msg := CppLastError(p.ctxPtr)
|
||||
if msg == "" {
|
||||
msg = "unknown error"
|
||||
}
|
||||
return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg)
|
||||
}
|
||||
delta := goStringFromCPtr(ret)
|
||||
CppFreeString(ret)
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
full.WriteString(delta)
|
||||
segText.WriteString(delta)
|
||||
results <- &pb.TranscriptStreamResponse{Delta: delta}
|
||||
return nil
|
||||
}
|
||||
|
||||
for off := 0; off < len(data); off += streamChunkSamples {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
end := min(off+streamChunkSamples, len(data))
|
||||
chunk := data[off:end]
|
||||
|
||||
var eou int32
|
||||
ret := CppStreamFeed(stream, chunk, int32(len(chunk)), unsafe.Pointer(&eou))
|
||||
if err := emitDelta(ret); err != nil {
|
||||
return err
|
||||
}
|
||||
if eou != 0 {
|
||||
flushSegment()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the streaming tail (final encoder chunk).
|
||||
if err := emitDelta(CppStreamFinalize(stream)); err != nil {
|
||||
return err
|
||||
}
|
||||
flushSegment()
|
||||
|
||||
text := strings.TrimSpace(full.String())
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text})
|
||||
}
|
||||
results <- &pb.TranscriptStreamResponse{
|
||||
FinalResult: &pb.TranscriptResult{
|
||||
Text: text,
|
||||
Segments: segments,
|
||||
Duration: duration,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeWavMono16k converts any input audio to 16 kHz mono PCM and returns the
|
||||
// float samples plus the clip duration in seconds. Mirrors the whisper
|
||||
// backend: utils.AudioToWav (ffmpeg) normalises rate/channels, go-audio
|
||||
// decodes the PCM.
|
||||
// convertToWavMono16k converts an arbitrary audio file to a 16 kHz mono WAV in
|
||||
// a fresh temp dir and returns the path together with a cleanup func the caller
|
||||
// must defer. WAV inputs already at 16 kHz/mono/16-bit are passed through by
|
||||
// utils.AudioToWav (hardlink/copy), everything else is transcoded via ffmpeg.
|
||||
// Used by the direct (non-batched) transcription path, which hands a file path
|
||||
// to the C library's WAV-only audio loader.
|
||||
func convertToWavMono16k(path string) (string, func(), error) {
|
||||
dir, err := os.MkdirTemp("", "parakeet")
|
||||
if err != nil {
|
||||
return "", func() {}, err
|
||||
}
|
||||
cleanup := func() { _ = os.RemoveAll(dir) }
|
||||
|
||||
converted := filepath.Join(dir, "converted.wav")
|
||||
if err := utils.AudioToWav(path, converted); err != nil {
|
||||
cleanup()
|
||||
return "", func() {}, err
|
||||
}
|
||||
return converted, cleanup, nil
|
||||
}
|
||||
|
||||
func decodeWavMono16k(path string) ([]float32, float32, error) {
|
||||
converted, cleanup, err := convertToWavMono16k(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
fh, err := os.Open(converted)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
buf, err := wav.NewDecoder(fh).FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
return data, duration, nil
|
||||
}
|
||||
|
||||
// Free releases the underlying parakeet_ctx. Called by LocalAI when the
|
||||
// model is unloaded.
|
||||
func (p *ParakeetCpp) Free() error {
|
||||
// Stop the dispatcher before releasing the engine so no in-flight runBatch
|
||||
// can touch a freed ctx (close leak / use-after-free on reload).
|
||||
if p.batStop != nil {
|
||||
close(p.batStop)
|
||||
p.batStop = nil
|
||||
}
|
||||
if p.ctxPtr != 0 {
|
||||
CppFree(p.ctxPtr)
|
||||
p.ctxPtr = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// goStringFromCPtr copies a NUL-terminated C string into Go memory.
|
||||
// cptr is the raw pointer returned by purego from the C-API (a malloc'd
|
||||
// buffer the caller owns); callers must free it via CppFreeString after
|
||||
// the copy lands.
|
||||
//
|
||||
// The uintptr->unsafe.Pointer conversion below trips go vet's unsafeptr
|
||||
// check, which can't distinguish a C-owned heap pointer from Go-managed
|
||||
// memory. It is safe here: the pointer addresses a malloc'd C buffer the
|
||||
// Go GC neither tracks nor moves, and we dereference it immediately to
|
||||
// copy the bytes out, the same pattern (and the same tolerated warning)
|
||||
// as the whisper backend's unsafe.Slice over segsPtr.
|
||||
func goStringFromCPtr(cptr uintptr) string {
|
||||
if cptr == 0 {
|
||||
return ""
|
||||
}
|
||||
p := unsafe.Pointer(cptr) //nolint:govet // C-owned malloc'd buffer, not Go-GC memory (see doc above)
|
||||
n := 0
|
||||
for *(*byte)(unsafe.Add(p, n)) != 0 {
|
||||
n++
|
||||
}
|
||||
return string(unsafe.Slice((*byte)(p), n))
|
||||
}
|
||||
221
backend/go/parakeet-cpp/goparakeetcpp_test.go
Normal file
221
backend/go/parakeet-cpp/goparakeetcpp_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/go-audio/wav"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestParakeetCpp(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "parakeet-cpp Backend Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive
|
||||
// the C-API bridge without spinning up the gRPC server. Skips the
|
||||
// current spec when libparakeet.so isn't loadable from cwd
|
||||
// ($LD_LIBRARY_PATH or a symlink in ./).
|
||||
func ensureLibLoaded() {
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("PARAKEET_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libparakeet.so"
|
||||
}
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppAbiVersion, lib, "parakeet_capi_abi_version")
|
||||
purego.RegisterLibFunc(&CppLoad, lib, "parakeet_capi_load")
|
||||
purego.RegisterLibFunc(&CppFree, lib, "parakeet_capi_free")
|
||||
purego.RegisterLibFunc(&CppTranscribePath, lib, "parakeet_capi_transcribe_path")
|
||||
purego.RegisterLibFunc(&CppTranscribePathJSON, lib, "parakeet_capi_transcribe_path_json")
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
purego.RegisterLibFunc(&CppStreamBegin, lib, "parakeet_capi_stream_begin")
|
||||
purego.RegisterLibFunc(&CppStreamFeed, lib, "parakeet_capi_stream_feed")
|
||||
purego.RegisterLibFunc(&CppStreamFinalize, lib, "parakeet_capi_stream_finalize")
|
||||
purego.RegisterLibFunc(&CppStreamFree, lib, "parakeet_capi_stream_free")
|
||||
purego.RegisterLibFunc(&CppFreeString, lib, "parakeet_capi_free_string")
|
||||
purego.RegisterLibFunc(&CppLastError, lib, "parakeet_capi_last_error")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
Skip("libparakeet.so not loadable: " + libLoadErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// fixturesOrSkip returns the model + audio paths or skips the spec if
|
||||
// either env var is unset. The smoke test never runs in default CI; it
|
||||
// needs a real parakeet GGUF and a 16 kHz mono WAV on disk.
|
||||
func fixturesOrSkip() (string, string) {
|
||||
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_MODEL")
|
||||
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set PARAKEET_BACKEND_TEST_MODEL and PARAKEET_BACKEND_TEST_WAV to run this spec")
|
||||
}
|
||||
return modelPath, audioPath
|
||||
}
|
||||
|
||||
// writeMono16kWav writes `samples` frames of 16 kHz mono 16-bit silence to
|
||||
// path. The result is already in AudioToWav's target format, so the conversion
|
||||
// helper copies it through without invoking ffmpeg.
|
||||
func writeMono16kWav(path string, samples int) {
|
||||
GinkgoHelper()
|
||||
f, err := os.Create(path)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
enc := wav.NewEncoder(f, 16000, 16, 1, 1)
|
||||
buf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: 1, SampleRate: 16000},
|
||||
SourceBitDepth: 16,
|
||||
Data: make([]int, samples),
|
||||
}
|
||||
Expect(enc.Write(buf)).To(Succeed())
|
||||
Expect(enc.Close()).To(Succeed())
|
||||
Expect(f.Close()).To(Succeed())
|
||||
}
|
||||
|
||||
var _ = Describe("ParakeetCpp", func() {
|
||||
Context("AudioTranscription", func() {
|
||||
It("transcribes a WAV via the parakeet C-API", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(strings.TrimSpace(res.Text)).ToNot(BeEmpty(),
|
||||
"expected non-empty transcript for %s", audioPath)
|
||||
Expect(res.Segments).To(HaveLen(1),
|
||||
"synthesises a single whole-clip segment")
|
||||
Expect(res.Segments[0].Text).To(Equal(res.Text),
|
||||
"single segment text must equal the top-level text")
|
||||
// Default (no granularities) is segment-level: no per-word timings.
|
||||
Expect(res.Segments[0].Words).To(BeEmpty(),
|
||||
"word timings are opt-in via timestamp_granularities")
|
||||
})
|
||||
|
||||
It("emits word-level timestamps when granularity=word", func() {
|
||||
modelPath, audioPath := fixturesOrSkip()
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
res, err := p.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
TimestampGranularities: []string{"word"},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Segments).To(HaveLen(1))
|
||||
seg := res.Segments[0]
|
||||
Expect(seg.Words).ToNot(BeEmpty(),
|
||||
"expected per-word timestamps with granularity=word")
|
||||
// Monotonic, non-negative timings spanning the segment.
|
||||
Expect(seg.Words[0].Start).To(BeNumerically(">=", int64(0)))
|
||||
Expect(seg.End).To(BeNumerically(">=", seg.Start))
|
||||
Expect(seg.Words[len(seg.Words)-1].End).To(Equal(seg.End),
|
||||
"segment end tracks the last word")
|
||||
})
|
||||
})
|
||||
|
||||
Context("convertToWavMono16k", func() {
|
||||
// The non-batched transcription path hands a file path to the C
|
||||
// library's WAV-only audio loader, so it must convert first.
|
||||
// utils.AudioToWav passes an already-16kHz/mono/16-bit WAV through
|
||||
// without ffmpeg, which lets us exercise the helper (and the
|
||||
// regression: the direct path used to skip conversion entirely)
|
||||
// without a model, the C library, or ffmpeg.
|
||||
It("returns a decodable 16kHz mono WAV copy and cleans it up", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
src := filepath.Join(dir, "input.wav")
|
||||
writeMono16kWav(src, 16000) // 1s of silence at 16 kHz
|
||||
|
||||
converted, cleanup, err := convertToWavMono16k(src)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// It must produce a fresh temp file, not return the original path.
|
||||
Expect(converted).ToNot(Equal(src))
|
||||
Expect(converted).To(BeAnExistingFile())
|
||||
|
||||
pcm, _, err := decodeWavMono16k(converted)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pcm).To(HaveLen(16000), "round-trips the sample count")
|
||||
|
||||
cleanup()
|
||||
Expect(converted).ToNot(BeAnExistingFile(), "cleanup removes the temp dir")
|
||||
})
|
||||
|
||||
It("errors on a non-existent input rather than passing the path through", func() {
|
||||
_, _, err := convertToWavMono16k(filepath.Join(GinkgoT().TempDir(), "missing.mp3"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("AudioTranscriptionStream", func() {
|
||||
It("streams deltas and a closing FinalResult from a cache-aware model", func() {
|
||||
// Streaming needs a cache-aware streaming model (e.g.
|
||||
// realtime_eou); the offline test model would fail stream_begin.
|
||||
modelPath := os.Getenv("PARAKEET_BACKEND_TEST_STREAM_MODEL")
|
||||
audioPath := os.Getenv("PARAKEET_BACKEND_TEST_WAV")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
Skip("set PARAKEET_BACKEND_TEST_STREAM_MODEL (cache-aware streaming model) and PARAKEET_BACKEND_TEST_WAV")
|
||||
}
|
||||
ensureLibLoaded()
|
||||
|
||||
p := &ParakeetCpp{}
|
||||
Expect(p.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
|
||||
defer func() { _ = p.Free() }()
|
||||
|
||||
results := make(chan *pb.TranscriptStreamResponse, 64)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- p.AudioTranscriptionStream(context.Background(),
|
||||
&pb.TranscriptRequest{Dst: audioPath}, results)
|
||||
}()
|
||||
|
||||
var deltas []string
|
||||
var final *pb.TranscriptResult
|
||||
for r := range results {
|
||||
if r.Delta != "" {
|
||||
deltas = append(deltas, r.Delta)
|
||||
}
|
||||
if r.FinalResult != nil {
|
||||
final = r.FinalResult
|
||||
}
|
||||
}
|
||||
Expect(<-errCh).ToNot(HaveOccurred())
|
||||
|
||||
Expect(final).ToNot(BeNil(), "expected a closing FinalResult")
|
||||
Expect(strings.TrimSpace(final.Text)).ToNot(BeEmpty(),
|
||||
"expected a non-empty streamed transcript")
|
||||
Expect(final.Segments).ToNot(BeEmpty(),
|
||||
"FinalResult always carries at least one segment")
|
||||
// The concatenated deltas reconstruct the final transcript.
|
||||
Expect(strings.TrimSpace(strings.Join(deltas, ""))).To(Equal(strings.TrimSpace(final.Text)),
|
||||
"deltas should reconstruct the final text")
|
||||
})
|
||||
})
|
||||
})
|
||||
75
backend/go/parakeet-cpp/main.go
Normal file
75
backend/go/parakeet-cpp/main.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package main
|
||||
|
||||
// Started internally by LocalAI - one gRPC server per loaded model.
|
||||
//
|
||||
// Loads libparakeet.so via purego and registers the flat C-API entry
|
||||
// points declared in parakeet_capi.h. The library name can be overridden
|
||||
// with PARAKEET_LIBRARY (mirrors the WHISPER_LIBRARY / VIBEVOICECPP_LIBRARY
|
||||
// convention in the sibling backends); the default looks for the .so next
|
||||
// to this binary.
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
libName := os.Getenv("PARAKEET_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "libparakeet.so"
|
||||
}
|
||||
|
||||
lib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("parakeet-cpp: dlopen %q: %w", libName, err))
|
||||
}
|
||||
|
||||
// Bound 1:1 to parakeet_capi.h. The C-API returns malloc'd char*
|
||||
// buffers from transcribe_*; we register those as uintptr so we get
|
||||
// the raw pointer back and can call parakeet_capi_free_string on it
|
||||
// (purego's string return would copy and forget the original pointer,
|
||||
// leaking it on every call).
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppAbiVersion, "parakeet_capi_abi_version"},
|
||||
{&CppLoad, "parakeet_capi_load"},
|
||||
{&CppFree, "parakeet_capi_free"},
|
||||
{&CppTranscribePath, "parakeet_capi_transcribe_path"},
|
||||
{&CppTranscribePathJSON, "parakeet_capi_transcribe_path_json"},
|
||||
{&CppStreamBegin, "parakeet_capi_stream_begin"},
|
||||
{&CppStreamFeed, "parakeet_capi_stream_feed"},
|
||||
{&CppStreamFinalize, "parakeet_capi_stream_finalize"},
|
||||
{&CppStreamFree, "parakeet_capi_stream_free"},
|
||||
{&CppFreeString, "parakeet_capi_free_string"},
|
||||
{&CppLastError, "parakeet_capi_last_error"},
|
||||
}
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, lib, lf.Name)
|
||||
}
|
||||
|
||||
// The batched-JSON entry point exists only in newer libparakeet.so (ABI >= 2).
|
||||
// Probe with Dlsym and register only if present, so the backend still loads
|
||||
// against an older library (it falls back to per-request transcription).
|
||||
if sym, err := purego.Dlsym(lib, "parakeet_capi_transcribe_pcm_batch_json"); err == nil && sym != 0 {
|
||||
purego.RegisterLibFunc(&CppTranscribePcmBatchJSON, lib, "parakeet_capi_transcribe_pcm_batch_json")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[parakeet-cpp] ABI=%d\n", CppAbiVersion())
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &ParakeetCpp{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
backend/go/parakeet-cpp/package.sh
Executable file
23
backend/go/parakeet-cpp/package.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# L0 packaging stub: copy the binary, run.sh and libparakeet.so* into
|
||||
# package/. The full ldd walk (libc, libstdc++, libgomp, GPU runtimes,
|
||||
# arch detection) lands in L3, mirroring backend/go/whisper/package.sh.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
mkdir -p "$CURDIR/package/lib"
|
||||
|
||||
cp -avf "$CURDIR/parakeet-cpp-grpc" "$CURDIR/package/"
|
||||
cp -avf "$CURDIR/run.sh" "$CURDIR/package/"
|
||||
|
||||
# libparakeet.so + any soname symlinks (libparakeet.so.X, libparakeet.so.X.Y).
|
||||
cp -avf "$CURDIR"/libparakeet.so* "$CURDIR/package/lib/" 2>/dev/null || {
|
||||
echo "ERROR: libparakeet.so not found in $CURDIR, run 'make' first" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo "L0 package layout (full ldd walk lands in L3):"
|
||||
ls -liah "$CURDIR/package/" "$CURDIR/package/lib/"
|
||||
16
backend/go/parakeet-cpp/run.sh
Executable file
16
backend/go/parakeet-cpp/run.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
|
||||
export LD_LIBRARY_PATH="$CURDIR/lib:$CURDIR:${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# If a self-contained ld.so was packaged, route through it so the
|
||||
# packaged libc / libstdc++ are used instead of the host's (matches the
|
||||
# whisper backend's runtime layout).
|
||||
if [ -f "$CURDIR/lib/ld.so" ]; then
|
||||
echo "Using lib/ld.so"
|
||||
exec "$CURDIR/lib/ld.so" "$CURDIR/parakeet-cpp-grpc" "$@"
|
||||
fi
|
||||
|
||||
exec "$CURDIR/parakeet-cpp-grpc" "$@"
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# qwen3-tts.cpp version
|
||||
QWEN3TTS_REPO?=https://github.com/predict-woo/qwen3-tts.cpp
|
||||
QWEN3TTS_CPP_VERSION?=7a762e2ad4bacc6fdda81d81bf10a09ffb546f29
|
||||
QWEN3TTS_CPP_VERSION?=136e5d36c17083da0321fd96512dc7b263f94a44
|
||||
SO_TARGET?=libgoqwen3ttscpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -21,6 +22,43 @@ type Qwen3TtsCpp struct {
|
||||
threads int
|
||||
}
|
||||
|
||||
// languageNameAliases maps common full language names to the canonical
|
||||
// two-letter code understood by the C++ language_to_id table.
|
||||
var languageNameAliases = map[string]string{
|
||||
"english": "en",
|
||||
"russian": "ru",
|
||||
"chinese": "zh",
|
||||
"japanese": "ja",
|
||||
"korean": "ko",
|
||||
"german": "de",
|
||||
"french": "fr",
|
||||
"spanish": "es",
|
||||
"italian": "it",
|
||||
"portuguese": "pt",
|
||||
}
|
||||
|
||||
// normalizeLanguage coerces a caller-supplied language into the canonical code
|
||||
// the model expects. It lowercases, trims, strips any region/locale suffix
|
||||
// (en-US, en_US, ja.JP -> en/ja), and resolves common full names (english -> en).
|
||||
// An empty input stays empty so the C++ side applies its English default; an
|
||||
// unrecognized value is returned normalized so C++ can log it and default.
|
||||
func normalizeLanguage(lang string) string {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
if lang == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip region/locale suffix: keep the segment before the first separator.
|
||||
if i := strings.IndexAny(lang, "-_."); i >= 0 {
|
||||
lang = lang[:i]
|
||||
}
|
||||
|
||||
if code, ok := languageNameAliases[lang]; ok {
|
||||
return code
|
||||
}
|
||||
return lang
|
||||
}
|
||||
|
||||
func (q *Qwen3TtsCpp) Load(opts *pb.ModelOptions) error {
|
||||
// ModelFile is the model directory path (containing GGUF files)
|
||||
modelDir := opts.ModelFile
|
||||
@@ -54,7 +92,7 @@ func (q *Qwen3TtsCpp) TTS(req *pb.TTSRequest) error {
|
||||
dst := req.Dst
|
||||
language := ""
|
||||
if req.Language != nil {
|
||||
language = *req.Language
|
||||
language = normalizeLanguage(*req.Language)
|
||||
}
|
||||
|
||||
// Synthesis parameters with sensible defaults
|
||||
|
||||
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
53
backend/go/qwen3-tts-cpp/language_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLanguageNormalization(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "qwen3-tts-cpp language normalization")
|
||||
}
|
||||
|
||||
var _ = Describe("normalizeLanguage", func() {
|
||||
DescribeTable("maps caller input to the canonical model language code",
|
||||
func(input, expected string) {
|
||||
Expect(normalizeLanguage(input)).To(Equal(expected))
|
||||
},
|
||||
// Canonical codes pass through unchanged
|
||||
Entry("canonical en", "en", "en"),
|
||||
Entry("canonical zh", "zh", "zh"),
|
||||
Entry("canonical pt", "pt", "pt"),
|
||||
|
||||
// Case-insensitive
|
||||
Entry("uppercase", "EN", "en"),
|
||||
Entry("mixed case", "Ja", "ja"),
|
||||
|
||||
// Surrounding whitespace
|
||||
Entry("trims whitespace", " en ", "en"),
|
||||
|
||||
// Region/locale stripping
|
||||
Entry("BCP-47 region", "en-US", "en"),
|
||||
Entry("underscore region", "en_US", "en"),
|
||||
Entry("dotted locale", "ja.JP", "ja"),
|
||||
Entry("region + case", "ZH-CN", "zh"),
|
||||
|
||||
// Full-name aliases
|
||||
Entry("english name", "english", "en"),
|
||||
Entry("chinese name cased", "Chinese", "zh"),
|
||||
Entry("japanese name", "japanese", "ja"),
|
||||
Entry("russian name", "russian", "ru"),
|
||||
Entry("portuguese name", "portuguese", "pt"),
|
||||
|
||||
// Empty stays empty (C++ applies the English default)
|
||||
Entry("empty", "", ""),
|
||||
Entry("whitespace only", " ", ""),
|
||||
|
||||
// Unknown values pass through normalized so C++ can log + default
|
||||
Entry("unknown code", "klingon", "klingon"),
|
||||
Entry("unknown with region", "xx-YY", "xx"),
|
||||
)
|
||||
})
|
||||
7
backend/go/rfdetr-cpp/.gitignore
vendored
Normal file
7
backend/go/rfdetr-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
sources/
|
||||
build*/
|
||||
package/
|
||||
librfdetrcpp*.so
|
||||
rfdetr-cpp
|
||||
test-models/
|
||||
test-data/
|
||||
79
backend/go/rfdetr-cpp/CMakeLists.txt
Normal file
79
backend/go/rfdetr-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,79 @@
|
||||
cmake_minimum_required(VERSION 3.18)
|
||||
project(librfdetrcpp LANGUAGES C CXX)
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# Static-link ggml + rfdetr so the resulting .so has no runtime dependency on
|
||||
# extra ggml/rfdetr shared libraries — only on libc/libstdc++/libgomp, which
|
||||
# the LocalAI package step bundles into the docker image.
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
|
||||
|
||||
# rfdetr.cpp build switches: skip CLI/tests, keep static lib.
|
||||
set(RFDETR_BUILD_CLI OFF CACHE BOOL "Disable rfdetr CLI" FORCE)
|
||||
set(RFDETR_BUILD_TESTS OFF CACHE BOOL "Disable rfdetr tests" FORCE)
|
||||
set(RFDETR_SHARED OFF CACHE BOOL "Build rfdetr as static lib" FORCE)
|
||||
|
||||
# rt-detr.cpp's top-level CMakeLists invokes
|
||||
# `bash ${CMAKE_SOURCE_DIR}/scripts/apply_ggml_patches.sh` to apply its
|
||||
# in-tree ggml patches before descending into the submodule. When we
|
||||
# `add_subdirectory` it from a parent project, `CMAKE_SOURCE_DIR` points
|
||||
# at *our* directory, not theirs, so the script path resolves wrong.
|
||||
#
|
||||
# Run the patches script ourselves up front (it's idempotent — re-running
|
||||
# is a no-op once patches are applied) so the rt-detr.cpp configure step
|
||||
# is essentially a no-op for the patch hook.
|
||||
set(RFDETR_CPP_SRC ${CMAKE_CURRENT_SOURCE_DIR}/sources/rt-detr.cpp)
|
||||
if(EXISTS ${RFDETR_CPP_SRC}/scripts/apply_ggml_patches.sh)
|
||||
execute_process(
|
||||
COMMAND bash ${RFDETR_CPP_SRC}/scripts/apply_ggml_patches.sh
|
||||
RESULT_VARIABLE _rfdetr_patch_result
|
||||
OUTPUT_VARIABLE _rfdetr_patch_output
|
||||
ERROR_VARIABLE _rfdetr_patch_error
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
ERROR_STRIP_TRAILING_WHITESPACE)
|
||||
if(NOT _rfdetr_patch_result EQUAL 0)
|
||||
message(FATAL_ERROR
|
||||
"Failed to apply ggml patches (exit ${_rfdetr_patch_result}):\n"
|
||||
"stdout:\n${_rfdetr_patch_output}\n"
|
||||
"stderr:\n${_rfdetr_patch_error}")
|
||||
endif()
|
||||
message(STATUS "${_rfdetr_patch_output}")
|
||||
endif()
|
||||
|
||||
# Stage a shim 'scripts/apply_ggml_patches.sh' under our source dir so that
|
||||
# rt-detr.cpp's CMakeLists — which calls
|
||||
# bash ${CMAKE_SOURCE_DIR}/scripts/apply_ggml_patches.sh
|
||||
# — finds an idempotent no-op there. The real patches have already been
|
||||
# applied above; this just satisfies the path lookup.
|
||||
file(MAKE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts)
|
||||
file(WRITE ${CMAKE_CURRENT_SOURCE_DIR}/scripts/apply_ggml_patches.sh
|
||||
"#!/usr/bin/env bash
|
||||
# Shim - patches were already applied by the parent CMakeLists.
|
||||
exit 0
|
||||
")
|
||||
execute_process(COMMAND chmod +x ${CMAKE_CURRENT_SOURCE_DIR}/scripts/apply_ggml_patches.sh)
|
||||
|
||||
add_subdirectory(./sources/rt-detr.cpp)
|
||||
|
||||
# rfdetr.cpp's C-API symbols already live inside librfdetr (src/rfdetr_capi.cpp
|
||||
# is compiled into the lib). We re-export them via a MODULE library that
|
||||
# whole-archive-links rfdetr so the symbols are visible at dlopen time.
|
||||
add_library(rfdetrcpp MODULE
|
||||
sources/rt-detr.cpp/src/rfdetr_capi.cpp)
|
||||
|
||||
target_include_directories(rfdetrcpp PRIVATE
|
||||
sources/rt-detr.cpp/include
|
||||
sources/rt-detr.cpp/src
|
||||
sources/rt-detr.cpp/third_party/stb
|
||||
)
|
||||
|
||||
target_link_libraries(rfdetrcpp PRIVATE rfdetr ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(rfdetrcpp PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
set_property(TARGET rfdetrcpp PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(rfdetrcpp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
135
backend/go/rfdetr-cpp/Makefile
Normal file
135
backend/go/rfdetr-cpp/Makefile
Normal file
@@ -0,0 +1,135 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# rt-detr.cpp (GitHub redirects the historical mudler/rt-detr.cpp to the new
|
||||
# mudler/rf-detr.cpp slug). Pin to a specific commit if you need a stable
|
||||
# build; leaving this on `master` always picks up the latest C-API surface
|
||||
# (incl. the per-detection accessor functions used by gorfdetrcpp.go).
|
||||
RFDETR_REPO?=https://github.com/mudler/rf-detr.cpp.git
|
||||
RFDETR_VERSION?=65c0ffcc9a9bc9dae38252f63d0417c9845a6cf7
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# Forward LocalAI's BUILD_TYPE to the matching ggml backend switch.
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON -DRFDETR_GGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS?=gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DRFDETR_GGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON -DRFDETR_GGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
CMAKE_ARGS+=-DRFDETR_GGML_METAL=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/rt-detr.cpp:
|
||||
mkdir -p sources && \
|
||||
git clone --recursive $(RFDETR_REPO) sources/rt-detr.cpp && \
|
||||
cd sources/rt-detr.cpp && \
|
||||
git checkout $(RFDETR_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
# Only build CPU variants on Linux
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = librfdetrcpp-avx.so librfdetrcpp-avx2.so librfdetrcpp-avx512.so librfdetrcpp-fallback.so
|
||||
else
|
||||
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||
VARIANT_TARGETS = librfdetrcpp-fallback.so
|
||||
endif
|
||||
|
||||
rfdetr-cpp: main.go gorfdetrcpp.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o rfdetr-cpp ./
|
||||
|
||||
package: rfdetr-cpp
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf librfdetrcpp*.so rfdetr-cpp package sources
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
librfdetrcpp-avx.so: sources/rt-detr.cpp
|
||||
rm -rfv build-$@
|
||||
$(info ${GREEN}I rfdetr-cpp build info:avx${RESET})
|
||||
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) librfdetrcpp-custom
|
||||
rm -rfv build-$@
|
||||
|
||||
librfdetrcpp-avx2.so: sources/rt-detr.cpp
|
||||
rm -rfv build-$@
|
||||
$(info ${GREEN}I rfdetr-cpp build info:avx2${RESET})
|
||||
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) librfdetrcpp-custom
|
||||
rm -rfv build-$@
|
||||
|
||||
librfdetrcpp-avx512.so: sources/rt-detr.cpp
|
||||
rm -rfv build-$@
|
||||
$(info ${GREEN}I rfdetr-cpp build info:avx512${RESET})
|
||||
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) librfdetrcpp-custom
|
||||
rm -rfv build-$@
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
librfdetrcpp-fallback.so: sources/rt-detr.cpp
|
||||
rm -rfv build-$@
|
||||
$(info ${GREEN}I rfdetr-cpp build info:fallback${RESET})
|
||||
SO_TARGET=$@ CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) librfdetrcpp-custom
|
||||
rm -rfv build-$@
|
||||
|
||||
librfdetrcpp-custom: CMakeLists.txt
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/librfdetrcpp.so ./$(SO_TARGET)
|
||||
|
||||
all: rfdetr-cpp package
|
||||
|
||||
# `test` is invoked by the top-level Makefile's `test-extra` target. It builds
|
||||
# the backend binary + the fallback shared library (needed for dlopen at
|
||||
# runtime), then runs test.sh which downloads the test models + COCO image
|
||||
# and exercises the gRPC Load/Detect wire path via the Go smoke test in
|
||||
# main_test.go for both the detection and segmentation models.
|
||||
test: rfdetr-cpp librfdetrcpp-fallback.so
|
||||
bash test.sh
|
||||
195
backend/go/rfdetr-cpp/gorfdetrcpp.go
Normal file
195
backend/go/rfdetr-cpp/gorfdetrcpp.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package main
|
||||
|
||||
// gorfdetrcpp.go - gRPC handlers (Load, Detect) for the rfdetr-cpp backend.
|
||||
//
|
||||
// Embeds base.SingleThread to default unimplemented RPCs to "not supported"
|
||||
// while we only implement object detection.
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// Default upper bound on detections returned per image. RF-DETR's decoder
|
||||
// queries are limited to a few hundred; 300 is a safe ceiling.
|
||||
const defaultTopK = 300
|
||||
|
||||
// rfdetr_handle_t is a uintptr-typed opaque handle (see include/rfdetr_capi.h).
|
||||
var (
|
||||
// rfdetr_capi_load(const char* model_path, int n_threads, rfdetr_handle_t* out_handle) -> int
|
||||
CapiLoad func(modelPath string, nThreads int32, outHandle *uintptr) int32
|
||||
// rfdetr_capi_unload(rfdetr_handle_t handle) -> int
|
||||
CapiUnload func(handle uintptr) int32
|
||||
// rfdetr_capi_detect_path(handle, image_path, threshold, top_k, out_json) -> int
|
||||
CapiDetectPath func(handle uintptr, imagePath string, threshold float32, topK uint32, outJSON *uintptr) int32
|
||||
// rfdetr_capi_detect_buffer(handle, bytes, len, threshold, top_k, out_json) -> int
|
||||
CapiDetectBuffer func(handle uintptr, bytes uintptr, length uintptr, threshold float32, topK uint32, outJSON *uintptr) int32
|
||||
// rfdetr_capi_free_string(char* s)
|
||||
CapiFreeString func(s uintptr)
|
||||
// rfdetr_capi_get_n_detections(handle) -> int
|
||||
CapiGetNDetections func(handle uintptr) int32
|
||||
// rfdetr_capi_get_detection_class_id(handle, i) -> int
|
||||
CapiGetDetectionClassID func(handle uintptr, i int32) int32
|
||||
// rfdetr_capi_get_detection_box(handle, i, out_xyxy[4]) -> int (0 on success)
|
||||
CapiGetDetectionBox func(handle uintptr, i int32, outXYXY uintptr) int32
|
||||
// rfdetr_capi_get_detection_score(handle, i) -> float
|
||||
CapiGetDetectionScore func(handle uintptr, i int32) float32
|
||||
// rfdetr_capi_get_detection_class_name(handle, i, buf, buf_size) -> int (needed/written; two-call sizing)
|
||||
CapiGetDetectionClassName func(handle uintptr, i int32, buf uintptr, bufSize int32) int32
|
||||
// rfdetr_capi_get_detection_mask_png(handle, i, buf, buf_size) -> int (needed/written; 0 means no mask)
|
||||
CapiGetDetectionMaskPNG func(handle uintptr, i int32, buf uintptr, bufSize int32) int32
|
||||
)
|
||||
|
||||
type RFDetrCpp struct {
|
||||
base.SingleThread
|
||||
handle uintptr
|
||||
}
|
||||
|
||||
// Load loads the GGUF model at opts.ModelFile (joined with opts.ModelPath if relative)
|
||||
// and stores the handle for later Detect calls.
|
||||
func (r *RFDetrCpp) Load(opts *pb.ModelOptions) error {
|
||||
modelFile := opts.ModelFile
|
||||
if modelFile == "" {
|
||||
modelFile = opts.Model
|
||||
}
|
||||
if modelFile == "" {
|
||||
return fmt.Errorf("rfdetr-cpp: ModelFile is empty")
|
||||
}
|
||||
|
||||
var modelPath string
|
||||
if filepath.IsAbs(modelFile) {
|
||||
modelPath = modelFile
|
||||
} else {
|
||||
modelPath = filepath.Join(opts.ModelPath, modelFile)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(modelPath); err != nil {
|
||||
return fmt.Errorf("rfdetr-cpp: model file not found: %s: %w", modelPath, err)
|
||||
}
|
||||
|
||||
threads := opts.Threads
|
||||
if threads <= 0 {
|
||||
threads = 4
|
||||
}
|
||||
|
||||
// Release previous model if any (re-Load).
|
||||
if r.handle != 0 {
|
||||
CapiUnload(r.handle)
|
||||
r.handle = 0
|
||||
}
|
||||
|
||||
var h uintptr
|
||||
rc := CapiLoad(modelPath, threads, &h)
|
||||
if rc != 0 || h == 0 {
|
||||
return fmt.Errorf("rfdetr-cpp: rfdetr_capi_load failed with rc=%d for %s", rc, modelPath)
|
||||
}
|
||||
r.handle = h
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detect runs object detection on the base64-encoded image in opts.Src at
|
||||
// opts.Threshold, returning one pb.Detection per result. Seg models also
|
||||
// populate Detection.Mask with PNG-encoded mask bytes.
|
||||
func (r *RFDetrCpp) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||
if r.handle == 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: model not loaded")
|
||||
}
|
||||
|
||||
// Decode base64 image and write to temp file.
|
||||
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: failed to decode base64 image: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "rfdetr-*.img")
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: failed to create temp file: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
|
||||
if _, err := tmpFile.Write(imgData); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: failed to write temp file: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
threshold := opts.Threshold
|
||||
if threshold <= 0 {
|
||||
threshold = 0.5
|
||||
}
|
||||
|
||||
// JSON output from detect_path is unused: we read structured detections via
|
||||
// the accessor functions. Still must free the returned string.
|
||||
var jsonPtr uintptr
|
||||
rc := CapiDetectPath(r.handle, tmpFile.Name(), threshold, uint32(defaultTopK), &jsonPtr)
|
||||
if jsonPtr != 0 {
|
||||
CapiFreeString(jsonPtr)
|
||||
}
|
||||
if rc != 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: detect failed with rc=%d", rc)
|
||||
}
|
||||
|
||||
n := CapiGetNDetections(r.handle)
|
||||
if n < 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("rfdetr-cpp: invalid n_detections=%d", n)
|
||||
}
|
||||
|
||||
detections := make([]*pb.Detection, 0, n)
|
||||
for i := int32(0); i < n; i++ {
|
||||
var bbox [4]float32 // x1, y1, x2, y2
|
||||
if rc := CapiGetDetectionBox(r.handle, i, uintptr(unsafe.Pointer(&bbox[0]))); rc != 0 {
|
||||
continue
|
||||
}
|
||||
cid := CapiGetDetectionClassID(r.handle, i)
|
||||
score := CapiGetDetectionScore(r.handle, i)
|
||||
|
||||
// Two-call sizing for class_name.
|
||||
var className string
|
||||
nameSize := CapiGetDetectionClassName(r.handle, i, 0, 0)
|
||||
if nameSize > 1 {
|
||||
buf := make([]byte, nameSize)
|
||||
written := CapiGetDetectionClassName(r.handle, i, uintptr(unsafe.Pointer(&buf[0])), nameSize)
|
||||
// `written` is the same number (needed bytes including NUL); strip NUL.
|
||||
if written > 0 && int(written) <= len(buf) {
|
||||
className = string(buf[:written-1])
|
||||
} else {
|
||||
className = string(buf[:len(buf)-1])
|
||||
}
|
||||
}
|
||||
if className == "" {
|
||||
className = strconv.Itoa(int(cid))
|
||||
}
|
||||
|
||||
// Two-call sizing for mask PNG (returns 0 when no mask).
|
||||
var mask []byte
|
||||
maskSize := CapiGetDetectionMaskPNG(r.handle, i, 0, 0)
|
||||
if maskSize > 0 {
|
||||
maskBuf := make([]byte, maskSize)
|
||||
CapiGetDetectionMaskPNG(r.handle, i, uintptr(unsafe.Pointer(&maskBuf[0])), maskSize)
|
||||
mask = maskBuf
|
||||
}
|
||||
|
||||
detections = append(detections, &pb.Detection{
|
||||
X: bbox[0],
|
||||
Y: bbox[1],
|
||||
Width: bbox[2] - bbox[0],
|
||||
Height: bbox[3] - bbox[1],
|
||||
Confidence: score,
|
||||
ClassName: className,
|
||||
Mask: mask,
|
||||
})
|
||||
}
|
||||
|
||||
return pb.DetectResponse{
|
||||
Detections: detections,
|
||||
}, nil
|
||||
}
|
||||
61
backend/go/rfdetr-cpp/main.go
Normal file
61
backend/go/rfdetr-cpp/main.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package main
|
||||
|
||||
// main.go - entry point for the rfdetr-cpp gRPC backend.
|
||||
//
|
||||
// Dlopens librfdetrcpp-<variant>.so via purego at the path in
|
||||
// RFDETR_LIBRARY (set by run.sh based on /proc/cpuinfo), registers the
|
||||
// rfdetr_capi_* C ABI symbols, then starts the gRPC server.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("RFDETR_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./librfdetrcpp-fallback.so"
|
||||
}
|
||||
|
||||
rfdetrLib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CapiLoad, "rfdetr_capi_load"},
|
||||
{&CapiUnload, "rfdetr_capi_unload"},
|
||||
{&CapiDetectPath, "rfdetr_capi_detect_path"},
|
||||
{&CapiDetectBuffer, "rfdetr_capi_detect_buffer"},
|
||||
{&CapiFreeString, "rfdetr_capi_free_string"},
|
||||
{&CapiGetNDetections, "rfdetr_capi_get_n_detections"},
|
||||
{&CapiGetDetectionClassID, "rfdetr_capi_get_detection_class_id"},
|
||||
{&CapiGetDetectionBox, "rfdetr_capi_get_detection_box"},
|
||||
{&CapiGetDetectionScore, "rfdetr_capi_get_detection_score"},
|
||||
{&CapiGetDetectionClassName, "rfdetr_capi_get_detection_class_name"},
|
||||
{&CapiGetDetectionMaskPNG, "rfdetr_capi_get_detection_mask_png"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, rfdetrLib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &RFDetrCpp{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
220
backend/go/rfdetr-cpp/main_test.go
Normal file
220
backend/go/rfdetr-cpp/main_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package main
|
||||
|
||||
// main_test.go - end-to-end smoke test for the rfdetr-cpp gRPC backend.
|
||||
//
|
||||
// Spawns the compiled rfdetr-cpp binary on a free local port, dials it via
|
||||
// gRPC, and exercises LoadModel + Detect against the test fixtures
|
||||
// downloaded by test.sh. Two scenarios:
|
||||
//
|
||||
// 1. detection — loads rfdetr-nano-q8_0.gguf and asserts at least one
|
||||
// detection comes back with a non-empty class name and a bounding box
|
||||
// of non-zero size.
|
||||
// 2. segmentation — loads rfdetr-seg-nano-q8_0.gguf and additionally
|
||||
// asserts that at least one detection carries a PNG-encoded mask blob
|
||||
// (verified by PNG magic bytes).
|
||||
//
|
||||
// Both specs Skip cleanly if their fixtures are missing so the test target
|
||||
// stays usable on a fresh checkout where models haven't been downloaded.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
func TestRFDetrCpp(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "rfdetr-cpp backend smoke suite")
|
||||
}
|
||||
|
||||
// freePort grabs an ephemeral TCP port and immediately releases it so the
|
||||
// spawned backend can bind to it. There is a tiny TOCTOU window here but in
|
||||
// practice it's adequate for a smoke test on a quiet runner.
|
||||
func freePort() int {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
Expect(err).ToNot(HaveOccurred(), "freePort listen")
|
||||
port := l.Addr().(*net.TCPAddr).Port
|
||||
Expect(l.Close()).To(Succeed())
|
||||
return port
|
||||
}
|
||||
|
||||
// startBackend spawns the rfdetr-cpp binary on the given port and waits
|
||||
// until it accepts TCP connections (up to 10s). The returned cleanup func
|
||||
// kills the process and reaps it.
|
||||
func startBackend(port int) func() {
|
||||
binary, err := filepath.Abs("./rfdetr-cpp")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if _, err := os.Stat(binary); err != nil {
|
||||
Skip(fmt.Sprintf("backend binary not built: %s (run `make rfdetr-cpp` first)", binary))
|
||||
}
|
||||
|
||||
libPath, err := filepath.Abs("./librfdetrcpp-fallback.so")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if _, err := os.Stat(libPath); err != nil {
|
||||
Skip(fmt.Sprintf("fallback library not built: %s (run `make librfdetrcpp-fallback.so` first)", libPath))
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
cmd := exec.Command(binary, "--addr", addr)
|
||||
cmd.Env = append(os.Environ(), "RFDETR_LIBRARY="+libPath)
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
Expect(cmd.Start()).To(Succeed())
|
||||
|
||||
cleanup := func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(10 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return cleanup
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
cleanup()
|
||||
Fail(fmt.Sprintf("backend did not become ready on %s within 10s", addr))
|
||||
return func() {}
|
||||
}
|
||||
|
||||
// loadTestImage reads the COCO test image downloaded by test.sh and returns
|
||||
// its base64-encoded content (the wire format accepted by the Detect RPC).
|
||||
func loadTestImage() string {
|
||||
imgPath, err := filepath.Abs("test-data/test.jpg")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
imgBytes, err := os.ReadFile(imgPath)
|
||||
if err != nil {
|
||||
Skip(fmt.Sprintf("test image not present: %s (run test.sh first)", imgPath))
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(imgBytes)
|
||||
}
|
||||
|
||||
// dialBackend opens a gRPC client connection to the spawned backend.
|
||||
func dialBackend(port int) (pb.BackendClient, func()) {
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return pb.NewBackendClient(conn), func() { _ = conn.Close() }
|
||||
}
|
||||
|
||||
// modelPathOrSkip resolves a model file under ./test-models/ and Skip()s
|
||||
// the current spec if it's missing.
|
||||
func modelPathOrSkip(name string) string {
|
||||
modelDir, err := filepath.Abs("test-models")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
modelPath := filepath.Join(modelDir, name)
|
||||
if _, err := os.Stat(modelPath); err != nil {
|
||||
Skip(fmt.Sprintf("model not present: %s (run test.sh first)", modelPath))
|
||||
}
|
||||
return modelPath
|
||||
}
|
||||
|
||||
var _ = Describe("rfdetr-cpp backend", func() {
|
||||
It("runs object detection against a known-good COCO image", func() {
|
||||
modelPath := modelPathOrSkip("rfdetr-nano-q8_0.gguf")
|
||||
imgB64 := loadTestImage()
|
||||
|
||||
port := freePort()
|
||||
cleanup := startBackend(port)
|
||||
defer cleanup()
|
||||
|
||||
client, closeConn := dialBackend(port)
|
||||
defer closeConn()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
loadResp, err := client.LoadModel(ctx, &pb.ModelOptions{
|
||||
Model: "rfdetr-nano-q8_0.gguf",
|
||||
ModelFile: modelPath,
|
||||
Threads: 2,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "LoadModel")
|
||||
Expect(loadResp.GetSuccess()).To(BeTrue(), "LoadModel reported failure: %s", loadResp.GetMessage())
|
||||
|
||||
detResp, err := client.Detect(ctx, &pb.DetectOptions{
|
||||
Src: imgB64,
|
||||
Threshold: 0.5,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "Detect")
|
||||
Expect(detResp.GetDetections()).ToNot(BeEmpty(), "no detections returned on a known-good COCO image")
|
||||
|
||||
_, _ = fmt.Fprintf(GinkgoWriter, "detection OK: %d detections\n", len(detResp.GetDetections()))
|
||||
for i, d := range detResp.GetDetections() {
|
||||
Expect(d.GetClassName()).ToNot(BeEmpty(), "detection %d has empty class_name", i)
|
||||
Expect(d.GetConfidence()).To(BeNumerically(">=", float32(0.5)),
|
||||
"detection %d below threshold", i)
|
||||
Expect(d.GetWidth()).To(BeNumerically(">", float32(0)),
|
||||
"detection %d has non-positive width", i)
|
||||
Expect(d.GetHeight()).To(BeNumerically(">", float32(0)),
|
||||
"detection %d has non-positive height", i)
|
||||
}
|
||||
})
|
||||
|
||||
It("runs segmentation and returns PNG-encoded masks", func() {
|
||||
modelPath := modelPathOrSkip("rfdetr-seg-nano-q8_0.gguf")
|
||||
imgB64 := loadTestImage()
|
||||
|
||||
port := freePort()
|
||||
cleanup := startBackend(port)
|
||||
defer cleanup()
|
||||
|
||||
client, closeConn := dialBackend(port)
|
||||
defer closeConn()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
loadResp, err := client.LoadModel(ctx, &pb.ModelOptions{
|
||||
Model: "rfdetr-seg-nano-q8_0.gguf",
|
||||
ModelFile: modelPath,
|
||||
Threads: 2,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "LoadModel")
|
||||
Expect(loadResp.GetSuccess()).To(BeTrue(), "LoadModel reported failure: %s", loadResp.GetMessage())
|
||||
|
||||
detResp, err := client.Detect(ctx, &pb.DetectOptions{
|
||||
Src: imgB64,
|
||||
Threshold: 0.5,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred(), "Detect")
|
||||
Expect(detResp.GetDetections()).ToNot(BeEmpty(), "no detections returned from segmentation model")
|
||||
|
||||
haveMask := false
|
||||
for i, d := range detResp.GetDetections() {
|
||||
m := d.GetMask()
|
||||
if len(m) == 0 {
|
||||
continue
|
||||
}
|
||||
haveMask = true
|
||||
// Verify PNG magic: 89 50 4E 47 ("\x89PNG").
|
||||
Expect(len(m)).To(BeNumerically(">=", 4), "detection %d mask too short", i)
|
||||
Expect([]byte{m[0], m[1], m[2], m[3]}).To(Equal([]byte{0x89, 'P', 'N', 'G'}),
|
||||
"detection %d mask is not a PNG", i)
|
||||
}
|
||||
Expect(haveMask).To(BeTrue(),
|
||||
"segmentation model returned %d detections but none carried a mask",
|
||||
len(detResp.GetDetections()))
|
||||
|
||||
_, _ = fmt.Fprintf(GinkgoWriter, "segmentation OK: %d detections, at least one with PNG mask\n",
|
||||
len(detResp.GetDetections()))
|
||||
})
|
||||
})
|
||||
59
backend/go/rfdetr-cpp/package.sh
Executable file
59
backend/go/rfdetr-cpp/package.sh
Executable file
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/librfdetrcpp-*.so $CURDIR/package/
|
||||
cp -avf $CURDIR/rfdetr-cpp $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/rfdetr-cpp/run.sh
Executable file
52
backend/go/rfdetr-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/librfdetrcpp-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/librfdetrcpp-avx.so ]; then
|
||||
LIBRARY="$CURDIR/librfdetrcpp-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/librfdetrcpp-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/librfdetrcpp-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/librfdetrcpp-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/librfdetrcpp-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export RFDETR_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/rfdetr-cpp "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/rfdetr-cpp "$@"
|
||||
55
backend/go/rfdetr-cpp/test.sh
Executable file
55
backend/go/rfdetr-cpp/test.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
echo "Running rfdetr-cpp backend tests..."
|
||||
|
||||
# Test models from the mudler/rfdetr-cpp-* HuggingFace repos. Both the
|
||||
# detection (nano-q8_0, ~36 MB) and segmentation (seg-nano-q8_0, ~40 MB)
|
||||
# variants are downloaded so the Go smoke test exercises both code paths.
|
||||
RFDETR_MODEL_DIR="${RFDETR_MODEL_DIR:-$CURDIR/test-models}"
|
||||
|
||||
RFDETR_DET_FILE="${RFDETR_DET_FILE:-rfdetr-nano-q8_0.gguf}"
|
||||
RFDETR_DET_URL="${RFDETR_DET_URL:-https://huggingface.co/mudler/rfdetr-cpp-nano/resolve/main/rfdetr-nano-q8_0.gguf}"
|
||||
|
||||
RFDETR_SEG_FILE="${RFDETR_SEG_FILE:-rfdetr-seg-nano-q8_0.gguf}"
|
||||
RFDETR_SEG_URL="${RFDETR_SEG_URL:-https://huggingface.co/mudler/rfdetr-cpp-seg-nano/resolve/main/rfdetr-seg-nano-q8_0.gguf}"
|
||||
|
||||
mkdir -p "$RFDETR_MODEL_DIR"
|
||||
|
||||
if [ ! -f "$RFDETR_MODEL_DIR/$RFDETR_DET_FILE" ]; then
|
||||
echo "Downloading rfdetr nano-q8_0 detection model..."
|
||||
curl -L -o "$RFDETR_MODEL_DIR/$RFDETR_DET_FILE" "$RFDETR_DET_URL" --progress-bar
|
||||
fi
|
||||
|
||||
if [ ! -f "$RFDETR_MODEL_DIR/$RFDETR_SEG_FILE" ]; then
|
||||
echo "Downloading rfdetr seg-nano-q8_0 segmentation model..."
|
||||
curl -L -o "$RFDETR_MODEL_DIR/$RFDETR_SEG_FILE" "$RFDETR_SEG_URL" --progress-bar
|
||||
fi
|
||||
|
||||
# Use a real COCO test image from the upstream rf-detr.cpp repo (~46 KB).
|
||||
# A synthetic 64x64 red PNG was too synthetic to elicit detections from a
|
||||
# real model — the smoke test would always trivially pass with zero
|
||||
# detections.
|
||||
TEST_IMAGE_DIR="$CURDIR/test-data"
|
||||
TEST_IMAGE_FILE="$TEST_IMAGE_DIR/test.jpg"
|
||||
TEST_IMAGE_URL="${TEST_IMAGE_URL:-https://raw.githubusercontent.com/mudler/rf-detr.cpp/main/tests/fixtures/ci/test_image.jpg}"
|
||||
|
||||
mkdir -p "$TEST_IMAGE_DIR"
|
||||
if [ ! -f "$TEST_IMAGE_FILE" ]; then
|
||||
echo "Downloading COCO test image..."
|
||||
curl -L -o "$TEST_IMAGE_FILE" "$TEST_IMAGE_URL" --progress-bar
|
||||
fi
|
||||
|
||||
echo "rfdetr-cpp test setup complete."
|
||||
echo " detection model: $RFDETR_MODEL_DIR/$RFDETR_DET_FILE"
|
||||
echo " segmentation model: $RFDETR_MODEL_DIR/$RFDETR_SEG_FILE"
|
||||
echo " test image: $TEST_IMAGE_FILE"
|
||||
|
||||
# Run the Go smoke test: spawns the backend binary on a free port, calls
|
||||
# LoadModel + Detect via gRPC for both detection and segmentation models.
|
||||
echo ""
|
||||
echo "Running Go smoke test..."
|
||||
cd "$CURDIR"
|
||||
go test -v -timeout 5m ./...
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=bd17f53b7386fb5f60e8587b75e73c4b2fed3426
|
||||
STABLEDIFFUSION_GGML_VERSION?=1f9ee88e09c258053fa59d5e05e23dfb10fa0b13
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <regex>
|
||||
#include <errno.h>
|
||||
#include <inttypes.h>
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
@@ -376,6 +377,8 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
const char *clip_g_path = "";
|
||||
const char *t5xxl_path = "";
|
||||
const char *vae_path = "";
|
||||
const char *audio_vae_path = "";
|
||||
const char *embeddings_connectors_path = "";
|
||||
const char *scheduler_str = "";
|
||||
const char *sampler = "";
|
||||
const char *clip_vision_path = "";
|
||||
@@ -431,6 +434,12 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "vae_path")) {
|
||||
vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "audio_vae_path")) {
|
||||
audio_vae_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "embeddings_connectors_path")) {
|
||||
embeddings_connectors_path = strdup(optval);
|
||||
}
|
||||
if (!strcmp(optname, "scheduler")) {
|
||||
scheduler_str = optval;
|
||||
}
|
||||
@@ -563,6 +572,8 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
ctx_params.diffusion_model_path = diffusion_model_path;
|
||||
ctx_params.high_noise_diffusion_model_path = high_noise_diffusion_model_path;
|
||||
ctx_params.vae_path = vae_path;
|
||||
ctx_params.audio_vae_path = audio_vae_path;
|
||||
ctx_params.embeddings_connectors_path = embeddings_connectors_path;
|
||||
ctx_params.taesd_path = taesd_path;
|
||||
ctx_params.control_net_path = control_net_path;
|
||||
if (lora_dir && strlen(lora_dir) > 0) {
|
||||
@@ -1065,9 +1076,71 @@ static uint8_t* load_and_resize_image(const char* path, int target_width, int ta
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Write sd.cpp's audio buffer to a temp WAV file (IEEE float, interleaved).
|
||||
// sd_audio_t.data is planar (all channel 0 samples, then channel 1, etc.) — we
|
||||
// interleave on the fly so ffmpeg's standard wav demuxer can read it directly.
|
||||
// Returns 0 on success and fills wav_path (must be at least 64 bytes).
|
||||
static int write_planar_float_wav(const sd_audio_t* a, char* wav_path, size_t wav_path_sz) {
|
||||
if (!a || !a->data || a->sample_count == 0 || a->channels == 0 || a->sample_rate == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
snprintf(wav_path, wav_path_sz, "/tmp/gosd-audio-XXXXXX.wav");
|
||||
int fd = mkstemps(wav_path, 4);
|
||||
if (fd < 0) { perror("mkstemps wav"); return -1; }
|
||||
FILE* f = fdopen(fd, "wb");
|
||||
if (!f) { perror("fdopen wav"); close(fd); return -1; }
|
||||
|
||||
uint64_t frames = a->sample_count;
|
||||
uint32_t channels = a->channels;
|
||||
uint32_t sample_rate = a->sample_rate;
|
||||
uint64_t total_samples64 = frames * (uint64_t)channels;
|
||||
uint64_t data_bytes64 = total_samples64 * sizeof(float);
|
||||
if (data_bytes64 > 0xFFFFFFFFull - 44) {
|
||||
fprintf(stderr, "audio too large for 32-bit WAV (%" PRIu64 " bytes)\n", data_bytes64);
|
||||
fclose(f);
|
||||
unlink(wav_path);
|
||||
return -1;
|
||||
}
|
||||
uint32_t data_bytes = (uint32_t)data_bytes64;
|
||||
uint32_t riff_size = 36 + data_bytes;
|
||||
uint16_t fmt_code = 3; // WAVE_FORMAT_IEEE_FLOAT
|
||||
uint16_t bits_per_sample = 32;
|
||||
uint16_t block_align = (uint16_t)(channels * sizeof(float));
|
||||
uint32_t byte_rate = sample_rate * block_align;
|
||||
uint16_t ch16 = (uint16_t)channels;
|
||||
uint32_t fmt_size = 16;
|
||||
|
||||
fwrite("RIFF", 1, 4, f);
|
||||
fwrite(&riff_size, 4, 1, f);
|
||||
fwrite("WAVEfmt ", 1, 8, f);
|
||||
fwrite(&fmt_size, 4, 1, f);
|
||||
fwrite(&fmt_code, 2, 1, f);
|
||||
fwrite(&ch16, 2, 1, f);
|
||||
fwrite(&sample_rate, 4, 1, f);
|
||||
fwrite(&byte_rate, 4, 1, f);
|
||||
fwrite(&block_align, 2, 1, f);
|
||||
fwrite(&bits_per_sample, 2, 1, f);
|
||||
fwrite("data", 1, 4, f);
|
||||
fwrite(&data_bytes, 4, 1, f);
|
||||
|
||||
// Interleave planar [ch0_samples..., ch1_samples...] → [ch0_s0, ch1_s0, ...]
|
||||
for (uint64_t s = 0; s < frames; s++) {
|
||||
for (uint32_t c = 0; c < channels; c++) {
|
||||
float v = a->data[(size_t)c * frames + s];
|
||||
fwrite(&v, sizeof(float), 1, f);
|
||||
}
|
||||
}
|
||||
fclose(f);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Pipe raw RGB/RGBA frames to ffmpeg stdin and let it produce an MP4 at dst.
|
||||
// Uses fork+execvp to avoid shell interpretation of dst.
|
||||
static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps, const char* dst) {
|
||||
// Uses fork+execvp to avoid shell interpretation of dst. When `audio` is
|
||||
// non-null, the audio waveform is staged to a temp WAV and added as a second
|
||||
// ffmpeg input so the final MP4 contains both video and AAC audio.
|
||||
static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps,
|
||||
const sd_audio_t* audio, const char* dst) {
|
||||
if (num_frames <= 0 || !frames || !frames[0].data) {
|
||||
fprintf(stderr, "ffmpeg_mux: empty frames\n");
|
||||
return 1;
|
||||
@@ -1082,38 +1155,87 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps, co
|
||||
snprintf(size_str, sizeof(size_str), "%dx%d", width, height);
|
||||
snprintf(fps_str, sizeof(fps_str), "%d", fps);
|
||||
|
||||
// Optional audio: write a temp WAV file if the model produced audio.
|
||||
char wav_path[64] = {0};
|
||||
bool have_audio = false;
|
||||
if (audio && audio->data && audio->sample_count > 0 && audio->channels > 0 && audio->sample_rate > 0) {
|
||||
if (write_planar_float_wav(audio, wav_path, sizeof(wav_path)) == 0) {
|
||||
have_audio = true;
|
||||
fprintf(stderr, "ffmpeg_mux: audio %u Hz × %u ch × %" PRIu64 " frames → %s\n",
|
||||
audio->sample_rate, audio->channels, audio->sample_count, wav_path);
|
||||
} else {
|
||||
fprintf(stderr, "ffmpeg_mux: failed to stage audio; producing silent video\n");
|
||||
}
|
||||
}
|
||||
|
||||
int pipefd[2];
|
||||
if (pipe(pipefd) != 0) { perror("pipe"); return 1; }
|
||||
if (pipe(pipefd) != 0) {
|
||||
perror("pipe");
|
||||
if (have_audio) unlink(wav_path);
|
||||
return 1;
|
||||
}
|
||||
|
||||
pid_t pid = fork();
|
||||
if (pid < 0) { perror("fork"); close(pipefd[0]); close(pipefd[1]); return 1; }
|
||||
if (pid < 0) {
|
||||
perror("fork");
|
||||
close(pipefd[0]); close(pipefd[1]);
|
||||
if (have_audio) unlink(wav_path);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (pid == 0) {
|
||||
// child
|
||||
close(pipefd[1]);
|
||||
if (dup2(pipefd[0], STDIN_FILENO) < 0) { perror("dup2"); _exit(127); }
|
||||
close(pipefd[0]);
|
||||
std::vector<char*> argv = {
|
||||
const_cast<char*>("ffmpeg"),
|
||||
const_cast<char*>("-y"),
|
||||
const_cast<char*>("-hide_banner"),
|
||||
const_cast<char*>("-loglevel"), const_cast<char*>("warning"),
|
||||
const_cast<char*>("-f"), const_cast<char*>("rawvideo"),
|
||||
const_cast<char*>("-pix_fmt"), const_cast<char*>(pix_fmt_in),
|
||||
const_cast<char*>("-s"), size_str,
|
||||
const_cast<char*>("-framerate"), fps_str,
|
||||
const_cast<char*>("-i"), const_cast<char*>("-"),
|
||||
const_cast<char*>("-c:v"), const_cast<char*>("libx264"),
|
||||
const_cast<char*>("-pix_fmt"), const_cast<char*>("yuv420p"),
|
||||
const_cast<char*>("-movflags"), const_cast<char*>("+faststart"),
|
||||
// Force MP4 container. Distributed LocalAI hands us a staging
|
||||
// path (e.g. /staging/localai-output-NNN.tmp) with a non-standard
|
||||
// extension; relying on filename suffix makes ffmpeg bail with
|
||||
// "Unable to choose an output format".
|
||||
const_cast<char*>("-f"), const_cast<char*>("mp4"),
|
||||
const_cast<char*>(dst),
|
||||
nullptr
|
||||
};
|
||||
std::vector<char*> argv;
|
||||
argv.push_back(const_cast<char*>("ffmpeg"));
|
||||
argv.push_back(const_cast<char*>("-y"));
|
||||
argv.push_back(const_cast<char*>("-hide_banner"));
|
||||
argv.push_back(const_cast<char*>("-loglevel"));
|
||||
argv.push_back(const_cast<char*>("warning"));
|
||||
// Input 0: raw video from stdin
|
||||
argv.push_back(const_cast<char*>("-f"));
|
||||
argv.push_back(const_cast<char*>("rawvideo"));
|
||||
argv.push_back(const_cast<char*>("-pix_fmt"));
|
||||
argv.push_back(const_cast<char*>(pix_fmt_in));
|
||||
argv.push_back(const_cast<char*>("-s"));
|
||||
argv.push_back(size_str);
|
||||
argv.push_back(const_cast<char*>("-framerate"));
|
||||
argv.push_back(fps_str);
|
||||
argv.push_back(const_cast<char*>("-i"));
|
||||
argv.push_back(const_cast<char*>("-"));
|
||||
// Input 1: optional audio WAV
|
||||
if (have_audio) {
|
||||
argv.push_back(const_cast<char*>("-i"));
|
||||
argv.push_back(wav_path);
|
||||
argv.push_back(const_cast<char*>("-map"));
|
||||
argv.push_back(const_cast<char*>("0:v:0"));
|
||||
argv.push_back(const_cast<char*>("-map"));
|
||||
argv.push_back(const_cast<char*>("1:a:0"));
|
||||
argv.push_back(const_cast<char*>("-c:a"));
|
||||
argv.push_back(const_cast<char*>("aac"));
|
||||
argv.push_back(const_cast<char*>("-b:a"));
|
||||
argv.push_back(const_cast<char*>("192k"));
|
||||
// -shortest so the final clip ends with the shorter of the two
|
||||
// streams — guards against an audio buffer that overshoots the
|
||||
// video duration (or vice versa) on certain LTX variants.
|
||||
argv.push_back(const_cast<char*>("-shortest"));
|
||||
}
|
||||
argv.push_back(const_cast<char*>("-c:v"));
|
||||
argv.push_back(const_cast<char*>("libx264"));
|
||||
argv.push_back(const_cast<char*>("-pix_fmt"));
|
||||
argv.push_back(const_cast<char*>("yuv420p"));
|
||||
argv.push_back(const_cast<char*>("-movflags"));
|
||||
argv.push_back(const_cast<char*>("+faststart"));
|
||||
// Force MP4 container. Distributed LocalAI hands us a staging
|
||||
// path (e.g. /staging/localai-output-NNN.tmp) with a non-standard
|
||||
// extension; relying on filename suffix makes ffmpeg bail with
|
||||
// "Unable to choose an output format".
|
||||
argv.push_back(const_cast<char*>("-f"));
|
||||
argv.push_back(const_cast<char*>("mp4"));
|
||||
argv.push_back(const_cast<char*>(dst));
|
||||
argv.push_back(nullptr);
|
||||
execvp(argv[0], argv.data());
|
||||
perror("execvp ffmpeg");
|
||||
_exit(127);
|
||||
@@ -1138,6 +1260,7 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps, co
|
||||
close(pipefd[1]);
|
||||
int status;
|
||||
waitpid(pid, &status, 0);
|
||||
if (have_audio) unlink(wav_path);
|
||||
return 1;
|
||||
}
|
||||
p += n;
|
||||
@@ -1148,8 +1271,13 @@ static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps, co
|
||||
|
||||
int status = 0;
|
||||
while (waitpid(pid, &status, 0) < 0) {
|
||||
if (errno != EINTR) { perror("waitpid"); return 1; }
|
||||
if (errno != EINTR) {
|
||||
perror("waitpid");
|
||||
if (have_audio) unlink(wav_path);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
if (have_audio) unlink(wav_path);
|
||||
if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) {
|
||||
fprintf(stderr, "ffmpeg exited with status %d\n", status);
|
||||
return 1;
|
||||
@@ -1188,6 +1316,9 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
p->high_noise_sample_params.scheduler = scheduler;
|
||||
p->high_noise_sample_params.flow_shift = flow_shift;
|
||||
|
||||
// Pin output fps in params; upstream uses it for audio sync (and we also mux at this rate).
|
||||
p->fps = fps;
|
||||
|
||||
// Load init/end reference images if provided (resized to output dims).
|
||||
uint8_t* init_buf = nullptr;
|
||||
uint8_t* end_buf = nullptr;
|
||||
@@ -1206,11 +1337,14 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
|
||||
// Generate
|
||||
int num_frames_out = 0;
|
||||
sd_image_t* frames = generate_video(sd_c, p, &num_frames_out);
|
||||
sd_image_t* frames = nullptr;
|
||||
sd_audio_t* audio = nullptr;
|
||||
bool ok = generate_video(sd_c, p, &frames, &num_frames_out, &audio);
|
||||
std::free(p);
|
||||
|
||||
if (!frames || num_frames_out == 0) {
|
||||
if (!ok || !frames || num_frames_out == 0) {
|
||||
fprintf(stderr, "generate_video produced no frames\n");
|
||||
if (audio) free_sd_audio(audio);
|
||||
if (init_buf) free(init_buf);
|
||||
if (end_buf) free(end_buf);
|
||||
return 1;
|
||||
@@ -1218,12 +1352,13 @@ int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int
|
||||
|
||||
fprintf(stderr, "Generated %d frames, muxing to %s via ffmpeg\n", num_frames_out, dst);
|
||||
|
||||
int rc = ffmpeg_mux_raw_to_mp4(frames, num_frames_out, fps, dst);
|
||||
int rc = ffmpeg_mux_raw_to_mp4(frames, num_frames_out, fps, audio, dst);
|
||||
|
||||
for (int i = 0; i < num_frames_out; i++) {
|
||||
if (frames[i].data) free(frames[i].data);
|
||||
}
|
||||
free(frames);
|
||||
if (audio) free_sd_audio(audio);
|
||||
if (init_buf) free(init_buf);
|
||||
if (end_buf) free(end_buf);
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=968eebe77225d25e57a3f981da7c696310f0e881
|
||||
WHISPER_CPP_VERSION?=99613cb720b65036237d44b52f753b51f75c2797
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -122,6 +122,62 @@
|
||||
nvidia-cuda-12: "cuda12-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper"
|
||||
- &crispasr
|
||||
name: "crispasr"
|
||||
alias: "crispasr"
|
||||
license: mit
|
||||
icon: https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg
|
||||
description: |
|
||||
CrispASR unified speech engine (whisper.cpp fork on ggml) supporting many ASR architectures (Parakeet, Canary, Voxtral, Qwen3-ASR, Granite, Wav2Vec2, Moonshine, OmniASR, FireRedASR, and more).
|
||||
urls:
|
||||
- https://github.com/CrispStrobe/CrispASR
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-crispasr"
|
||||
nvidia: "cuda12-crispasr"
|
||||
intel: "intel-sycl-f16-crispasr"
|
||||
metal: "metal-crispasr"
|
||||
amd: "rocm-crispasr"
|
||||
vulkan: "vulkan-crispasr"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-cuda-13: "cuda13-crispasr"
|
||||
nvidia-cuda-12: "cuda12-crispasr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
- ¶keetcpp
|
||||
name: "parakeet-cpp"
|
||||
alias: "parakeet-cpp"
|
||||
license: mit
|
||||
icon: https://avatars.githubusercontent.com/u/95302084
|
||||
description: |
|
||||
parakeet.cpp is a C++/ggml port of NVIDIA NeMo Parakeet automatic speech recognition (ASR) models.
|
||||
It supports the tdt, ctc, rnnt and hybrid decoder families as well as cache-aware streaming transcription,
|
||||
and runs on CPU, NVIDIA CUDA, AMD ROCm/HIP, Intel SYCL and NVIDIA Jetson (L4T) targets.
|
||||
urls:
|
||||
- https://github.com/mudler/parakeet.cpp
|
||||
tags:
|
||||
- audio-transcription
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
capabilities:
|
||||
default: "cpu-parakeet-cpp"
|
||||
nvidia: "cuda12-parakeet-cpp"
|
||||
intel: "intel-sycl-f16-parakeet-cpp"
|
||||
metal: "metal-parakeet-cpp"
|
||||
amd: "rocm-parakeet-cpp"
|
||||
vulkan: "vulkan-parakeet-cpp"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
nvidia-cuda-13: "cuda13-parakeet-cpp"
|
||||
nvidia-cuda-12: "cuda12-parakeet-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-parakeet-cpp"
|
||||
- &voxtral
|
||||
name: "voxtral"
|
||||
alias: "voxtral"
|
||||
@@ -253,6 +309,34 @@
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||
intel: "intel-sycl-f32-sam3-cpp"
|
||||
vulkan: "vulkan-sam3-cpp"
|
||||
- &rfdetrcpp
|
||||
name: "rfdetr-cpp"
|
||||
alias: "rfdetr-cpp"
|
||||
license: apache-2.0
|
||||
description: |
|
||||
Native RF-DETR object detection and instance segmentation in C/C++
|
||||
using GGML. Loads pre-built GGUF weights from the mudler/rfdetr-cpp-*
|
||||
family (Nano/Small/Base/Medium/Large + SegNano/SegSmall/SegMedium)
|
||||
and returns bounding boxes, class labels, confidence scores, and
|
||||
(for segmentation variants) PNG-encoded per-detection masks.
|
||||
urls:
|
||||
- https://github.com/mudler/rf-detr.cpp
|
||||
tags:
|
||||
- object-detection
|
||||
- image-segmentation
|
||||
- rfdetr
|
||||
- gpu
|
||||
- cpu
|
||||
capabilities:
|
||||
default: "cpu-rfdetr-cpp"
|
||||
nvidia: "cuda12-rfdetr-cpp"
|
||||
nvidia-cuda-12: "cuda12-rfdetr-cpp"
|
||||
nvidia-cuda-13: "cuda13-rfdetr-cpp"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-rfdetr-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-rfdetr-cpp"
|
||||
intel: "intel-sycl-f32-rfdetr-cpp"
|
||||
vulkan: "vulkan-rfdetr-cpp"
|
||||
- &vllm
|
||||
name: "vllm"
|
||||
license: apache-2.0
|
||||
@@ -1900,6 +1984,246 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-whisper
|
||||
## crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "crispasr-development"
|
||||
capabilities:
|
||||
default: "cpu-crispasr-development"
|
||||
nvidia: "cuda12-crispasr-development"
|
||||
intel: "intel-sycl-f16-crispasr-development"
|
||||
metal: "metal-crispasr-development"
|
||||
amd: "rocm-crispasr-development"
|
||||
vulkan: "vulkan-crispasr-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-cuda-13: "cuda13-crispasr-development"
|
||||
nvidia-cuda-12: "cuda12-crispasr-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-crispasr-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-nvidia-l4t-arm64-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cpu-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "vulkan-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "metal-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda12-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "rocm-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f32-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "intel-sycl-f16-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-crispasr
|
||||
- !!merge <<: *crispasr
|
||||
name: "cuda13-crispasr-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-crispasr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-crispasr
|
||||
## parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "parakeet-cpp-development"
|
||||
capabilities:
|
||||
default: "cpu-parakeet-cpp-development"
|
||||
nvidia: "cuda12-parakeet-cpp-development"
|
||||
intel: "intel-sycl-f16-parakeet-cpp-development"
|
||||
metal: "metal-parakeet-cpp-development"
|
||||
amd: "rocm-parakeet-cpp-development"
|
||||
vulkan: "vulkan-parakeet-cpp-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
nvidia-cuda-13: "cuda13-parakeet-cpp-development"
|
||||
nvidia-cuda-12: "cuda12-parakeet-cpp-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "nvidia-l4t-arm64-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-nvidia-l4t-arm64-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-nvidia-l4t-arm64-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cpu-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cpu-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "metal-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "metal-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda12-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda12-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "rocm-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "rocm-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f32-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f32-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f16-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "intel-sycl-f16-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "vulkan-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "vulkan-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-parakeet-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-parakeet-cpp
|
||||
- !!merge <<: *parakeetcpp
|
||||
name: "cuda13-parakeet-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-parakeet-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-parakeet-cpp
|
||||
## stablediffusion-ggml
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "cpu-stablediffusion-ggml"
|
||||
@@ -2349,6 +2673,99 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-sam3-cpp
|
||||
## rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "rfdetr-cpp-development"
|
||||
capabilities:
|
||||
default: "cpu-rfdetr-cpp-development"
|
||||
nvidia: "cuda12-rfdetr-cpp-development"
|
||||
nvidia-cuda-12: "cuda12-rfdetr-cpp-development"
|
||||
nvidia-cuda-13: "cuda13-rfdetr-cpp-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-rfdetr-cpp-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr-cpp-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-rfdetr-cpp-development"
|
||||
intel: "intel-sycl-f32-rfdetr-cpp-development"
|
||||
vulkan: "vulkan-rfdetr-cpp-development"
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "cpu-rfdetr-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "cpu-rfdetr-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "cuda12-rfdetr-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "cuda12-rfdetr-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "cuda13-rfdetr-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "cuda13-rfdetr-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "nvidia-l4t-arm64-rfdetr-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "nvidia-l4t-arm64-rfdetr-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "cuda13-nvidia-l4t-arm64-rfdetr-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "cuda13-nvidia-l4t-arm64-rfdetr-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "intel-sycl-f32-rfdetr-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "intel-sycl-f32-rfdetr-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "intel-sycl-f16-rfdetr-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "intel-sycl-f16-rfdetr-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "vulkan-rfdetr-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-rfdetr-cpp
|
||||
- !!merge <<: *rfdetrcpp
|
||||
name: "vulkan-rfdetr-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-rfdetr-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-rfdetr-cpp
|
||||
## Rerankers
|
||||
- !!merge <<: *rerankers
|
||||
name: "rerankers-development"
|
||||
|
||||
@@ -37,6 +37,20 @@ def is_int(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def coerce_param_value(value):
|
||||
"""Coerce a TTSRequest.params value (string on the wire) to the type the
|
||||
Chatterbox generate() kwargs expect (float/int/bool), matching how static
|
||||
YAML options are coerced at load time. Non-string values pass through."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if is_float(value):
|
||||
return float(value)
|
||||
if is_int(value):
|
||||
return int(value)
|
||||
if value.lower() in ["true", "false"]:
|
||||
return value.lower() == "true"
|
||||
return value
|
||||
|
||||
def split_text_at_word_boundary(text, max_length=250):
|
||||
"""
|
||||
Split text at word boundaries without truncating words.
|
||||
@@ -191,6 +205,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Merge per-request params (TTSRequest.params), overriding the static
|
||||
# YAML options. This exposes Chatterbox generation knobs (e.g.
|
||||
# exaggeration, cfg_weight, temperature) per request. Values arrive as
|
||||
# strings on the wire and are coerced to float/int/bool.
|
||||
if hasattr(request, "params") and request.params:
|
||||
for key, value in request.params.items():
|
||||
kwargs[key] = coerce_param_value(value)
|
||||
|
||||
# Check if text exceeds 250 characters
|
||||
# (chatterbox does not support long text)
|
||||
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
transformers==4.48.3
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
torch==2.4.1
|
||||
torch==2.7.1+cpu
|
||||
torchaudio==2.4.1
|
||||
coqui-tts
|
||||
@@ -1,5 +1,5 @@
|
||||
torch==2.4.1
|
||||
torch==2.7.1+cpu
|
||||
torchaudio==2.4.1
|
||||
transformers==4.48.3
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
coqui-tts
|
||||
@@ -1,6 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm7.0
|
||||
torch==2.10.0+rocm7.0
|
||||
torch==2.7.1+cpu
|
||||
torchaudio==2.10.0+rocm7.0
|
||||
transformers==4.48.3
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
coqui-tts
|
||||
@@ -1,8 +1,8 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch==2.8.0+xpu
|
||||
torch==2.7.1+cpu
|
||||
torchaudio==2.8.0+xpu
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
transformers==4.48.3
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
coqui-tts
|
||||
@@ -1,4 +1,4 @@
|
||||
torch==2.7.1
|
||||
transformers==4.48.3
|
||||
torch==2.7.1+cpu
|
||||
transformers==5.0.0rc3
|
||||
accelerate
|
||||
coqui-tts
|
||||
|
||||
@@ -99,8 +99,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if not results or len(results) == 0:
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
|
||||
# Get the transcript text from the first result
|
||||
text = results[0]
|
||||
# Get the transcript text from the first result.
|
||||
# CTC models return List[str], TDT/RNNT models return List[Hypothesis]
|
||||
# where the actual text lives in Hypothesis.text.
|
||||
result = results[0]
|
||||
if isinstance(result, str):
|
||||
text = result
|
||||
else:
|
||||
text = getattr(result, 'text', None) or ""
|
||||
|
||||
if text:
|
||||
# Create a single segment with the full transcription
|
||||
result_segments.append(backend_pb2.TranscriptSegment(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user