Compare commits

...

129 Commits

Author SHA1 Message Date
LocalAI [bot]
e9f10f2f50 chore(model gallery): 🤖 add 1 new models via gallery agent (#9202)
chore(model gallery): 🤖 add new models via gallery agent

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-04-02 21:22:19 +02:00
Ettore Di Giacinto
b95b0b72ff chore(ci): fix gallery agent
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-04-02 18:02:18 +00:00
LocalAI [bot]
26f1b94f4d chore: ⬆️ Update ggml-org/llama.cpp to 95a6ebabb277c4cc18247e7bc2a5502133caca63 (#9199)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-04-02 08:53:16 +02:00
LocalAI [bot]
2d40725ca2 chore: ⬆️ Update leejet/stable-diffusion.cpp to 87ecb95cbc65dc8e58e3d88f4f4a59a0939796f5 (#9200)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-04-02 08:53:04 +02:00
Ettore Di Giacinto
6c635e8353 feat: add resume endpoint to undrain nodes (#9197)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-04-01 18:21:43 +02:00
LocalAI [bot]
cc5f33ce95 chore: ⬆️ Update ggml-org/llama.cpp to 0fcb3760b2b9a3a496ef14621a7e4dad7a8df90f (#9196)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-04-01 00:48:40 +02:00
LocalAI [bot]
ba7cdd532a chore: ⬆️ Update leejet/stable-diffusion.cpp to 09b12d5f6d51d862749e8e0ee8baac8f012089e2 (#9195)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-04-01 00:48:25 +02:00
Ettore Di Giacinto
6b6c136210 fix(inflight): count inflight from load model, but release afterwards (#9194)
This should fix the count of 1 in flight always showing in the node list

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 23:24:45 +02:00
Ettore Di Giacinto
e587ecc485 chore(ui): allow to unload forcefully
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 17:20:53 +00:00
Ettore Di Giacinto
f259036a27 feat(gpu): add jetson/tegra detection
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 15:45:07 +00:00
Ettore Di Giacinto
221ff0f28f feat(ui): show cluster status in home in distributed mode
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 15:37:58 +00:00
Ettore Di Giacinto
16d5cb00bd chore: css cleanups
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 16:37:38 +02:00
Richard Palethorpe
952635fba6 feat(distributed): Avoid resending models to backend nodes (#9193)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-31 16:28:13 +02:00
Ettore Di Giacinto
3cc05af2e5 chore(nodes): restore offline nodes too
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 14:22:18 +00:00
Copilot
87a63316c7 stablediffusion-ggml: replace hand-maintained enum string arrays with upstream API calls (#9192)
* Initial plan

* Remove hand-maintained enum string arrays in gosd.cpp, use upstream API functions

Agent-Logs-Url: https://github.com/mudler/LocalAI/sessions/561fb489-89ed-4588-8f1e-7b967d91ba37

Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-31 14:53:38 +02:00
Richard Palethorpe
efdcbbe332 feat(api): Return 404 when model is not found except for model names in HF format (#9133)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-31 10:48:21 +02:00
Ettore Di Giacinto
b4fff9293d chore: small ui improvements in the node page
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 08:41:40 +00:00
dependabot[bot]
8180221b7e chore(deps): bump grpcio from 1.78.1 to 1.80.0 in /backend/python/common/template (#9176)
chore(deps): bump grpcio in /backend/python/common/template

Bumps [grpcio](https://github.com/grpc/grpc) from 1.78.1 to 1.80.0.
- [Release notes](https://github.com/grpc/grpc/releases)
- [Commits](https://github.com/grpc/grpc/compare/v1.78.1...v1.80.0)

---
updated-dependencies:
- dependency-name: grpcio
  dependency-version: 1.80.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-31 10:11:04 +02:00
dependabot[bot]
52a9755e08 chore(deps): bump grpcio from 1.78.1 to 1.80.0 in /backend/python/rerankers (#9181)
chore(deps): bump grpcio in /backend/python/rerankers

Bumps [grpcio](https://github.com/grpc/grpc) from 1.78.1 to 1.80.0.
- [Release notes](https://github.com/grpc/grpc/releases)
- [Commits](https://github.com/grpc/grpc/compare/v1.78.1...v1.80.0)

---
updated-dependencies:
- dependency-name: grpcio
  dependency-version: 1.80.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-31 10:10:50 +02:00
dependabot[bot]
a2a1d919f9 chore(deps): bump grpcio from 1.78.1 to 1.80.0 in /backend/python/coqui (#9182)
Bumps [grpcio](https://github.com/grpc/grpc) from 1.78.1 to 1.80.0.
- [Release notes](https://github.com/grpc/grpc/releases)
- [Commits](https://github.com/grpc/grpc/compare/v1.78.1...v1.80.0)

---
updated-dependencies:
- dependency-name: grpcio
  dependency-version: 1.80.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-31 10:10:35 +02:00
dependabot[bot]
a3d37931ec chore(deps): bump grpcio from 1.78.1 to 1.80.0 in /backend/python/vllm (#9177)
Bumps [grpcio](https://github.com/grpc/grpc) from 1.78.1 to 1.80.0.
- [Release notes](https://github.com/grpc/grpc/releases)
- [Commits](https://github.com/grpc/grpc/compare/v1.78.1...v1.80.0)

---
updated-dependencies:
- dependency-name: grpcio
  dependency-version: 1.80.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-31 10:10:17 +02:00
dependabot[bot]
5b2e25ebb0 chore(deps): bump grpcio from 1.78.1 to 1.80.0 in /backend/python/transformers (#9180)
chore(deps): bump grpcio in /backend/python/transformers

Bumps [grpcio](https://github.com/grpc/grpc) from 1.78.1 to 1.80.0.
- [Release notes](https://github.com/grpc/grpc/releases)
- [Commits](https://github.com/grpc/grpc/compare/v1.78.1...v1.80.0)

---
updated-dependencies:
- dependency-name: grpcio
  dependency-version: 1.80.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-31 10:10:03 +02:00
LocalAI [bot]
b0b37a472f chore: ⬆️ Update ggml-org/llama.cpp to 08f21453aec846867b39878500d725a05bd32683 (#9190)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-31 09:27:08 +02:00
Ettore Di Giacinto
3db12eaa7a fix(oauth/invite): do not register user (prending approval) without correct invite (#9189)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 08:29:07 +02:00
Ettore Di Giacinto
8862e3ce60 feat: add node reconciler, allow to schedule to group of nodes, min/max autoscaler (#9186)
* always enable parallel requests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: add node reconciler, allow to schedule to group of nodes, min/max autoscaler

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: move tests to ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(smart router): order by available vram

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-31 08:28:56 +02:00
LocalAI [bot]
80699a3f70 feat(swagger): update swagger (#9187)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-30 23:48:06 +02:00
dependabot[bot]
309a59f61e chore(deps): bump actions/upload-pages-artifact from 3 to 4 (#9179)
Bumps [actions/upload-pages-artifact](https://github.com/actions/upload-pages-artifact) from 3 to 4.
- [Release notes](https://github.com/actions/upload-pages-artifact/releases)
- [Commits](https://github.com/actions/upload-pages-artifact/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/upload-pages-artifact
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 23:16:23 +02:00
dependabot[bot]
65c9380389 chore(deps): bump go.opentelemetry.io/otel/exporters/prometheus from 0.62.0 to 0.64.0 (#9178)
chore(deps): bump go.opentelemetry.io/otel/exporters/prometheus

Bumps [go.opentelemetry.io/otel/exporters/prometheus](https://github.com/open-telemetry/opentelemetry-go) from 0.62.0 to 0.64.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/exporters/prometheus/v0.62.0...exporters/prometheus/v0.64.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/prometheus
  dependency-version: 0.64.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 23:16:09 +02:00
dependabot[bot]
79963c56bf chore(deps): bump github.com/pion/webrtc/v4 from 4.2.9 to 4.2.11 (#9185)
Bumps [github.com/pion/webrtc/v4](https://github.com/pion/webrtc) from 4.2.9 to 4.2.11.
- [Release notes](https://github.com/pion/webrtc/releases)
- [Commits](https://github.com/pion/webrtc/compare/v4.2.9...v4.2.11)

---
updated-dependencies:
- dependency-name: github.com/pion/webrtc/v4
  dependency-version: 4.2.11
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 23:15:54 +02:00
dependabot[bot]
7004ce0b78 chore(deps): bump github.com/nats-io/nats.go from 1.49.0 to 1.50.0 (#9183)
Bumps [github.com/nats-io/nats.go](https://github.com/nats-io/nats.go) from 1.49.0 to 1.50.0.
- [Release notes](https://github.com/nats-io/nats.go/releases)
- [Commits](https://github.com/nats-io/nats.go/compare/v1.49.0...v1.50.0)

---
updated-dependencies:
- dependency-name: github.com/nats-io/nats.go
  dependency-version: 1.50.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 23:15:39 +02:00
dependabot[bot]
702d0e0e4d chore(deps): bump google.golang.org/grpc from 1.79.1 to 1.79.3 (#9175)
Bumps [google.golang.org/grpc](https://github.com/grpc/grpc-go) from 1.79.1 to 1.79.3.
- [Release notes](https://github.com/grpc/grpc-go/releases)
- [Commits](https://github.com/grpc/grpc-go/compare/v1.79.1...v1.79.3)

---
updated-dependencies:
- dependency-name: google.golang.org/grpc
  dependency-version: 1.79.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 23:15:25 +02:00
dependabot[bot]
d6de208d6c chore(deps): bump actions/configure-pages from 5 to 6 (#9174)
Bumps [actions/configure-pages](https://github.com/actions/configure-pages) from 5 to 6.
- [Release notes](https://github.com/actions/configure-pages/releases)
- [Commits](https://github.com/actions/configure-pages/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/configure-pages
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 23:15:10 +02:00
dependabot[bot]
7451145e0c chore(deps): bump actions/deploy-pages from 4 to 5 (#9172)
Bumps [actions/deploy-pages](https://github.com/actions/deploy-pages) from 4 to 5.
- [Release notes](https://github.com/actions/deploy-pages/releases)
- [Commits](https://github.com/actions/deploy-pages/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/deploy-pages
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 23:14:57 +02:00
dependabot[bot]
cfda3dd0df chore(deps): bump actions/checkout from 4 to 6 (#9173)
Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 6.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 23:14:43 +02:00
Richard Palethorpe
e0eb2fd734 chore(ci): Scope tests extras backend tests (#9170)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-30 17:46:07 +00:00
Ettore Di Giacinto
dd3376e0a9 chore(workers): improve logging, set header timeouts (#9171)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 17:26:55 +02:00
Richard Palethorpe
520e1ce3cd fix(kokoro): Download phonemization model during installation (#9165)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-30 15:08:48 +02:00
LocalAI [bot]
3d738164b7 chore: ⬆️ Update ggml-org/llama.cpp to 7c203670f8d746382247ed369fea7fbf10df8ae0 (#9160)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-30 08:27:26 +02:00
LocalAI [bot]
56db76599a chore: ⬆️ Update ggml-org/whisper.cpp to 95ea8f9bfb03a15db08a8989966fd1ae3361e20d (#9168)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-30 08:27:11 +02:00
LocalAI [bot]
ad57cdfefe chore: ⬆️ Update leejet/stable-diffusion.cpp to f16a110f8776398ef23a2a6b7b57522c2471637a (#9167)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-30 08:26:45 +02:00
Richard Palethorpe
c2f7d1c18b feat(ui): Add media history to studio pages (e.g. past images) (#9151)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-30 00:49:55 +02:00
ER-EPR
afe79568d6 fix: huggingface repo change the file name so Update index.yaml is needed (#9163)
* Update index.yaml

Signed-off-by: ER-EPR <38782737+ER-EPR@users.noreply.github.com>

* Add mmproj files for Qwen3.5 models

Signed-off-by: ER-EPR <38782737+ER-EPR@users.noreply.github.com>

* Update file paths for Qwen models in index.yaml

Signed-off-by: ER-EPR <38782737+ER-EPR@users.noreply.github.com>

---------

Signed-off-by: ER-EPR <38782737+ER-EPR@users.noreply.github.com>
2026-03-30 00:48:17 +02:00
Ettore Di Giacinto
59108fbe32 feat: add distributed mode (#9124)
* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 00:47:27 +02:00
LocalAI [bot]
4c870288d9 chore: ⬆️ Update ggml-org/llama.cpp to 59d840209a5195c2f6e2e81b5f8339a0637b59d9 (#9144)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-28 18:18:06 +01:00
Ettore Di Giacinto
8da7212763 fix(ci): checkout submodules
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-28 00:33:31 +01:00
Ettore Di Giacinto
6e76052f9d ci: set gh-pages
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-27 21:26:55 +00:00
Richard Palethorpe
cf84db36ec fix(voxcpm): Force using a recent voxcpm version to kick the dependency solver (#9150)
fix(voxcpm): Allow packages to be fetched from all indexes

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-27 15:38:51 +01:00
Richard Palethorpe
d3f629f183 feat: Merge repeated log lines in the terminal (#9141)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-26 22:16:13 +01:00
Richard Palethorpe
b1aa707a92 fix(coqui,nemo,voxcpm): Add dependencies to allow CI to progress (#9142)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-26 18:03:56 +01:00
LocalAI [bot]
731176ce3a feat(swagger): update swagger (#9136)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-26 07:58:11 +01:00
LocalAI [bot]
b86fa63f70 chore: ⬆️ Update ggml-org/llama.cpp to a970515bdb0b1d09519106847660b0d0c84d2472 (#9137)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-26 07:56:41 +01:00
walcz-de
00fcf6936c fix: implement encoding_format=base64 for embeddings endpoint (#9135)
The OpenAI Node.js SDK v4+ sends encoding_format=base64 by default.
LocalAI previously ignored this parameter and always returned a float
JSON array, causing a silent data corruption bug in any Node.js client
(AnythingLLM Desktop, LangChain.js, LlamaIndex.TS, …):

  // What the client does when it expects base64 but receives a float array:
  Buffer.from(floatArray, 'base64')

Node.js treats a non-string first argument as a byte array — each
float32 value is truncated to a single byte — and Float32Array then
reads those bytes as floats, yielding dims/4 values.  Vector databases
(Qdrant, pgvector, …) then create collections with the wrong dimension,
causing all similarity searches to fail silently.

  e.g. granite-embedding-107m (384 dims) → 96 stored in Qdrant
       jina-embeddings-v3      (1024 dims) → 256 stored in Qdrant

Changes:
- core/schema/prediction.go: add EncodingFormat string field to
  PredictionOptions so the request parameter is parsed and available
  throughout the request pipeline
- core/schema/openai.go: add EmbeddingBase64 string field to Item;
  add MarshalJSON so the "embedding" JSON key emits either []float32
  or a base64 string depending on which field is populated — all other
  Item consumers (image, video endpoints) are unaffected
- core/http/endpoints/openai/embeddings.go: add floatsToBase64()
  which packs a float32 slice as little-endian bytes and base64-encodes
  it; add embeddingItem() helper; both InputToken and InputStrings loops
  now honour encoding_format=base64

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-25 17:38:07 +01:00
Richard Palethorpe
26384c5c70 fix(docs): Use notice instead of alert (#9134)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-25 13:55:48 +01:00
LocalAI [bot]
7209457f53 chore: ⬆️ Update ace-step/acestep.cpp to 6f35c874ee11e86d511b860019b84976f5b52d3a (#9128)
⬆️ Update ace-step/acestep.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-25 07:52:31 +01:00
LocalAI [bot]
9bc68b2721 chore: ⬆️ Update ggml-org/llama.cpp to 9f102a1407ed5d73b8c954f32edab50f8dfa3f58 (#9127)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-25 07:52:14 +01:00
Richard Palethorpe
7bdd198fd3 fix(downloader): Rewrite full https HF URI with HF_ENDPOINT (#9107)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-24 18:32:52 +01:00
dependabot[bot]
b296e3d94b chore(deps): bump github.com/mudler/skillserver from 0.0.5 to 0.0.6 (#9116)
Bumps [github.com/mudler/skillserver](https://github.com/mudler/skillserver) from 0.0.5 to 0.0.6.
- [Release notes](https://github.com/mudler/skillserver/releases)
- [Commits](https://github.com/mudler/skillserver/compare/v0.0.5...v0.0.6)

---
updated-dependencies:
- dependency-name: github.com/mudler/skillserver
  dependency-version: 0.0.6
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-24 08:51:02 +01:00
dependabot[bot]
c91855a9b2 chore(deps): bump peter-evans/create-pull-request from 7 to 8 (#9114)
Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 7 to 8.
- [Release notes](https://github.com/peter-evans/create-pull-request/releases)
- [Commits](https://github.com/peter-evans/create-pull-request/compare/v7...v8)

---
updated-dependencies:
- dependency-name: peter-evans/create-pull-request
  dependency-version: '8'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-24 08:50:50 +01:00
dependabot[bot]
e8e445cd43 chore(deps): bump actions/checkout from 4 to 6 (#9110)
Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 6.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-24 08:50:36 +01:00
dependabot[bot]
735c426072 chore(deps): bump github.com/modelcontextprotocol/go-sdk from 1.4.0 to 1.4.1 (#9118)
chore(deps): bump github.com/modelcontextprotocol/go-sdk

Bumps [github.com/modelcontextprotocol/go-sdk](https://github.com/modelcontextprotocol/go-sdk) from 1.4.0 to 1.4.1.
- [Release notes](https://github.com/modelcontextprotocol/go-sdk/releases)
- [Commits](https://github.com/modelcontextprotocol/go-sdk/compare/v1.4.0...v1.4.1)

---
updated-dependencies:
- dependency-name: github.com/modelcontextprotocol/go-sdk
  dependency-version: 1.4.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-24 00:36:23 +01:00
dependabot[bot]
0976b8a17b chore(deps): bump github.com/google/go-containerregistry from 0.21.2 to 0.21.3 (#9121)
chore(deps): bump github.com/google/go-containerregistry

Bumps [github.com/google/go-containerregistry](https://github.com/google/go-containerregistry) from 0.21.2 to 0.21.3.
- [Release notes](https://github.com/google/go-containerregistry/releases)
- [Commits](https://github.com/google/go-containerregistry/compare/v0.21.2...v0.21.3)

---
updated-dependencies:
- dependency-name: github.com/google/go-containerregistry
  dependency-version: 0.21.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-24 00:35:55 +01:00
LocalAI [bot]
2ad8c149e0 chore: ⬆️ Update ggml-org/llama.cpp to 1772701f99dd3fc13f5783b282c2361eda8ca47c (#9123)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-24 00:35:40 +01:00
LocalAI [bot]
31fcb1425d chore: ⬆️ Update ggml-org/llama.cpp to 49bfddeca18e62fa3d39114a23e9fcbdf8a22388 (#9102)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-23 01:11:18 +01:00
LocalAI [bot]
470d5e506f feat(swagger): update swagger (#9103)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-22 21:28:20 +01:00
Ettore Di Giacinto
0ee49cf42e Fix formatting in LocalAI description
Updated the description formatting for LocalAI.

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-22 21:28:07 +01:00
Ettore Di Giacinto
cecd8d6aa5 chore(docs): simplify
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-22 20:24:44 +00:00
Ettore Di Giacinto
15935e9d5f fix(auth): do not allow to register in invite mode (#9101)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-22 20:44:03 +01:00
Ettore Di Giacinto
5d410e5a03 fix(download): do not remove dst dir until we try all fallbacks (#9100)
This actually caused fallbacks to be compeletely no-op as we were
removing the destination dir before calling containerd.Apply

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-22 10:29:57 +01:00
Ettore Di Giacinto
5df77d7e8c Remove how-tos section link from README
Removed outdated link to community curated how-tos section.

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-22 10:05:13 +01:00
Ettore Di Giacinto
f891d60d26 fix(llama.cpp): bundle libdl, librt, libpthread in llama-cpp backend (#9099)
chore(llama.cpp): bundle libdl, librt, libpthread in llama-cpp backend

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-22 00:58:14 +01:00
Ettore Di Giacinto
be25217955 chore(transformers): bump to >5.0 and generically load models (#9097)
* chore(transformers): bump to >5.0

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: refactor to use generic model loading

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-22 00:57:54 +01:00
LocalAI [bot]
b74111feed chore: ⬆️ Update ggml-org/llama.cpp to 990e4d96980d0b016a2b07049cc9031642fb9903 (#9095)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-22 00:57:39 +01:00
LocalAI [bot]
bf92117259 chore: ⬆️ Update ggml-org/whisper.cpp to 76684141a5d059be71cbe23dc2f0ed552213ba2d (#9094)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-22 00:57:28 +01:00
Ettore Di Giacinto
031a36c995 feat: inferencing default, automatic tool parsing fallback and wire min_p (#9092)
* feat: wire min_p

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: inferencing defaults

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(refactor): re-use iterative parser

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: generate automatically inference defaults from unsloth

Instead of trying to re-invent the wheel and maintain here the inference
defaults, prefer to consume unsloth ones, and contribute there as
necessary.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: apply defaults also to models installed via gallery

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: be consistent and apply fallback to all endpoint

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-22 00:57:15 +01:00
LocalAI [bot]
8036d22ec6 chore: ⬆️ Update ace-step/acestep.cpp to 7326a7bea0c2037982ec924f7364e998df70450c (#9086)
⬆️ Update ace-step/acestep.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-22 00:56:52 +01:00
Ettore Di Giacinto
f7e8d9e791 feat(quantization): add quantization backend (#9096)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-22 00:56:34 +01:00
Ettore Di Giacinto
4b183b7bb6 feat: add quota system (#9090)
* feat: add quota system

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-21 10:09:49 +01:00
Ettore Di Giacinto
f38e91d80b feat(ui): add predictor for usage, user-breakdown statistics (#9091)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-21 10:09:36 +01:00
LocalAI [bot]
aa3e82976e chore: ⬆️ Update ggml-org/llama.cpp to 4cb7e0bd61e7e1101e8ab10db5dee70c5717a386 (#9087)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-21 09:41:11 +01:00
Ettore Di Giacinto
d9c1db2b87 feat: add (experimental) fine-tuning support with TRL (#9088)
* feat: add fine-tuning endpoint

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(experimental): add fine-tuning endpoint and TRL support

This changeset defines new GRPC signatues for Fine tuning backends, and
add TRL backend as initial fine-tuning engine. This implementation also
supports exporting to GGUF and automatically importing it to LocalAI
after fine-tuning.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* commit TRL backend, stop by killing process

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* move fine-tune to generic features

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* add evals, reorder menu

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-21 02:08:02 +01:00
LocalAI [bot]
f7e3aab4fc feat(swagger): update swagger (#9085)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-20 21:45:03 +01:00
Richard Palethorpe
73bdc3b50d fix(realtime): Set the alias for opus so the development backend can be selected (#9083)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-20 15:08:07 +01:00
Richard Palethorpe
cb63bdb9e4 feat(ui): Add model pipeline editor (#9070)
This creates a new model config page. Presently just allows configuring
pipelines, but can be extending the future to other types of models.
However pipelines are quite easy to create a form for and require
editing to create.

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-20 15:07:34 +01:00
Richard Palethorpe
8cd3f9fc47 feat(ui, openai): Structured errors and link to traces in error toast (#9068)
First when sending errors over SSE we now clearly identify them as such
instead of just sending the error string as a chat completion message.

We use this in the UI to identify errors and link to them to the traces.

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-20 15:06:07 +01:00
lif
e0ab1a8b43 fix: use exact tag matching for model gallery tag filtering (#9041)
The Search() method uses strings.Contains() on comma-joined tags,
causing substring false positives (e.g., "asr" matching "image-diffusers").

Add FilterByTag() method that checks each tag with strings.EqualFold()
for exact, case-insensitive matching. Add 'tag' query parameter to
/api/models and /api/backends endpoints. Update the React frontend to
send filter selections as 'tag' instead of 'term'.

Closes #8775

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-03-20 08:37:45 +01:00
Ettore Di Giacinto
c3174f9543 chore(deps): bump llama-cpp to 'a0bbcdd9b6b83eeeda6f1216088f42c33d464e38' (#9079)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-20 08:12:21 +01:00
LocalAI [bot]
2b12875302 fix: Add tracing settings loading from runtime_settings.json (#9081)
Tracing settings (EnableTracing and TracingMaxItems) were not being
loaded from runtime_settings.json on startup, causing tracing settings
configured via WebUI to be lost after service restart.

This fix adds proper loading of tracing settings in
loadRuntimeSettingsFromFile function in core/application/startup.go.

Fixes #9072

Co-authored-by: localai-bot <localai-bot@localai.io>
2026-03-20 00:58:52 +01:00
Ettore Di Giacinto
9cdbd89c1f chore(agents.md): update with auth/feature gating instructions
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-19 22:52:28 +00:00
LocalAI [bot]
7d81bf0aa3 chore: ⬆️ Update ggml-org/whisper.cpp to 9386f239401074690479731c1e41683fbbeac557 (#9077)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-19 23:27:35 +01:00
Tv
a6d0e29eba fix(openresponses): do not omit required field ORItemParam.Arguments (#9074)
See #9047
2026-03-19 22:04:45 +01:00
LocalAI [bot]
6054d2a91b feat(swagger): update swagger (#9075)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-19 21:49:40 +01:00
Ettore Di Giacinto
aea21951a2 feat: add users and authentication support (#9061)
* feat(ui): add users and authentication support

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: allow the admin user to impersonificate users

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: ui improvements, disable 'Users' button in navbar when no auth is configured

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: add OIDC support

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: gate models

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: cache requests to optimize speed

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* small UI enhancements

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(ui): style improvements

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: cover other paths by auth

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: separate local auth, refactor

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* security hardening, approval mode

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: fix tests and expectations

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: update localagi/localrecall

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-19 21:40:51 +01:00
LocalAI [bot]
bbe9067227 docs: Add troubleshooting guide for embedding models (fixes #9064) (#9065)
docs: Add troubleshooting guide for embedding models (#9064)

- Add section on using gallery models for embeddings
- Document common issues with embedding model configuration
- Add troubleshooting guide for Qwen3 embedding models
- Include correct configuration examples for Qwen3-Embedding-4B
- Document context size limits and dimension parameters
- Add table of Qwen3 embedding model specifications

Fixes #9064

Signed-off-by: localai-bot <localai-bot@localai.io>
Co-authored-by: localai-bot <localai-bot@localai.io>
2026-03-19 19:41:12 +01:00
LocalAI [bot]
9a9da062e1 chore: ⬆️ Update ggml-org/llama.cpp to 5744d7ec430e2f875a393770195fda530560773f (#9063)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-19 07:58:30 +01:00
LocalAI [bot]
dd1a8b174f chore: ⬆️ Update ggml-org/whisper.cpp to ef3463bb29ef90d25dfabfd1e75993111c52412d (#9062)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-19 07:58:11 +01:00
Richard Palethorpe
cfb7641eea feat(ui, gallery): Show model backends and add searchable model/backend selector (#9060)
* feat(ui, gallery): Display and filter by the backend models use

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(ui): Add searchable model backend/model selector and prevent delete models being selected

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-18 21:14:41 +01:00
Richard Palethorpe
e832efeb9e fix(ui): Refresh model list on deletion (#9059)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-18 14:07:45 +01:00
dependabot[bot]
a42548e9d1 chore(deps): bump playwright from 1.52.0 to 1.58.2 in /core/http/react-ui in the npm_and_yarn group across 1 directory (#9055)
chore(deps): bump playwright

Bumps the npm_and_yarn group with 1 update in the /core/http/react-ui directory: [playwright](https://github.com/microsoft/playwright).


Updates `playwright` from 1.52.0 to 1.58.2
- [Release notes](https://github.com/microsoft/playwright/releases)
- [Commits](https://github.com/microsoft/playwright/compare/v1.52.0...v1.58.2)

---
updated-dependencies:
- dependency-name: playwright
  dependency-version: 1.58.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-18 14:05:59 +01:00
LocalAI [bot]
8615ce28a8 feat: Add standalone agent run mode inspired by LocalAGI (#9056)
- Add 'agent' subcommand with 'run' and 'list' sub-commands
- Support running agents by name from pool.json registry
- Support running agents from JSON config files
- Implement foreground mode with --prompt flag for single-turn interactions
- Reuse AgentPoolService for consistent agent initialization
- Add comprehensive unit tests for config loading and overrides

Fixes #8960

Signed-off-by: localai-bot <localai-bot@users.noreply.github.com>
Co-authored-by: localai-bot <localai-bot@noreply.github.com>
2026-03-18 14:04:20 +01:00
Ettore Di Giacinto
8336efec41 fix(ui): correctly display backend if specified in the model config, re-order MCP buttons (#9053)
fix(ui): correctly display backend if specified in the model config

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-18 09:58:25 +01:00
LocalAI [bot]
8560a1e571 chore: ⬆️ Update ace-step/acestep.cpp to ab020a9aefcd364423e0665da12babc6b0c7b507 (#9046)
⬆️ Update ace-step/acestep.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-18 08:54:15 +01:00
LocalAI [bot]
29c33e6a6a chore: ⬆️ Update ggml-org/whisper.cpp to dc9611662265870df22a7230b7586176a99c1955 (#9045)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-18 08:46:35 +01:00
LocalAI [bot]
a58475dbef chore: ⬆️ Update ggml-org/llama.cpp to ee4801e5a6ee7ee4063144ab44ab4e127f76fba8 (#9044)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-18 08:46:12 +01:00
Tv
8a0edd0809 Always populate ORItemParam.Summary (#9049)
* fix(openresponses): do not omit required fields summary and id

* fix(openresponses): ensure ORItemParam.Summary is never null

Normalize Summary to an empty slice at serialization chokepoints
(sendSSEEvent, bufferEvent, buildORResponse) so it always serializes
as [] instead of null.

Closes #9047
2026-03-18 08:45:46 +01:00
Richard Palethorpe
35d509d8e7 feat(ui): Per model backend logs and various fixes (#9028)
* feat(gallery): Switch to expandable box instead of pop-over and display model files

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(ui, backends): Add individual backend logging

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(ui): Set the context settings from the model config

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-18 08:31:26 +01:00
Ettore Di Giacinto
eef808d921 fix: call .String() on AllowedTools
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-17 23:16:19 +00:00
Ettore Di Giacinto
9d9ea5c1a0 chore(deps): bump skillserver
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-17 21:04:52 +00:00
LocalAI [bot]
e21ad5cfaa chore: ⬆️ Update leejet/stable-diffusion.cpp to 545fac4f3fb0117a4e962b1a04cf933a7e635933 (#9036)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-17 18:07:30 +01:00
dependabot[bot]
05ab0c0aa2 chore(deps): bump github.com/google/go-containerregistry from 0.21.1 to 0.21.2 (#9033)
chore(deps): bump github.com/google/go-containerregistry

Bumps [github.com/google/go-containerregistry](https://github.com/google/go-containerregistry) from 0.21.1 to 0.21.2.
- [Release notes](https://github.com/google/go-containerregistry/releases)
- [Commits](https://github.com/google/go-containerregistry/compare/v0.21.1...v0.21.2)

---
updated-dependencies:
- dependency-name: github.com/google/go-containerregistry
  dependency-version: 0.21.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 11:43:25 +01:00
dependabot[bot]
e2b6233570 chore(deps): bump actions/upload-artifact from 4 to 7 (#9030)
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4 to 7.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v4...v7)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 11:42:49 +01:00
dependabot[bot]
19f995f38f chore(deps): bump github.com/ebitengine/purego from 0.9.1 to 0.10.0 (#9034)
Bumps [github.com/ebitengine/purego](https://github.com/ebitengine/purego) from 0.9.1 to 0.10.0.
- [Release notes](https://github.com/ebitengine/purego/releases)
- [Commits](https://github.com/ebitengine/purego/compare/v0.9.1...v0.10.0)

---
updated-dependencies:
- dependency-name: github.com/ebitengine/purego
  dependency-version: 0.10.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 11:42:06 +01:00
dependabot[bot]
ac168bbc60 chore(deps): bump github.com/anthropics/anthropic-sdk-go from 1.26.0 to 1.27.0 (#9035)
chore(deps): bump github.com/anthropics/anthropic-sdk-go

Bumps [github.com/anthropics/anthropic-sdk-go](https://github.com/anthropics/anthropic-sdk-go) from 1.26.0 to 1.27.0.
- [Release notes](https://github.com/anthropics/anthropic-sdk-go/releases)
- [Changelog](https://github.com/anthropics/anthropic-sdk-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/anthropics/anthropic-sdk-go/compare/v1.26.0...v1.27.0)

---
updated-dependencies:
- dependency-name: github.com/anthropics/anthropic-sdk-go
  dependency-version: 1.27.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 11:41:36 +01:00
LocalAI [bot]
5c5e537b31 chore: ⬆️ Update ace-step/acestep.cpp to 15740f4301b3ec3020875f1fb975a6cfdb2f6767 (#9038)
⬆️ Update ace-step/acestep.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-17 10:22:53 +01:00
LocalAI [bot]
118bcee196 chore: ⬆️ Update ggml-org/llama.cpp to 9b342d0a9f2f4892daec065491583ec2be129685 (#9039)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-17 10:22:42 +01:00
LocalAI [bot]
3eabd6d1d0 chore: ⬆️ Update ggml-org/whisper.cpp to 79218f51d02ffe70575ef7fba3496dfc7adda027 (#9037)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-17 08:25:31 +01:00
Ettore Di Giacinto
ee96e5e08d chore: refactor endpoints to use same inferencing path, add automatic retrial mechanism in case of errors (#9029)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-16 21:31:02 +01:00
Richard Palethorpe
3d9ccd1ddc fix(ui): Add tracing inline settings back and create UI tests (#9027)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-16 17:51:06 +01:00
Ettore Di Giacinto
d8161bfe57 fix(api): unescape model names (#9024)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-16 00:56:37 +01:00
Ettore Di Giacinto
5fd42399d4 feat: support streaming mode for tool calls in agent mode, fix interleaved thinking stream (#9023)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-16 00:50:19 +01:00
LocalAI [bot]
b2030255ca chore: ⬆️ Update ggml-org/llama.cpp to 88915cb55c14769738fcab7f1c6eaa6dcc9c2b0c (#9020)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-16 00:10:11 +01:00
LocalAI [bot]
9f903ec06e chore: ⬆️ Update leejet/stable-diffusion.cpp to 862a6586cb6fcec037c14f9ed902329ecec7d990 (#9019)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-16 00:09:59 +01:00
Ettore Di Giacinto
4ea461c330 fix(ui): correctly map watchdog fields (#9022)
Fixes: https://github.com/mudler/LocalAI/issues/9018

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-15 22:12:24 +01:00
Ettore Di Giacinto
042a9b8ef6 Remove Table of Contents from README
Removed the Table of Contents section from the README.

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-15 21:49:13 +01:00
Ettore Di Giacinto
65f1a4154a Revise README structure and content
Updated the README to include new sections and remove outdated content.

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-15 21:47:08 +01:00
LocalAI [bot]
c6a51289b0 fix: Automatically disable mmap for Intel SYCL backends (#9012) (#9015)
* fix: Automatically disable mmap for Intel SYCL backends

Fixes issue #9012 where Qwen3.5 models fail to load on Intel Arc GPU
with RPC EOF error.

The Intel SYCL backend has a known issue where mmap enabled causes
the backend to hang. This change automatically disables mmap when
detecting Intel or SYCL backends.

References:
- https://github.com/mudler/LocalAI/issues/9012
- Documentation mentions: SYCL hangs when mmap: true is set

* feat: Add logging for mmap auto-disable on Intel SYCL backends

As requested in PR review, add xlog.Info call to log when mmap
is automatically disabled for Intel SYCL backends. This helps
with debugging and confirms the auto-disable logic is working.

---------

Co-authored-by: localai-bot <localai-bot@users.noreply.github.com>
2026-03-15 21:06:35 +01:00
LocalAI [bot]
87525109f1 chore: ⬆️ Update ggml-org/llama.cpp to 3a6f059909ed5dab8587df5df4120315053d57a4 (#9009)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-15 09:46:45 +01:00
Ettore Di Giacinto
c596d8a5d9 fix: Change baseDir assignment to use ModelPath (#9010)
Fixes: https://github.com/mudler/LocalAI/issues/9005

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-15 09:45:58 +01:00
LocalAI [bot]
d79ad76e48 docs: ⬆️ update docs version mudler/LocalAI (#9008)
⬆️ Update docs version mudler/LocalAI

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-15 08:39:49 +01:00
Ettore Di Giacinto
dde0353432 chore(api): add path to expose collection raw files
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-14 22:31:45 +00:00
602 changed files with 309807 additions and 250970 deletions

View File

@@ -0,0 +1,259 @@
# API Endpoints and Authentication
This guide covers how to add new API endpoints and properly integrate them with the auth/permissions system.
## Architecture overview
Authentication and authorization flow through three layers:
1. **Global auth middleware** (`core/http/auth/middleware.go``auth.Middleware`) — applied to every request in `core/http/app.go`. Handles session cookies, Bearer tokens, API keys, and legacy API keys. Populates `auth_user` and `auth_role` in the Echo context.
2. **Feature middleware** (`auth.RequireFeature`) — per-feature access control applied to route groups or individual routes. Checks if the authenticated user has the specific feature enabled.
3. **Admin middleware** (`auth.RequireAdmin`) — restricts endpoints to admin users only.
When auth is disabled (no auth DB, no legacy API keys), all middleware becomes pass-through (`auth.NoopMiddleware`).
## Adding a new API endpoint
### Step 1: Create the handler
Write the endpoint handler in the appropriate package under `core/http/endpoints/`. Follow existing patterns:
```go
// core/http/endpoints/localai/my_feature.go
func MyFeatureEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
// Use auth.GetUser(c) to get the authenticated user (may be nil if auth is disabled)
user := auth.GetUser(c)
// Your logic here
return c.JSON(http.StatusOK, result)
}
}
```
### Step 2: Register routes
Add routes in the appropriate file under `core/http/routes/`. The file you use depends on the endpoint category:
| File | Category |
|------|----------|
| `routes/openai.go` | OpenAI-compatible API endpoints (`/v1/...`) |
| `routes/localai.go` | LocalAI-specific endpoints (`/api/...`, `/models/...`, `/backends/...`) |
| `routes/agents.go` | Agent pool endpoints (`/api/agents/...`) |
| `routes/auth.go` | Auth endpoints (`/api/auth/...`) |
| `routes/ui_api.go` | UI backend API endpoints |
### Step 3: Apply the right middleware
Choose the appropriate protection level:
#### No auth required (public)
Exempt paths bypass auth entirely. Add to `isExemptPath()` in `middleware.go` or use the `/api/auth/` prefix (always exempt). Use sparingly — most endpoints should require auth.
#### Standard auth (any authenticated user)
The global middleware already handles this. API paths (`/api/`, `/v1/`, etc.) automatically require authentication when auth is enabled. You don't need to add any extra middleware.
```go
router.GET("/v1/my-endpoint", myHandler) // auth enforced by global middleware
```
#### Admin only
Pass `adminMiddleware` to the route. This is set up in `app.go` and passed to `Register*Routes` functions:
```go
// In the Register function signature, accept the middleware:
func RegisterMyRoutes(router *echo.Echo, app *application.Application, adminMiddleware echo.MiddlewareFunc) {
router.POST("/models/apply", myHandler, adminMiddleware)
}
```
#### Feature-gated
For endpoints that should be toggleable per-user, use feature middleware. There are two approaches:
**Approach A: Route-level middleware** (preferred for groups of related endpoints)
```go
// In app.go, create the feature middleware:
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
// Pass it to the route registration function:
routes.RegisterMyRoutes(e, app, myFeatureMw)
// In the routes file, apply to a group:
g := e.Group("/api/my-feature", myFeatureMw)
g.GET("", listHandler)
g.POST("", createHandler)
```
**Approach B: RouteFeatureRegistry** (preferred for individual OpenAI-compatible endpoints)
Add an entry to `RouteFeatureRegistry` in `core/http/auth/features.go`. The `RequireRouteFeature` global middleware will automatically enforce it:
```go
var RouteFeatureRegistry = []RouteFeature{
// ... existing entries ...
{"POST", "/v1/my-endpoint", FeatureMyFeature},
}
```
## Adding a new feature
When you need a new toggleable feature (not just a new endpoint under an existing feature):
### 1. Define the feature constant
Add to `core/http/auth/permissions.go`:
```go
const (
// Add to the appropriate group:
// Agent features (default OFF for new users)
FeatureMyFeature = "my_feature"
// OR API features (default ON for new users)
FeatureMyFeature = "my_feature"
)
```
Then add it to the appropriate slice:
```go
// Default OFF — user must be explicitly granted access:
var AgentFeatures = []string{..., FeatureMyFeature}
// Default ON — user has access unless explicitly revoked:
var APIFeatures = []string{..., FeatureMyFeature}
```
### 2. Add feature metadata
In `core/http/auth/features.go`, add to the appropriate `FeatureMetas` function so the admin UI can display it:
```go
func AgentFeatureMetas() []FeatureMeta {
return []FeatureMeta{
// ... existing ...
{FeatureMyFeature, "My Feature", false}, // false = default OFF
}
}
```
### 3. Wire up the middleware
In `core/http/app.go`:
```go
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
```
Then pass it to the route registration function.
### 4. Register route-feature mappings (if applicable)
If your feature gates standard API endpoints (like `/v1/...`), add entries to `RouteFeatureRegistry` in `features.go` instead of using per-route middleware.
## Accessing the authenticated user in handlers
```go
import "github.com/mudler/LocalAI/core/http/auth"
func MyHandler(c echo.Context) error {
// Get the user (nil when auth is disabled or unauthenticated)
user := auth.GetUser(c)
if user == nil {
// Handle unauthenticated — or let middleware handle it
}
// Check role
if user.Role == auth.RoleAdmin {
// admin-specific logic
}
// Check feature access programmatically (when you need conditional behavior, not full blocking)
if auth.HasFeatureAccess(db, user, auth.FeatureMyFeature) {
// feature-specific logic
}
// Check model access
if !auth.IsModelAllowed(db, user, modelName) {
return c.JSON(http.StatusForbidden, ...)
}
}
```
## Middleware composition patterns
Middleware can be composed at different levels. Here are the patterns used in the codebase:
### Group-level middleware (agents pattern)
```go
// All routes in the group share the middleware
g := e.Group("/api/agents", poolReadyMw, agentsMw)
g.GET("", listHandler)
g.POST("", createHandler)
```
### Per-route middleware (localai pattern)
```go
// Individual routes get middleware as extra arguments
router.POST("/models/apply", applyHandler, adminMiddleware)
router.GET("/metrics", metricsHandler, adminMiddleware)
```
### Middleware slice (openai pattern)
```go
// Build a middleware chain for a handler
chatMiddleware := []echo.MiddlewareFunc{
usageMiddleware,
traceMiddleware,
modelFilterMiddleware,
}
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
```
## Error response format
Always use `schema.ErrorResponse` for auth/permission errors to stay consistent with the OpenAI-compatible API:
```go
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "feature not enabled for your account",
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
```
Use these HTTP status codes:
- `401 Unauthorized` — no valid credentials provided
- `403 Forbidden` — authenticated but lacking permission
- `429 Too Many Requests` — rate limited (auth endpoints)
## Usage tracking
If your endpoint should be tracked for usage (token counts, request counts), add the `usageMiddleware` to its middleware chain. See `core/http/middleware/usage.go` and how it's applied in `routes/openai.go`.
## Path protection rules
The global auth middleware classifies paths as API paths or non-API paths:
- **API paths** (always require auth when auth is enabled): `/api/`, `/v1/`, `/models/`, `/backends/`, `/backend/`, `/tts`, `/vad`, `/video`, `/stores/`, `/system`, `/ws/`, `/metrics`
- **Exempt paths** (never require auth): `/api/auth/` prefix, anything in `appConfig.PathWithoutAuth`
- **Non-API paths** (UI, static assets): pass through without auth — the React UI handles login redirects client-side
If you add endpoints under a new top-level path prefix, add it to `isAPIPath()` in `middleware.go` to ensure it requires authentication.
## Checklist
When adding a new endpoint:
- [ ] Handler in `core/http/endpoints/`
- [ ] Route registered in appropriate `core/http/routes/` file
- [ ] Auth level chosen: public / standard / admin / feature-gated
- [ ] If feature-gated: constant in `permissions.go`, metadata in `features.go`, middleware in `app.go`
- [ ] If new path prefix: added to `isAPIPath()` in `middleware.go`
- [ ] If OpenAI-compatible: entry in `RouteFeatureRegistry`
- [ ] If token-counting: `usageMiddleware` added to middleware chain
- [ ] Error responses use `schema.ErrorResponse` format
- [ ] Tests cover both authenticated and unauthenticated access

View File

@@ -49,3 +49,4 @@ The project documentation is located in `docs/content`. When adding new features
- **Feature Documentation**: If you add a new feature (like a new backend or API endpoint), create a new markdown file in `docs/content/features/` explaining what it is, how to configure it, and how to use it. - **Feature Documentation**: If you add a new feature (like a new backend or API endpoint), create a new markdown file in `docs/content/features/` explaining what it is, how to configure it, and how to use it.
- **Configuration**: If you modify configuration options, update the relevant sections in `docs/content/`. - **Configuration**: If you modify configuration options, update the relevant sections in `docs/content/`.
- **Examples**: providing concrete examples (like YAML configuration blocks) is highly encouraged to help users get started quickly. - **Examples**: providing concrete examples (like YAML configuration blocks) is highly encouraged to help users get started quickly.
- **Shortcodes**: Use `{{% notice note %}}`, `{{% notice tip %}}`, or `{{% notice warning %}}` for callout boxes. Do **not** use `{{% alert %}}` — that shortcode does not exist in this project's Hugo theme and will break the docs build.

View File

@@ -0,0 +1,141 @@
# Debugging and Rebuilding Backends
When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild.
## Architecture Overview
- **Source directory**: `backend/python/<name>/` (or `backend/go/<name>/`, `backend/cpp/<name>/`)
- **Installed directory**: `backends/<name>/` — this is what LocalAI actually runs. It is populated by `make backends/<name>` which builds a Docker image, exports it, and installs it via `local-ai backends install`.
- **Virtual environment**: `backends/<name>/venv/` — the installed Python venv (for Python backends). The Python binary is at `backends/<name>/venv/bin/python`.
Editing files in `backend/python/<name>/` does **not** affect the running backend until you rebuild with `make backends/<name>`.
## Diagnosing Failures
### 1. Check the logs
Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID:
```
GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..."
```
Common error patterns:
- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub.
- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`).
- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback.
### 2. Test the Python environment directly
You can run the installed venv's Python to check imports without starting the full server:
```bash
backends/<name>/venv/bin/python -c "import datasets; print(datasets.__version__)"
```
If `pip` is missing from the venv, bootstrap it:
```bash
backends/<name>/venv/bin/python -m ensurepip
```
Then use `backends/<name>/venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements.
### 3. Check upstream dependency constraints
When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`:
```
https://github.com/huggingface/trl/blob/main/requirements.txt
```
Pin minimum versions in the backend's requirements files to match upstream.
## Common Fixes
### Missing gRPC methods
If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`:
```python
def LoadModel(self, request, context):
"""No-op — actual loading happens elsewhere."""
return backend_pb2.Result(success=True, message="OK")
```
The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading.
### Dependency version conflicts
Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps:
1. Identify the broken import in the logs
2. Test in the installed venv: `backends/<name>/venv/bin/python -c "import <module>"`
3. Check upstream requirements for version constraints
4. Update **all** requirements files in `backend/python/<name>/`:
- `requirements.txt` — base deps (grpcio, protobuf)
- `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index)
- `requirements-cublas12.txt` — CUDA 12
- `requirements-cublas13.txt` — CUDA 13
5. Rebuild: `make backends/<name>`
### PyTorch index conflicts (uv resolver)
The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`:
```bash
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
installRequirements
```
Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference.
## Rebuilding
### Rebuild a single backend
```bash
make backends/<name>
```
This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/<name>.tar`, and installs it into `backends/<name>/`. It also rebuilds the `local-ai` Go binary (without extra tags).
**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/<name>` step will overwrite your binary without that tag. Rebuild the Go binary afterward:
```bash
GO_TAGS=auth make build
```
### Rebuild and restart
After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded.
```bash
# Kill existing process
kill <pid>
# Restart
./local-ai run --debug [your flags]
```
### Quick iteration (skip Docker rebuild)
For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly:
```bash
# Edit the installed copy
vim backends/<name>/backend.py
# Restart LocalAI to respawn the gRPC process
```
This is useful for testing but **does not persist** — the next `make backends/<name>` will overwrite it. Always commit fixes to the source in `backend/python/<name>/`.
## Verification
After fixing and rebuilding:
1. Start LocalAI and confirm the backend registers: look for `Registering backend name="<name>"` in the logs
2. Trigger the operation that failed (e.g. start a fine-tuning job)
3. Watch the GRPC stderr/stdout lines for the backend's model ID
4. Confirm no errors in the traceback

View File

@@ -133,6 +133,7 @@ func getRealReadme(ctx context.Context, repository string) (string, error) {
result, err := cogito.ExecuteTools(llm, fragment, result, err := cogito.ExecuteTools(llm, fragment,
cogito.WithIterations(3), cogito.WithIterations(3),
cogito.WithMaxAttempts(3), cogito.WithMaxAttempts(3),
cogito.DisableSinkState,
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()})) cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
if err != nil { if err != nil {
return "", err return "", err
@@ -406,7 +407,7 @@ func getHuggingFaceAvatarURL(author string) string {
} }
// Parse the response to get avatar URL // Parse the response to get avatar URL
var userInfo map[string]interface{} var userInfo map[string]any
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "" return ""

View File

@@ -79,7 +79,20 @@ func generateYAMLEntry(model ProcessedModel, quantization string) string {
description = cleanTextContent(description) description = cleanTextContent(description)
formattedDescription := formatTextContent(description) formattedDescription := formatTextContent(description)
configFile := formatTextContent(modelConfig.ConfigFile) // Strip name and description from config file since they are
// already present at the gallery entry level and should not
// appear under overrides.
configFileContent := modelConfig.ConfigFile
var cfgMap map[string]any
if err := yaml.Unmarshal([]byte(configFileContent), &cfgMap); err == nil {
delete(cfgMap, "name")
delete(cfgMap, "description")
if cleaned, err := yaml.Marshal(cfgMap); err == nil {
configFileContent = string(cleaned)
}
}
configFile := formatTextContent(configFileContent)
filesYAML, _ := yaml.Marshal(modelConfig.Files) filesYAML, _ := yaml.Marshal(modelConfig.Files)

View File

@@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand" "math/rand/v2"
"strings" "strings"
"time" "time"
) )
@@ -13,11 +13,11 @@ func runSyntheticMode() error {
generator := NewSyntheticDataGenerator() generator := NewSyntheticDataGenerator()
// Generate a random number of synthetic models (1-3) // Generate a random number of synthetic models (1-3)
numModels := generator.rand.Intn(3) + 1 numModels := generator.rand.IntN(3) + 1
fmt.Printf("Generating %d synthetic models for testing...\n", numModels) fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
var models []ProcessedModel var models []ProcessedModel
for i := 0; i < numModels; i++ { for range numModels {
model := generator.GenerateProcessedModel() model := generator.GenerateProcessedModel()
models = append(models, model) models = append(models, model)
fmt.Printf("Generated synthetic model: %s\n", model.ModelID) fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
@@ -42,14 +42,14 @@ type SyntheticDataGenerator struct {
// NewSyntheticDataGenerator creates a new synthetic data generator // NewSyntheticDataGenerator creates a new synthetic data generator
func NewSyntheticDataGenerator() *SyntheticDataGenerator { func NewSyntheticDataGenerator() *SyntheticDataGenerator {
return &SyntheticDataGenerator{ return &SyntheticDataGenerator{
rand: rand.New(rand.NewSource(time.Now().UnixNano())), rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)),
} }
} }
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile // GenerateProcessedModelFile creates a synthetic ProcessedModelFile
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile { func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
fileTypes := []string{"model", "readme", "other"} fileTypes := []string{"model", "readme", "other"}
fileType := fileTypes[g.rand.Intn(len(fileTypes))] fileType := fileTypes[g.rand.IntN(len(fileTypes))]
var path string var path string
var isReadme bool var isReadme bool
@@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile
return ProcessedModelFile{ return ProcessedModelFile{
Path: path, Path: path,
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB
SHA256: g.randomSHA256(), SHA256: g.randomSHA256(),
IsReadme: isReadme, IsReadme: isReadme,
FileType: fileType, FileType: fileType,
@@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"} authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"} modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
author := authors[g.rand.Intn(len(authors))] author := authors[g.rand.IntN(len(authors))]
modelName := modelNames[g.rand.Intn(len(modelNames))] modelName := modelNames[g.rand.IntN(len(modelNames))]
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6)) modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
// Generate files // Generate files
numFiles := g.rand.Intn(5) + 2 // 2-6 files numFiles := g.rand.IntN(5) + 2 // 2-6 files
files := make([]ProcessedModelFile, numFiles) files := make([]ProcessedModelFile, numFiles)
// Ensure at least one model file and one readme // Ensure at least one model file and one readme
hasModelFile := false hasModelFile := false
hasReadme := false hasReadme := false
for i := 0; i < numFiles; i++ { for i := range numFiles {
files[i] = g.GenerateProcessedModelFile() files[i] = g.GenerateProcessedModelFile()
if files[i].FileType == "model" { if files[i].FileType == "model" {
hasModelFile = true hasModelFile = true
@@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
// Generate sample metadata // Generate sample metadata
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""} licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
license := licenses[g.rand.Intn(len(licenses))] license := licenses[g.rand.IntN(len(licenses))]
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"} sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
numTags := g.rand.Intn(4) + 3 // 3-6 tags numTags := g.rand.IntN(4) + 3 // 3-6 tags
tags := make([]string, numTags) tags := make([]string, numTags)
for i := 0; i < numTags; i++ { for i := range numTags {
tags[i] = sampleTags[g.rand.Intn(len(sampleTags))] tags[i] = sampleTags[g.rand.IntN(len(sampleTags))]
} }
// Remove duplicates // Remove duplicates
tags = g.removeDuplicates(tags) tags = g.removeDuplicates(tags)
// Optionally include icon (50% chance) // Optionally include icon (50% chance)
icon := "" icon := ""
if g.rand.Intn(2) == 0 { if g.rand.IntN(2) == 0 {
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24)) icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
} }
return ProcessedModel{ return ProcessedModel{
ModelID: modelID, ModelID: modelID,
Author: author, Author: author,
Downloads: g.rand.Intn(1000000) + 1000, Downloads: g.rand.IntN(1000000) + 1000,
LastModified: g.randomDate(), LastModified: g.randomDate(),
Files: files, Files: files,
PreferredModelFile: preferredModelFile, PreferredModelFile: preferredModelFile,
@@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789" const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length) b := make([]byte, length)
for i := range b { for i := range b {
b[i] = charset[g.rand.Intn(len(charset))] b[i] = charset[g.rand.IntN(len(charset))]
} }
return string(b) return string(b)
} }
@@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string {
const charset = "0123456789abcdef" const charset = "0123456789abcdef"
b := make([]byte, 64) b := make([]byte, 64)
for i := range b { for i := range b {
b[i] = charset[g.rand.Intn(len(charset))] b[i] = charset[g.rand.IntN(len(charset))]
} }
return string(b) return string(b)
} }
func (g *SyntheticDataGenerator) randomDate() string { func (g *SyntheticDataGenerator) randomDate() string {
now := time.Now() now := time.Now()
daysAgo := g.rand.Intn(365) // Random date within last year daysAgo := g.rand.IntN(365) // Random date within last year
pastDate := now.AddDate(0, 0, -daysAgo) pastDate := now.AddDate(0, 0, -daysAgo)
return pastDate.Format("2006-01-02T15:04:05.000Z") return pastDate.Format("2006-01-02T15:04:05.000Z")
} }
@@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string)
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author), fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
} }
return templates[g.rand.Intn(len(templates))] return templates[g.rand.IntN(len(templates))]
} }

View File

@@ -118,6 +118,32 @@ jobs:
dockerfile: "./backend/Dockerfile.python" dockerfile: "./backend/Dockerfile.python"
context: "./" context: "./"
ubuntu-version: '2404' ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'true'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64,linux/arm64'
tag-latest: 'auto'
tag-suffix: '-cpu-llama-cpp-quantization'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'true'
backend: "llama-cpp-quantization"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: '' - build-type: ''
cuda-major-version: "" cuda-major-version: ""
cuda-minor-version: "" cuda-minor-version: ""
@@ -366,6 +392,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python" dockerfile: "./backend/Dockerfile.python"
context: "./" context: "./"
ubuntu-version: '2404' ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas' - build-type: 'cublas'
cuda-major-version: "12" cuda-major-version: "12"
cuda-minor-version: "8" cuda-minor-version: "8"
@@ -757,6 +796,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python" dockerfile: "./backend/Dockerfile.python"
context: "./" context: "./"
ubuntu-version: '2404' ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-13-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'l4t' - build-type: 'l4t'
cuda-major-version: "13" cuda-major-version: "13"
cuda-minor-version: "0" cuda-minor-version: "0"
@@ -2373,6 +2425,9 @@ jobs:
tag-suffix: "-metal-darwin-arm64-local-store" tag-suffix: "-metal-darwin-arm64-local-store"
build-type: "metal" build-type: "metal"
lang: "go" lang: "go"
- backend: "llama-cpp-quantization"
tag-suffix: "-metal-darwin-arm64-llama-cpp-quantization"
build-type: "mps"
with: with:
backend: ${{ matrix.backend }} backend: ${{ matrix.backend }}
build-type: ${{ matrix.build-type }} build-type: ${{ matrix.build-type }}

View File

@@ -0,0 +1,48 @@
name: Bump inference defaults
on:
schedule:
# Run daily at 06:00 UTC
- cron: '0 6 * * *'
workflow_dispatch: # Allow manual trigger
permissions:
contents: write
pull-requests: write
jobs:
bump:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Re-fetch inference defaults
run: make generate-force
- name: Check for changes
id: diff
run: |
if git diff --quiet core/config/inference_defaults.json; then
echo "changed=false" >> "$GITHUB_OUTPUT"
else
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Create Pull Request
if: steps.diff.outputs.changed == 'true'
uses: peter-evans/create-pull-request@v8
with:
commit-message: "chore: bump inference defaults from unsloth"
title: "chore: bump inference defaults from unsloth"
body: |
Auto-generated update of `core/config/inference_defaults.json` from
[unsloth's inference_defaults.json](https://github.com/unslothai/unsloth/blob/main/studio/backend/assets/configs/inference_defaults.json).
This PR was created automatically by the `bump-inference-defaults` workflow.
branch: chore/bump-inference-defaults
delete-branch: true
labels: automated

View File

@@ -55,7 +55,7 @@ jobs:
- name: Run gallery agent - name: Run gallery agent
env: env:
#OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }} #OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
OPENAI_MODE: Qwen3.5-2B-GGUF OPENAI_MODEL: Qwen3.5-2B-GGUF
OPENAI_BASE_URL: "http://localhost:8080" OPENAI_BASE_URL: "http://localhost:8080"
OPENAI_KEY: ${{ secrets.OPENAI_KEY }} OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
#OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }} #OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}

75
.github/workflows/gh-pages.yml vendored Normal file
View File

@@ -0,0 +1,75 @@
name: Deploy docs to GitHub Pages
on:
push:
branches:
- master
paths:
- 'docs/**'
- 'gallery/**'
- 'images/**'
- '.github/ci/modelslist.go'
- '.github/workflows/gh-pages.yml'
workflow_dispatch:
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: pages
cancel-in-progress: false
jobs:
build:
runs-on: ubuntu-latest
env:
HUGO_VERSION: "0.146.3"
steps:
- name: Checkout
uses: actions/checkout@v6
with:
fetch-depth: 0 # needed for enableGitInfo
submodules: true
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '1.22'
cache: false
- name: Setup Hugo
uses: peaceiris/actions-hugo@v3
with:
hugo-version: ${{ env.HUGO_VERSION }}
extended: true
- name: Setup Pages
id: pages
uses: actions/configure-pages@v6
- name: Generate gallery
run: go run ./.github/ci/modelslist.go ./gallery/index.yaml > docs/static/gallery.html
- name: Build site
working-directory: docs
run: |
mkdir -p layouts/_default
hugo --minify --baseURL "${{ steps.pages.outputs.base_url }}/"
- name: Upload artifact
uses: actions/upload-pages-artifact@v4
with:
path: docs/public
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
needs: build
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v5

View File

@@ -14,6 +14,37 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
detect-changes:
runs-on: ubuntu-latest
outputs:
run-all: ${{ steps.detect.outputs.run-all }}
transformers: ${{ steps.detect.outputs.transformers }}
rerankers: ${{ steps.detect.outputs.rerankers }}
diffusers: ${{ steps.detect.outputs.diffusers }}
coqui: ${{ steps.detect.outputs.coqui }}
moonshine: ${{ steps.detect.outputs.moonshine }}
pocket-tts: ${{ steps.detect.outputs.pocket-tts }}
qwen-tts: ${{ steps.detect.outputs.qwen-tts }}
qwen-asr: ${{ steps.detect.outputs.qwen-asr }}
nemo: ${{ steps.detect.outputs.nemo }}
voxcpm: ${{ steps.detect.outputs.voxcpm }}
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
voxtral: ${{ steps.detect.outputs.voxtral }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Install dependencies
run: bun add js-yaml @octokit/core
- name: Detect changed backends
id: detect
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_EVENT_PATH: ${{ github.event_path }}
run: bun run scripts/changed-backends.js
# Requires CUDA # Requires CUDA
# tests-chatterbox-tts: # tests-chatterbox-tts:
# runs-on: ubuntu-latest # runs-on: ubuntu-latest
@@ -37,6 +68,8 @@ jobs:
# make --jobs=5 --output-sync=target -C backend/python/chatterbox # make --jobs=5 --output-sync=target -C backend/python/chatterbox
# make --jobs=5 --output-sync=target -C backend/python/chatterbox test # make --jobs=5 --output-sync=target -C backend/python/chatterbox test
tests-transformers: tests-transformers:
needs: detect-changes
if: needs.detect-changes.outputs.transformers == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -58,6 +91,8 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/transformers make --jobs=5 --output-sync=target -C backend/python/transformers
make --jobs=5 --output-sync=target -C backend/python/transformers test make --jobs=5 --output-sync=target -C backend/python/transformers test
tests-rerankers: tests-rerankers:
needs: detect-changes
if: needs.detect-changes.outputs.rerankers == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -80,6 +115,8 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/rerankers test make --jobs=5 --output-sync=target -C backend/python/rerankers test
tests-diffusers: tests-diffusers:
needs: detect-changes
if: needs.detect-changes.outputs.diffusers == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -229,6 +266,8 @@ jobs:
# make --jobs=5 --output-sync=target -C backend/python/vllm test # make --jobs=5 --output-sync=target -C backend/python/vllm test
tests-coqui: tests-coqui:
needs: detect-changes
if: needs.detect-changes.outputs.coqui == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -248,6 +287,8 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/coqui make --jobs=5 --output-sync=target -C backend/python/coqui
make --jobs=5 --output-sync=target -C backend/python/coqui test make --jobs=5 --output-sync=target -C backend/python/coqui test
tests-moonshine: tests-moonshine:
needs: detect-changes
if: needs.detect-changes.outputs.moonshine == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -267,6 +308,8 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/moonshine make --jobs=5 --output-sync=target -C backend/python/moonshine
make --jobs=5 --output-sync=target -C backend/python/moonshine test make --jobs=5 --output-sync=target -C backend/python/moonshine test
tests-pocket-tts: tests-pocket-tts:
needs: detect-changes
if: needs.detect-changes.outputs.pocket-tts == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -286,6 +329,8 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/pocket-tts make --jobs=5 --output-sync=target -C backend/python/pocket-tts
make --jobs=5 --output-sync=target -C backend/python/pocket-tts test make --jobs=5 --output-sync=target -C backend/python/pocket-tts test
tests-qwen-tts: tests-qwen-tts:
needs: detect-changes
if: needs.detect-changes.outputs.qwen-tts == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -327,6 +372,8 @@ jobs:
# make --jobs=5 --output-sync=target -C backend/python/fish-speech # make --jobs=5 --output-sync=target -C backend/python/fish-speech
# make --jobs=5 --output-sync=target -C backend/python/fish-speech test # make --jobs=5 --output-sync=target -C backend/python/fish-speech test
tests-qwen-asr: tests-qwen-asr:
needs: detect-changes
if: needs.detect-changes.outputs.qwen-asr == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -346,6 +393,8 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/qwen-asr make --jobs=5 --output-sync=target -C backend/python/qwen-asr
make --jobs=5 --output-sync=target -C backend/python/qwen-asr test make --jobs=5 --output-sync=target -C backend/python/qwen-asr test
tests-nemo: tests-nemo:
needs: detect-changes
if: needs.detect-changes.outputs.nemo == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -365,6 +414,8 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/nemo make --jobs=5 --output-sync=target -C backend/python/nemo
make --jobs=5 --output-sync=target -C backend/python/nemo test make --jobs=5 --output-sync=target -C backend/python/nemo test
tests-voxcpm: tests-voxcpm:
needs: detect-changes
if: needs.detect-changes.outputs.voxcpm == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -383,7 +434,38 @@ jobs:
run: | run: |
make --jobs=5 --output-sync=target -C backend/python/voxcpm make --jobs=5 --output-sync=target -C backend/python/voxcpm
make --jobs=5 --output-sync=target -C backend/python/voxcpm test make --jobs=5 --output-sync=target -C backend/python/voxcpm test
tests-llama-cpp-quantization:
needs: detect-changes
if: needs.detect-changes.outputs.llama-cpp-quantization == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Clone
uses: actions/checkout@v6
with:
submodules: true
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential cmake curl git python3-pip
# Install UV
curl -LsSf https://astral.sh/uv/install.sh | sh
pip install --user --no-cache-dir grpcio-tools==1.64.1
- name: Build llama-quantize from llama.cpp
run: |
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git /tmp/llama.cpp
cmake -B /tmp/llama.cpp/build -S /tmp/llama.cpp -DGGML_NATIVE=OFF
cmake --build /tmp/llama.cpp/build --target llama-quantize -j$(nproc)
sudo cp /tmp/llama.cpp/build/bin/llama-quantize /usr/local/bin/
- name: Install backend
run: |
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization
- name: Test llama-cpp-quantization
run: |
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization test
tests-acestep-cpp: tests-acestep-cpp:
needs: detect-changes
if: needs.detect-changes.outputs.acestep-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone
@@ -414,6 +496,8 @@ jobs:
run: | run: |
make --jobs=5 --output-sync=target -C backend/go/acestep-cpp test make --jobs=5 --output-sync=target -C backend/go/acestep-cpp test
tests-voxtral: tests-voxtral:
needs: detect-changes
if: needs.detect-changes.outputs.voxtral == 'true' || needs.detect-changes.outputs.run-all == 'true'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Clone - name: Clone

View File

@@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
go-version: ['1.25.x'] go-version: ['1.26.x']
steps: steps:
- name: Free Disk Space (Ubuntu) - name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main uses: jlumbroso/free-disk-space@main
@@ -179,7 +179,7 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
strategy: strategy:
matrix: matrix:
go-version: ['1.25.x'] go-version: ['1.26.x']
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v6 uses: actions/checkout@v6

72
.github/workflows/tests-ui-e2e.yml vendored Normal file
View File

@@ -0,0 +1,72 @@
---
name: 'UI E2E Tests'
on:
pull_request:
paths:
- 'core/http/**'
- 'tests/e2e-ui/**'
- 'tests/e2e/mock-backend/**'
push:
branches:
- master
concurrency:
group: ci-tests-ui-e2e-${{ github.head_ref || github.ref }}-${{ github.repository }}
cancel-in-progress: true
jobs:
tests-ui-e2e:
runs-on: ubuntu-latest
strategy:
matrix:
go-version: ['1.26.x']
steps:
- name: Clone
uses: actions/checkout@v6
with:
submodules: true
- name: Setup Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '22'
- name: Proto Dependencies
run: |
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
rm protoc.zip
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
- name: System Dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential libopus-dev
- name: Build UI test server
run: PATH="$PATH:$HOME/go/bin" make build-ui-test-server
- name: Install Playwright
working-directory: core/http/react-ui
run: |
npm install
npx playwright install --with-deps chromium
- name: Run Playwright tests
working-directory: core/http/react-ui
run: npx playwright test
- name: Upload Playwright report
if: ${{ failure() }}
uses: actions/upload-artifact@v7
with:
name: playwright-report
path: core/http/react-ui/playwright-report/
retention-days: 7
- name: Setup tmate session if tests fail
if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.23
with:
detached: true
connect-timeout-seconds: 180
limit-access-to-actor: true

5
.gitignore vendored
View File

@@ -72,3 +72,8 @@ core/http/react-ui/dist
# Extracted backend binaries for container-based testing # Extracted backend binaries for container-based testing
local-backends/ local-backends/
# UI E2E test artifacts
tests/e2e-ui/ui-test-server
core/http/react-ui/playwright-report/
core/http/react-ui/test-results/

View File

@@ -11,6 +11,8 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions | | [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing | | [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI | | [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
## Quick Reference ## Quick Reference

View File

@@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH}
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it. # The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
FROM requirements-drivers AS build-requirements FROM requirements-drivers AS build-requirements
ARG GO_VERSION=1.25.4 ARG GO_VERSION=1.26.0
ARG CMAKE_VERSION=3.31.10 ARG CMAKE_VERSION=3.31.10
ARG CMAKE_FROM_SOURCE=false ARG CMAKE_FROM_SOURCE=false
ARG TARGETARCH ARG TARGETARCH
@@ -256,7 +256,7 @@ RUN apt-get update && \
FROM build-requirements AS builder-base FROM build-requirements AS builder-base
ARG GO_TAGS="" ARG GO_TAGS="auth"
ARG GRPC_BACKENDS ARG GRPC_BACKENDS
ARG MAKEFLAGS ARG MAKEFLAGS
ARG LD_FLAGS="-s -w" ARG LD_FLAGS="-s -w"
@@ -319,7 +319,6 @@ COPY ./.git ./.git
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here # Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
COPY ./pkg/grpc ./pkg/grpc COPY ./pkg/grpc ./pkg/grpc
COPY ./pkg/utils ./pkg/utils COPY ./pkg/utils ./pkg/utils
COPY ./pkg/langchain ./pkg/langchain
RUN ls -l ./ RUN ls -l ./
RUN make protogen-go RUN make protogen-go

View File

@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds # Disable parallel execution for backend builds
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus .NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization
GOCMD=go GOCMD=go
GOTEST=$(GOCMD) test GOTEST=$(GOCMD) test
@@ -107,7 +107,7 @@ core/http/react-ui/dist: react-ui
## Build: ## Build:
build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project build: protogen-go generate install-go-tools core/http/react-ui/dist ## Build the project
$(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I local-ai build info:${RESET})
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
@@ -398,6 +398,16 @@ protogen-go: protoc install-go-tools
./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \ ./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
backend/backend.proto backend/backend.proto
core/config/inference_defaults.json: ## Fetch inference defaults from unsloth (only if missing)
$(GOCMD) generate ./core/config/...
.PHONY: generate
generate: core/config/inference_defaults.json ## Ensure inference defaults exist
.PHONY: generate-force
generate-force: ## Re-fetch inference defaults from unsloth (always)
$(GOCMD) generate ./core/config/...
.PHONY: protogen-go-clean .PHONY: protogen-go-clean
protogen-go-clean: protogen-go-clean:
$(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go $(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
@@ -421,6 +431,7 @@ prepare-test-extra: protogen-python
$(MAKE) -C backend/python/voxcpm $(MAKE) -C backend/python/voxcpm
$(MAKE) -C backend/python/whisperx $(MAKE) -C backend/python/whisperx
$(MAKE) -C backend/python/ace-step $(MAKE) -C backend/python/ace-step
$(MAKE) -C backend/python/trl
test-extra: prepare-test-extra test-extra: prepare-test-extra
$(MAKE) -C backend/python/transformers test $(MAKE) -C backend/python/transformers test
@@ -440,6 +451,7 @@ test-extra: prepare-test-extra
$(MAKE) -C backend/python/voxcpm test $(MAKE) -C backend/python/voxcpm test
$(MAKE) -C backend/python/whisperx test $(MAKE) -C backend/python/whisperx test
$(MAKE) -C backend/python/ace-step test $(MAKE) -C backend/python/ace-step test
$(MAKE) -C backend/python/trl test
DOCKER_IMAGE?=local-ai DOCKER_IMAGE?=local-ai
IMAGE_TYPE?=core IMAGE_TYPE?=core
@@ -572,6 +584,8 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true
BACKEND_WHISPERX = whisperx|python|.|false|true BACKEND_WHISPERX = whisperx|python|.|false|true
BACKEND_ACE_STEP = ace-step|python|.|false|true BACKEND_ACE_STEP = ace-step|python|.|false|true
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
BACKEND_TRL = trl|python|.|false|true
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
# Helper function to build docker image for a backend # Helper function to build docker image for a backend
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG) # Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
@@ -629,12 +643,14 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP))) $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP))) $(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED))) $(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
# Pattern rule for docker-save targets # Pattern rule for docker-save targets
docker-save-%: backend-images docker-save-%: backend-images
docker save local-ai-backend:$* -o backend-images/$*.tar docker save local-ai-backend:$* -o backend-images/$*.tar
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization
######################################################## ########################################################
### Mock Backend for E2E Tests ### Mock Backend for E2E Tests
@@ -646,6 +662,23 @@ build-mock-backend: protogen-go
clean-mock-backend: clean-mock-backend:
rm -f tests/e2e/mock-backend/mock-backend rm -f tests/e2e/mock-backend/mock-backend
########################################################
### UI E2E Test Server
########################################################
build-ui-test-server: build-mock-backend react-ui protogen-go
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
test-ui-e2e: build-ui-test-server
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
test-ui-e2e-docker:
docker build -t localai-ui-e2e -f tests/e2e-ui/Dockerfile .
docker run --rm localai-ui-e2e
clean-ui-test-server:
rm -f tests/e2e-ui/ui-test-server
######################################################## ########################################################
### END Backends ### END Backends
######################################################## ########################################################

398
README.md
View File

@@ -5,35 +5,17 @@
</h1> </h1>
<p align="center"> <p align="center">
<a href="https://github.com/go-skynet/LocalAI/fork" target="blank">
<img src="https://img.shields.io/github/forks/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI forks"/>
</a>
<a href="https://github.com/go-skynet/LocalAI/stargazers" target="blank"> <a href="https://github.com/go-skynet/LocalAI/stargazers" target="blank">
<img src="https://img.shields.io/github/stars/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI stars"/> <img src="https://img.shields.io/github/stars/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI stars"/>
</a> </a>
<a href="https://github.com/go-skynet/LocalAI/pulls" target="blank">
<img src="https://img.shields.io/github/issues-pr/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI pull-requests"/>
</a>
<a href='https://github.com/go-skynet/LocalAI/releases'> <a href='https://github.com/go-skynet/LocalAI/releases'>
<img src='https://img.shields.io/github/release/go-skynet/LocalAI?&label=Latest&style=for-the-badge'> <img src='https://img.shields.io/github/release/go-skynet/LocalAI?&label=Latest&style=for-the-badge'>
</a> </a>
</p>
<p align="center">
<a href="LICENSE" target="blank"> <a href="LICENSE" target="blank">
<img src="https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge" alt="LocalAI License"/> <img src="https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge" alt="LocalAI License"/>
</a> </a>
</p> </p>
<p align="center">
<a href="https://hub.docker.com/r/localai/localai" target="blank">
<img src="https://img.shields.io/badge/dockerhub-images-important.svg?logo=Docker" alt="LocalAI Docker hub"/>
</a>
<a href="https://quay.io/repository/go-skynet/local-ai?tab=tags&tag=latest" target="blank">
<img src="https://img.shields.io/badge/quay.io-images-important.svg?" alt="LocalAI Quay.io"/>
</a>
</p>
<p align="center"> <p align="center">
<a href="https://twitter.com/LocalAI_API" target="blank"> <a href="https://twitter.com/LocalAI_API" target="blank">
<img src="https://img.shields.io/badge/X-%23000000.svg?style=for-the-badge&logo=X&logoColor=white&label=LocalAI_API" alt="Follow LocalAI_API"/> <img src="https://img.shields.io/badge/X-%23000000.svg?style=for-the-badge&logo=X&logoColor=white&label=LocalAI_API" alt="Follow LocalAI_API"/>
@@ -47,363 +29,161 @@
<a href="https://trendshift.io/repositories/5539" target="_blank"><img src="https://trendshift.io/api/badge/repositories/5539" alt="mudler%2FLocalAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/5539" target="_blank"><img src="https://trendshift.io/api/badge/repositories/5539" alt="mudler%2FLocalAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p> </p>
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/) **LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
>
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/localaiofficial_bot)
[![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai) - **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
- **35+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
- **Multi-user ready** — API key auth, user quotas, role-based access
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
- **Privacy-first** — your data never leaves your infrastructure
<p align="center"> Created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
<a href="https://github.com/mudler/LocalAI-examples" target="blank">
<img src="https://img.shields.io/badge/📦_Examples_Repository-Browse_Ready--to--Run_Examples-blue?style=for-the-badge" alt="LocalAI Examples Repository"/>
</a>
</p>
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler). > [:book: Documentation](https://localai.io/) | [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) | [💻 Quickstart](https://localai.io/basics/getting_started/) | [🖼️ Models](https://models.localai.io/) | [❓FAQ](https://localai.io/faq/)
<details> ## Screenshots
<summary><strong>Table of Contents</strong></summary>
- [Local Stack Family](#local-stack-family) ### Chat, Model gallery
- [Screenshots / Video](#screenshots--video)
- [Quickstart](#-quickstart)
- [macOS Download](#macos-download)
- [Containers (Docker, podman, ...)](#containers-docker-podman-)
- [Latest project news](#-latest-project-news)
- [Features](#-features)
- [Supported Backends & Acceleration](#-supported-backends--acceleration)
- [Text Generation & Language Models](#text-generation--language-models)
- [Audio & Speech Processing](#audio--speech-processing)
- [Image & Video Generation](#image--video-generation)
- [Specialized AI Tasks](#specialized-ai-tasks)
- [Hardware Acceleration Matrix](#hardware-acceleration-matrix)
- [Community and integrations](#-community-and-integrations)
- [Resources](#-resources)
- [Media, Blogs, Social](#book--media-blogs-social)
- [Autonomous Development Team](#-autonomous-development-team)
- [Citation](#citation)
- [Sponsors](#-sponsors)
- [Individual sponsors](#individual-sponsors)
- [Star history](#-star-history)
- [License](#-license)
- [Acknowledgements](#-acknowledgements)
- [Contributors](#-contributors)
</details> https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
## Local Stack Family ### Agents
Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tools, you might also like: https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
- **[LocalAGI](https://github.com/mudler/LocalAGI)** - AI agent orchestration platform with OpenAI Responses API compatibility and advanced agentic capabilities ## Quickstart
- **[LocalRecall](https://github.com/mudler/LocalRecall)** - MCP/REST API knowledge base system providing persistent memory and storage for AI agents
- 🆕 **[Cogito](https://github.com/mudler/cogito)** - Go library for building intelligent, co-operative agentic software and LLM-powered workflows, focusing on improving results for small, open source language models that scales to any LLM. Powers LocalAGI and LocalAI MCP/Agentic capabilities
- 🆕 **[Wiz](https://github.com/mudler/wiz)** - Terminal-based AI agent accessible via Ctrl+Space keybinding. Portable, local-LLM friendly shell assistant with TUI/CLI modes, tool execution with approval, MCP protocol support, and multi-shell compatibility (zsh, bash, fish)
- 🆕 **[SkillServer](https://github.com/mudler/skillserver)** - Simple, centralized skills database for AI agents via MCP. Manages skills as Markdown files with MCP server integration, web UI for editing, Git synchronization, and full-text search capabilities
### macOS
## Screenshots / Video
### Youtube video
<h1 align="center">
<br>
<a href="https://www.youtube.com/watch?v=PDqYhB9nNHA" target="_blank"> <img width="300" src="https://img.youtube.com/vi/PDqYhB9nNHA/0.jpg"> </a><br>
<br>
</h1>
### Screenshots
| Talk Interface | Generate Audio |
| --- | --- |
| ![LocalAI - Talk](./docs/assets/images/screenshots/screenshot_talk.png) | ![LocalAI - Generate Audio](./docs/assets/images/screenshots/screenshot_tts.png) |
| Models Overview | Generate Images |
| --- | --- |
| ![LocalAI - Models](./docs/assets/images/screenshots/screenshot_gallery.png) | ![LocalAI - Generate Images](./docs/assets/images/screenshots/screenshot_image.png) |
| Chat Interface | Home |
| --- | --- |
| ![LocalAI - Chat](./docs/assets/images/screenshots/screenshot_chat.png) | ![LocalAI - Home](./docs/assets/images/screenshots/screenshot_home.png) |
| Login | Swarm |
| --- | --- |
| ![LocalAI - Login](./docs/assets/images/screenshots/screenshot_login.png) | ![LocalAI - P2P Dashboard](./docs/assets/images/screenshots/screenshot_p2p.png) |
## 💻 Quickstart
### macOS Download:
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg"> <a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/> <img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
</a> </a>
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244 > **Note:** The DMG is not signed by Apple. After installing, run: `sudo xattr -d com.apple.quarantine /Applications/LocalAI.app`. See [#6268](https://github.com/mudler/LocalAI/issues/6268) for details.
### Containers (Docker, podman, ...) ### Containers (Docker, podman, ...)
> **💡 Docker Run vs Docker Start** > Already ran LocalAI before? Use `docker start -i local-ai` to restart an existing container.
>
> - `docker run` creates and starts a new container. If a container with the same name already exists, this command will fail.
> - `docker start` starts an existing container that was previously created with `docker run`.
>
> If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai`
#### CPU only image: #### CPU only:
```bash ```bash
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest docker run -ti --name local-ai -p 8080:8080 localai/localai:latest
``` ```
#### NVIDIA GPU Images: #### NVIDIA GPU:
```bash ```bash
# CUDA 13.0 # CUDA 13
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-13 docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-13
# CUDA 12.0 # CUDA 12
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12 docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12
# NVIDIA Jetson (L4T) ARM64 # NVIDIA Jetson ARM64 (CUDA 12, for AGX Orin and similar)
# CUDA 12 (for Nvidia AGX Orin and similar platforms)
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64 docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64
# CUDA 13 (for Nvidia DGX Spark) # NVIDIA Jetson ARM64 (CUDA 13, for DGX Spark)
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13 docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13
``` ```
#### AMD GPU Images (ROCm): #### AMD GPU (ROCm):
```bash ```bash
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas
``` ```
#### Intel GPU Images (oneAPI): #### Intel GPU (oneAPI):
```bash ```bash
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
``` ```
#### Vulkan GPU Images: #### Vulkan GPU:
```bash ```bash
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan
``` ```
To load models: ### Loading models
```bash ```bash
# From the model gallery (see available models with `local-ai models list`, in the WebUI from the model tab, or visiting https://models.localai.io) # From the model gallery (see available models with `local-ai models list` or at https://models.localai.io)
local-ai run llama-3.2-1b-instruct:q4_k_m local-ai run llama-3.2-1b-instruct:q4_k_m
# Start LocalAI with the phi-2 model directly from huggingface # From Huggingface
local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf
# Install and run a model from the Ollama OCI registry # From the Ollama OCI registry
local-ai run ollama://gemma:2b local-ai run ollama://gemma:2b
# Run a model from a configuration file # From a YAML config
local-ai run https://gist.githubusercontent.com/.../phi-2.yaml local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
# Install and run a model from a standard OCI registry (e.g., Docker Hub) # From a standard OCI registry (e.g., Docker Hub)
local-ai run oci://localai/phi-2:latest local-ai run oci://localai/phi-2:latest
``` ```
> **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection). > **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).
## 📰 Latest project news ## Latest News
- March 2026: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790),[MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
- February 2026: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
- January 2026: **LocalAI 3.10.0** - Major release with Anthropic API support, Open Responses API for stateful agents, video & image generation suite (LTX-2), unified GPU backends, tool streaming & XML parsing, system-aware backend gallery, crash fixes for AVX-only CPUs and AMD VRAM reporting, request tracing, and new backends: **Moonshine** (ultra-fast transcription), **Pocket-TTS** (lightweight TTS). Vulkan arm64 builds now available. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0).
- December 2025: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic fitting of models to multiple GPUS(llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Added Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
- November 2025: Major improvements to the UX. Among these: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245) and [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
- October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools
- September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments.
- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
- June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607).
- May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery).
- May 2025: Important: image name changes [See release](https://github.com/mudler/LocalAI/releases/tag/v2.29.0)
- Apr 2025: Rebrand, WebUI enhancements
- Apr 2025: [LocalAGI](https://github.com/mudler/LocalAGI) and [LocalRecall](https://github.com/mudler/LocalRecall) join the LocalAI family stack.
- Apr 2025: WebUI overhaul
- Feb 2025: Backend cleanup, Breaking changes, new backends (kokoro, OutelTTS, faster-whisper), Nvidia L4T images
- Jan 2025: LocalAI model release: https://huggingface.co/mudler/LocalAI-functioncall-phi-4-v0.3, SANA support in diffusers: https://github.com/mudler/LocalAI/pull/4603
- Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 )
- Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 )
- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
- July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723. P2P Global community pools: https://github.com/mudler/LocalAI/issues/3113
- May 2024: 🔥🔥 Decentralized P2P llama.cpp: https://github.com/mudler/LocalAI/pull/2343 (peer2peer llama.cpp!) 👉 Docs https://localai.io/features/distribute/
- May 2024: 🔥🔥 Distributed inferencing: https://github.com/mudler/LocalAI/pull/2324
- April 2024: Reranker API: https://github.com/mudler/LocalAI/pull/2121
Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) - **March 2026**: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
- **February 2026**: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
- **January 2026**: **LocalAI 3.10.0** — Anthropic API support, Open Responses API, video & image generation (LTX-2), unified GPU backends, tool streaming, Moonshine, Pocket-TTS. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0)
- **December 2025**: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic multi-GPU model fitting (llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
- **November 2025**: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245), [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
- **October 2025**: [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support for agentic capabilities
- **September 2025**: New Launcher for macOS and Linux, extended backend support for Mac and Nvidia L4T, MLX-Audio, WAN 2.2
- **August 2025**: MLX, MLX-VLM, Diffusers, llama.cpp now supported on Apple Silicon
- **July 2025**: All backends migrated outside the main binary — [lightweight, modular architecture](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
## 🚀 [Features](https://localai.io/features/) For older news and full release notes, see [GitHub Releases](https://github.com/mudler/LocalAI/releases) and the [News page](https://localai.io/basics/news/).
- 🧩 [Backend Gallery](https://localai.io/backends/): Install/remove backends on the fly, powered by OCI images — fully customizable and API-driven. ## Features
- 📖 [Text generation with GPTs](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [:book: and more](https://localai.io/model-compatibility/index.html#model-compatibility-table))
- 🗣 [Text to Audio](https://localai.io/features/text-to-audio/)
- 🔈 [Audio to Text](https://localai.io/features/audio-to-text/)
- 🎨 [Image generation](https://localai.io/features/image-generation)
- 🔥 [OpenAI-alike tools API](https://localai.io/features/openai-functions/)
- ⚡ [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
- 📈 [Reranker API](https://localai.io/features/reranker/)
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
- 🆕🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - Agentic capabilities with external tools and [LocalAGI's Agentic capabilities](https://github.com/mudler/LocalAGI)
- 🆕🤖 [Built-in Agents](https://localai.io/features/agents/) - Autonomous AI agents with tool use, knowledge base (RAG), skills, SSE streaming, import/export, and [Agent Hub](https://agenthub.localai.io) — powered by [LocalAGI](https://github.com/mudler/LocalAGI)
- 🔊 Voice activity detection (Silero-VAD support)
- 🌍 Integrated WebUI!
## 🧩 Supported Backends & Acceleration - [Text generation](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [and more](https://localai.io/model-compatibility/))
- [Text to Audio](https://localai.io/features/text-to-audio/)
- [Audio to Text](https://localai.io/features/audio-to-text/)
- [Image generation](https://localai.io/features/image-generation)
- [OpenAI-compatible tools API](https://localai.io/features/openai-functions/)
- [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
- [Embeddings generation](https://localai.io/features/embeddings/)
- [Constrained grammars](https://localai.io/features/constrained_grammars/)
- [Download models from Huggingface](https://localai.io/models/)
- [Vision API](https://localai.io/features/gpt-vision/)
- [Object Detection](https://localai.io/features/object-detection/)
- [Reranker API](https://localai.io/features/reranker/)
- [P2P Inferencing](https://localai.io/features/distribute/)
- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images
- Voice Activity Detection (Silero-VAD)
- Integrated WebUI
LocalAI supports a comprehensive range of AI backends with multiple acceleration options: ## Supported Backends & Acceleration
### Text Generation & Language Models LocalAI supports **35+ backends** including llama.cpp, vLLM, transformers, whisper.cpp, diffusers, MLX, MLX-VLM, and many more. Hardware acceleration is available for **NVIDIA** (CUDA 12/13), **AMD** (ROCm), **Intel** (oneAPI/SYCL), **Apple Silicon** (Metal), **Vulkan**, and **NVIDIA Jetson** (L4T). All backends can be installed on-the-fly from the [Backend Gallery](https://localai.io/backends/).
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **llama.cpp** | LLM inference in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, Metal, CPU |
| **vLLM** | Fast LLM inference with PagedAttention | CUDA 12/13, ROCm, Intel |
| **transformers** | HuggingFace transformers framework | CUDA 12/13, ROCm, Intel, CPU |
| **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) |
| **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) |
| **vLLM Omni** | Multimodal vLLM with vision and audio | CUDA 12/13, ROCm, Intel |
### Audio & Speech Processing See the full [Backend & Model Compatibility Table](https://localai.io/model-compatibility/) and [GPU Acceleration guide](https://localai.io/features/gpu-acceleration/).
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, CPU |
| **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12/13, ROCm, Intel, CPU |
| **moonshine** | Ultra-fast transcription engine for low-end devices | CUDA 12/13, Metal, CPU |
| **coqui** | Advanced TTS with 1100+ languages | CUDA 12/13, ROCm, Intel, CPU |
| **kokoro** | Lightweight TTS model | CUDA 12/13, ROCm, Intel, CPU |
| **chatterbox** | Production-grade TTS | CUDA 12/13, CPU |
| **piper** | Fast neural TTS system | CPU |
| **kitten-tts** | Kitten TTS models | CPU |
| **silero-vad** | Voice Activity Detection | CPU |
| **neutts** | Text-to-speech with voice cloning | CUDA 12/13, ROCm, CPU |
| **vibevoice** | Real-time TTS with voice cloning | CUDA 12/13, ROCm, Intel, CPU |
| **pocket-tts** | Lightweight CPU-based TTS | CUDA 12/13, ROCm, Intel, CPU |
| **qwen-tts** | High-quality TTS with custom voice, voice design, and voice cloning | CUDA 12/13, ROCm, Intel, CPU |
| **nemo** | NVIDIA NeMo framework for speech models | CUDA 12/13, ROCm, Intel, CPU |
| **outetts** | OuteTTS with voice cloning | CUDA 12/13, CPU |
| **faster-qwen3-tts** | Faster Qwen3 TTS | CUDA 12/13, ROCm, Intel, CPU |
| **qwen-asr** | Qwen ASR speech recognition | CUDA 12/13, ROCm, Intel, CPU |
| **voxcpm** | VoxCPM speech understanding | CUDA 12/13, Metal, CPU |
| **whisperx** | Enhanced Whisper transcription | CUDA 12/13, ROCm, Intel, CPU |
| **ace-step** | Music generation from text descriptions, lyrics, or audio samples | CUDA 12/13, ROCm, Intel, Metal, CPU |
### Image & Video Generation ## Resources
| Backend | Description | Acceleration Support |
|---------|-------------|---------------------|
| **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12/13, Intel SYCL, Vulkan, CPU |
| **diffusers** | HuggingFace diffusion models | CUDA 12/13, ROCm, Intel, Metal, CPU |
### Specialized AI Tasks - [Documentation](https://localai.io/)
| Backend | Description | Acceleration Support | - [LLM fine-tuning guide](https://localai.io/docs/advanced/fine-tuning/)
|---------|-------------|---------------------| - [Build from source](https://localai.io/basics/build/)
| **rfdetr** | Real-time object detection | CUDA 12/13, Intel, CPU | - [Kubernetes installation](https://localai.io/basics/getting_started/#run-localai-in-kubernetes)
| **rerankers** | Document reranking API | CUDA 12/13, ROCm, Intel, CPU | - [Integrations & community projects](https://localai.io/docs/integrations/)
| **local-store** | Vector database | CPU | - [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
| **huggingface** | HuggingFace API integration | API-based | - [Examples](https://github.com/mudler/LocalAI-examples)
### Hardware Acceleration Matrix ## Autonomous Development Team
| Acceleration Type | Supported Backends | Hardware Support | LocalAI is helped being maintained by a team of autonomous AI agents led by an AI Scrum Master.
|-------------------|-------------------|------------------|
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
| **NVIDIA CUDA 13** | All CUDA-compatible backends | Nvidia hardware |
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, neutts, vibevoice, pocket-tts, qwen-tts, ace-step | AMD Graphics |
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, coqui, kokoro, vibevoice, pocket-tts, qwen-tts, ace-step | Intel Arc, Intel iGPUs |
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, moonshine, ace-step | Apple M1/M2/M3+ |
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
| **NVIDIA Jetson (CUDA 12)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr, ace-step | ARM64 embedded AI (AGX Orin, etc.) |
| **NVIDIA Jetson (CUDA 13)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (DGX Spark) |
| **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support |
### 🔗 Community and integrations - **Live Reports**: [reports.localai.io](http://reports.localai.io)
- **Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6)
Build and deploy custom containers: - **Blog Post**: [Learn about the experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
- https://github.com/sozercan/aikit
WebUIs:
- https://github.com/Jirubizu/localai-admin
- https://github.com/go-skynet/LocalAI-frontend
- QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot
Agentic Libraries:
- https://github.com/mudler/cogito
MCPs:
- https://github.com/mudler/MCPs
OS Assistant:
- https://github.com/mudler/Keygeist - Keygeist is an AI-powered keyboard operator that listens for key combinations and responds with AI-generated text typed directly into your Linux box.
Model galleries
- https://github.com/go-skynet/model-gallery
Voice:
- https://github.com/richiejp/VoxInput
Other:
- Helm chart https://github.com/go-skynet/helm-charts
- VSCode extension https://github.com/badgooooor/localai-vscode-plugin
- Langchain: https://python.langchain.com/docs/integrations/providers/localai/
- Terminal utility https://github.com/djcopley/ShellOracle
- Local Smart assistant https://github.com/mudler/LocalAGI
- Home Assistant https://github.com/drndos/hass-openai-custom-conversation / https://github.com/valentinfrlch/ha-llmvision / https://github.com/loryanstrant/HA-LocalAI-Monitor
- Discord bot https://github.com/mudler/LocalAGI/tree/main/examples/discord
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
- Shell-Pilot(Interact with LLM using LocalAI models via pure shell scripts on your Linux or MacOS system) https://github.com/reid41/shell-pilot
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
- Another Telegram Bot https://github.com/JackBekket/Hellper
- Auto-documentation https://github.com/JackBekket/Reflexia
- Github bot which answer on issues, with code and documentation as context https://github.com/JackBekket/GitHelper
- Github Actions: https://github.com/marketplace/actions/start-localai
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
### 🔗 Resources
- [LLM finetuning guide](https://localai.io/docs/advanced/fine-tuning/)
- [How to build locally](https://localai.io/basics/build/index.html)
- [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes)
- [Projects integrating LocalAI](https://localai.io/docs/integrations/)
- [How tos section](https://io.midori-ai.xyz/howtos/) (curated by our community)
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
- 🆕 [LocalAI Autonomous Dev Team Blog Post](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
- [Run Visual studio code with LocalAI (SUSE)](https://www.suse.com/c/running-ai-locally/)
- 🆕 [Run LocalAI on Jetson Nano Devkit](https://mudler.pm/posts/local-ai-jetson-nano-devkit/)
- [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/blog/low-code-llm-apps-with-local-ai-flowise-and-pulumi/)
- [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance)
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65)
## 🤖 Autonomous Development Team
LocalAI is now helped being maintained (for small tasks!) by a full team of autonomous AI agents led by an AI Scrum Master! This experiment demonstrates how open source projects can leverage AI agents for sustainable, long-term maintenance.
- **📊 Live Reports**: [Automatically generated reports](http://reports.localai.io)
- **📋 Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6)
- **📝 Blog Post**: [Learn about the autonomous dev team experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
## Citation ## Citation
@@ -419,7 +199,7 @@ If you utilize this repository, data in a downstream project, please consider ci
howpublished = {\url{https://github.com/go-skynet/LocalAI}}, howpublished = {\url{https://github.com/go-skynet/LocalAI}},
``` ```
## ❤️ Sponsors ## Sponsors
> Do you find LocalAI useful? > Do you find LocalAI useful?
@@ -438,19 +218,19 @@ A huge thank you to our generous sponsors who support this project covering CI e
### Individual sponsors ### Individual sponsors
A special thanks to individual sponsors that contributed to the project, a full list is in [Github](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler), a special shout out goes to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone! A special thanks to individual sponsors, a full list is on [GitHub](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler). Special shout out to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
## 🌟 Star history ## Star history
[![LocalAI Star history Chart](https://api.star-history.com/svg?repos=go-skynet/LocalAI&type=Date)](https://star-history.com/#go-skynet/LocalAI&Date) [![LocalAI Star history Chart](https://api.star-history.com/svg?repos=go-skynet/LocalAI&type=Date)](https://star-history.com/#go-skynet/LocalAI&Date)
## 📖 License ## License
LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/). LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/).
MIT - Author Ettore Di Giacinto <mudler@localai.io> MIT - Author Ettore Di Giacinto <mudler@localai.io>
## 🙇 Acknowledgements ## Acknowledgements
LocalAI couldn't have been built without the help of great software already available from the community. Thank you! LocalAI couldn't have been built without the help of great software already available from the community. Thank you!
@@ -463,9 +243,9 @@ LocalAI couldn't have been built without the help of great software already avai
- https://github.com/rhasspy/piper - https://github.com/rhasspy/piper
- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation - [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation
## 🤗 Contributors ## Contributors
This is a community project, a special thanks to our contributors! 🤗 This is a community project, a special thanks to our contributors!
<a href="https://github.com/go-skynet/LocalAI/graphs/contributors"> <a href="https://github.com/go-skynet/LocalAI/graphs/contributors">
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" /> <img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" />
</a> </a>

View File

@@ -39,6 +39,19 @@ service Backend {
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {} rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {} rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
// Fine-tuning RPCs
rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {}
rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {}
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
rpc ExportModel(ExportModelRequest) returns (Result) {}
// Quantization RPCs
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
} }
// Define the empty request // Define the empty request
@@ -166,6 +179,7 @@ message PredictOptions {
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter) int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter) int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking) map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking)
float MinP = 53; // Minimum probability sampling threshold (0.0 = disabled)
} }
// ToolCallDelta represents an incremental tool call update from the C++ parser. // ToolCallDelta represents an incremental tool call update from the C++ parser.
@@ -472,7 +486,7 @@ message ToolFormatMarkers {
string id_field = 16; // e.g., "id" string id_field = 16; // e.g., "id"
bool fun_name_is_key = 17; bool fun_name_is_key = 17;
bool tools_array_wrapped = 18; bool tools_array_wrapped = 18;
bool uses_python_dicts = 19; reserved 19;
// Reasoning markers // Reasoning markers
string reasoning_start = 20; // e.g., "<think>" string reasoning_start = 20; // e.g., "<think>"
@@ -528,3 +542,139 @@ message ModelMetadataResponse {
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable) string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
} }
// Fine-tuning messages
message FineTuneRequest {
// Model identification
string model = 1; // HF model name or local path
string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train
string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training"
// Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion)
int32 adapter_rank = 10; // LoRA rank (r), default 16
int32 adapter_alpha = 11; // scaling factor, default 16
float adapter_dropout = 12; // default 0.0
repeated string target_modules = 13; // layer names to adapt
// Universal training hyperparameters
float learning_rate = 20; // default 2e-4
int32 num_epochs = 21; // default 3
int32 batch_size = 22; // default 2
int32 gradient_accumulation_steps = 23; // default 4
int32 warmup_steps = 24; // default 5
int32 max_steps = 25; // 0 = use epochs
int32 save_steps = 26; // 0 = only save final
float weight_decay = 27; // default 0.01
bool gradient_checkpointing = 28;
string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy
int32 seed = 30; // default 3407
string mixed_precision = 31; // fp16, bf16, fp8, no
// Dataset
string dataset_source = 40; // HF dataset ID, local file/dir path
string dataset_split = 41; // train, test, etc.
// Output
string output_dir = 50;
string job_id = 51; // client-assigned or auto-generated
// Resume training from a checkpoint
string resume_from_checkpoint = 55; // path to checkpoint dir to resume from
// Backend-specific AND method-specific extensibility
map<string, string> extra_options = 60;
}
message FineTuneJobResult {
string job_id = 1;
bool success = 2;
string message = 3;
}
message FineTuneProgressRequest {
string job_id = 1;
}
message FineTuneProgressUpdate {
string job_id = 1;
int32 current_step = 2;
int32 total_steps = 3;
float current_epoch = 4;
float total_epochs = 5;
float loss = 6;
float learning_rate = 7;
float grad_norm = 8;
float eval_loss = 9;
float eta_seconds = 10;
float progress_percent = 11;
string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped
string message = 13;
string checkpoint_path = 14; // set when a checkpoint is saved
string sample_path = 15; // set when a sample is generated (video/image backends)
map<string, float> extra_metrics = 16; // method-specific metrics
}
message FineTuneStopRequest {
string job_id = 1;
bool save_checkpoint = 2;
}
message ListCheckpointsRequest {
string output_dir = 1;
}
message ListCheckpointsResponse {
repeated CheckpointInfo checkpoints = 1;
}
message CheckpointInfo {
string path = 1;
int32 step = 2;
float epoch = 3;
float loss = 4;
string created_at = 5;
}
message ExportModelRequest {
string checkpoint_path = 1;
string output_path = 2;
string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers
string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc.
string model = 5; // base model name (for merge operations)
map<string, string> extra_options = 6;
}
// Quantization messages
message QuantizationRequest {
string model = 1; // HF model name or local path
string quantization_type = 2; // q4_k_m, q5_k_m, q8_0, f16, etc.
string output_dir = 3; // where to write output files
string job_id = 4; // client-assigned job ID
map<string, string> extra_options = 5; // hf_token, custom flags, etc.
}
message QuantizationJobResult {
string job_id = 1;
bool success = 2;
string message = 3;
}
message QuantizationProgressRequest {
string job_id = 1;
}
message QuantizationProgressUpdate {
string job_id = 1;
float progress_percent = 2;
string status = 3; // queued, downloading, converting, quantizing, completed, failed, stopped
string message = 4;
string output_file = 5; // set when completed — path to the output GGUF file
map<string, float> extra_metrics = 6; // e.g. file_size_mb, compression_ratio
}
message QuantizationStopRequest {
string job_id = 1;
}

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=e30f1fdf74ea9238ff562901aa974c75aab6619b LLAMA_VERSION?=95a6ebabb277c4cc18247e7bc2a5502133caca63
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?= CMAKE_ARGS?=

View File

@@ -22,8 +22,10 @@
#include <grpcpp/ext/proto_server_reflection_plugin.h> #include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h> #include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h> #include <grpcpp/health_check_service_interface.h>
#include <grpcpp/security/server_credentials.h>
#include <regex> #include <regex>
#include <atomic> #include <atomic>
#include <cstdlib>
#include <mutex> #include <mutex>
#include <signal.h> #include <signal.h>
#include <thread> #include <thread>
@@ -37,6 +39,47 @@ using grpc::Server;
using grpc::ServerBuilder; using grpc::ServerBuilder;
using grpc::ServerContext; using grpc::ServerContext;
using grpc::Status; using grpc::Status;
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
// requests without a matching "authorization: Bearer <token>" metadata header.
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
public:
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
bool IsBlocking() const override { return false; }
grpc::Status Process(const InputMetadata& auth_metadata,
grpc::AuthContext* /*context*/,
OutputMetadata* /*consumed_auth_metadata*/,
OutputMetadata* /*response_metadata*/) override {
auto it = auth_metadata.find("authorization");
if (it != auth_metadata.end()) {
std::string expected = "Bearer " + token_;
std::string got(it->second.data(), it->second.size());
// Constant-time comparison
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
return grpc::Status::OK;
}
}
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
}
private:
std::string token_;
// Minimal constant-time comparison (avoids OpenSSL dependency)
static int ct_memcmp(const void* a, const void* b, size_t n) {
const unsigned char* pa = static_cast<const unsigned char*>(a);
const unsigned char* pb = static_cast<const unsigned char*>(b);
unsigned char result = 0;
for (size_t i = 0; i < n; i++) {
result |= pa[i] ^ pb[i];
}
return result;
}
};
// END LocalAI // END LocalAI
@@ -136,6 +179,7 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
data["mirostat_eta"] = predict->mirostateta(); data["mirostat_eta"] = predict->mirostateta();
data["n_keep"] = predict->nkeep(); data["n_keep"] = predict->nkeep();
data["seed"] = predict->seed(); data["seed"] = predict->seed();
data["min_p"] = predict->minp();
std::string grammar_str = predict->grammar(); std::string grammar_str = predict->grammar();
@@ -2687,7 +2731,6 @@ public:
tf->set_id_field(ap.tools.format.id_field); tf->set_id_field(ap.tools.format.id_field);
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key); tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped); tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts);
tf->set_function_field(ap.tools.format.function_field); tf->set_function_field(ap.tools.format.function_field);
tf->set_gen_id_field(ap.tools.format.gen_id_field); tf->set_gen_id_field(ap.tools.format.gen_id_field);
@@ -2760,11 +2803,24 @@ int main(int argc, char** argv) {
BackendServiceImpl service(ctx_server); BackendServiceImpl service(ctx_server);
ServerBuilder builder; ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); // Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
std::shared_ptr<grpc::ServerCredentials> creds;
if (auth_token != nullptr && auth_token[0] != '\0') {
creds = grpc::InsecureServerCredentials();
creds->SetAuthMetadataProcessor(
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
} else {
creds = grpc::InsecureServerCredentials();
}
builder.AddListeningPort(server_address, creds);
builder.RegisterService(&service); builder.RegisterService(&service);
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
std::unique_ptr<Server> server(builder.BuildAndStart()); std::unique_ptr<Server> server(builder.BuildAndStart());
// run the HTTP server in a thread - see comment below // run the HTTP server in a thread - see comment below
std::thread t([&]() std::thread t([&]()

View File

@@ -24,6 +24,9 @@ if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
# ARM64 architecture # ARM64 architecture
echo "Detected ARM64 architecture, copying ARM64 libraries..." echo "Detected ARM64 architecture, copying ARM64 libraries..."
@@ -33,6 +36,9 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
else else
echo "Error: Could not detect architecture" echo "Error: Could not detect architecture"
exit 1 exit 1

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# acestep.cpp version # acestep.cpp version
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
ACESTEP_CPP_VERSION?=5aa065445541094cba934299cd498bbb9fa5c434 ACESTEP_CPP_VERSION?=6f35c874ee11e86d511b860019b84976f5b52d3a
SO_TARGET?=libgoacestepcpp.so SO_TARGET?=libgoacestepcpp.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -106,12 +106,13 @@ func TestLoadModel(t *testing.T) {
defer conn.Close() defer conn.Close()
client := pb.NewBackendClient(conn) client := pb.NewBackendClient(conn)
// Get base directory from main model file for relative paths // Get base directory from main model file for relative paths
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf") mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{ resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: mainModelPath, ModelFile: mainModelPath,
ModelPath: modelDir,
Options: []string{ Options: []string{
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf", "text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
"dit_model:acestep-v15-turbo-Q8_0.gguf", "dit_model:acestep-v15-turbo-Q8_0.gguf",
@@ -133,7 +134,7 @@ func TestSoundGeneration(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer os.RemoveAll(tmpDir) t.Cleanup(func() { os.RemoveAll(tmpDir) })
outputFile := filepath.Join(tmpDir, "output.wav") outputFile := filepath.Join(tmpDir, "output.wav")
@@ -151,6 +152,7 @@ func TestSoundGeneration(t *testing.T) {
// Load models // Load models
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{ loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: mainModelPath, ModelFile: mainModelPath,
ModelPath: modelDir,
Options: []string{ Options: []string{
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf", "text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
"dit_model:acestep-v15-turbo-Q8_0.gguf", "dit_model:acestep-v15-turbo-Q8_0.gguf",

View File

@@ -11,7 +11,7 @@ import (
) )
var ( var (
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
) )
@@ -24,23 +24,23 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
lmModel := opts.ModelFile lmModel := opts.ModelFile
// Get the base directory from ModelFile for resolving relative paths // Get the base directory from ModelFile for resolving relative paths
baseDir := filepath.Dir(lmModel) baseDir := opts.ModelPath
var textEncoderModel, ditModel, vaeModel string var textEncoderModel, ditModel, vaeModel string
for _, oo := range opts.Options { for _, oo := range opts.Options {
parts := strings.SplitN(oo, ":", 2) key, value, found := strings.Cut(oo, ":")
if len(parts) != 2 { if !found {
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo) fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
continue continue
} }
switch parts[0] { switch key {
case "text_encoder_model": case "text_encoder_model":
textEncoderModel = parts[1] textEncoderModel = value
case "dit_model": case "dit_model":
ditModel = parts[1] ditModel = value
case "vae_model": case "vae_model":
vaeModel = parts[1] vaeModel = value
default: default:
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo) fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
} }

View File

@@ -18,7 +18,6 @@ type LLM struct {
draftModel *llama.LLama draftModel *llama.LLama
} }
// Free releases GPU resources and frees the llama model // Free releases GPU resources and frees the llama model
// This should be called when the model is being unloaded to properly release VRAM // This should be called when the model is being unloaded to properly release VRAM
func (llm *LLM) Free() error { func (llm *LLM) Free() error {

View File

@@ -1,5 +1,4 @@
//go:build debug //go:build debug
// +build debug
package main package main

View File

@@ -1,5 +1,4 @@
//go:build !debug //go:build !debug
// +build !debug
package main package main

View File

@@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot float32 var dot float32
for i := 0; i < len(k1); i++ { for i := range len(k1) {
dot += k1[i] * k2[i] dot += k1[i] * k2[i]
} }
@@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot, mag2 float64 var dot, mag2 float64
for i := 0; i < len(k1); i++ { for i := range len(k1) {
dot += float64(k1[i] * k2[i]) dot += float64(k1[i] * k2[i])
mag2 += float64(k2[i] * k2[i]) mag2 += float64(k2[i] * k2[i])
} }

View File

@@ -701,7 +701,7 @@ var _ = Describe("Opus", func() {
// to one-shot (only difference is resampler batch boundaries). // to one-shot (only difference is resampler batch boundaries).
var maxDiff float64 var maxDiff float64
var sumDiffSq float64 var sumDiffSq float64
for i := 0; i < minLen; i++ { for i := range minLen {
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i])) diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
if diff > maxDiff { if diff > maxDiff {
maxDiff = diff maxDiff = diff
@@ -774,7 +774,7 @@ var _ = Describe("Opus", func() {
minLen := min(len(refTail), min(len(persistentTail), len(freshTail))) minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
var persistentMaxDiff, freshMaxDiff float64 var persistentMaxDiff, freshMaxDiff float64
for i := 0; i < minLen; i++ { for i := range minLen {
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i])) pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i])) fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
if pd > persistentMaxDiff { if pd > persistentMaxDiff {
@@ -932,7 +932,7 @@ var _ = Describe("Opus", func() {
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n", GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
mean, stddev, stddev/mean, 16000.0/440.0/2.0) mean, stddev, stddev/mean, 16000.0/440.0/2.0)
Expect(stddev / mean).To(BeNumerically("<", 0.15), Expect(stddev/mean).To(BeNumerically("<", 0.15),
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean)) fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
// Also check frequency is correct // Also check frequency is correct
@@ -978,7 +978,7 @@ var _ = Describe("Opus", func() {
// Every sample must be identical — the resampler is deterministic // Every sample must be identical — the resampler is deterministic
var maxDiff float64 var maxDiff float64
for i := 0; i < len(oneShot); i++ { for i := range len(oneShot) {
diff := math.Abs(float64(oneShot[i]) - float64(batched[i])) diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
if diff > maxDiff { if diff > maxDiff {
maxDiff = diff maxDiff = diff
@@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() {
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen)) binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
copy(hdr[8:12], "WAVE") copy(hdr[8:12], "WAVE")
copy(hdr[12:16], "fmt ") copy(hdr[12:16], "fmt ")
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
copy(hdr[36:40], "data") copy(hdr[36:40], "data")
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen)) binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
@@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() {
) )
pcm := make([]byte, toneNumSamples*2) pcm := make([]byte, toneNumSamples*2)
for i := 0; i < toneNumSamples; i++ { for i := range toneNumSamples {
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate))) sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
} }

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# stablediffusion.cpp (ggml) # stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=d6dd6d7b555c233bb9bc9f20b4751eb8c9269743 STABLEDIFFUSION_GGML_VERSION?=87ecb95cbc65dc8e58e3d88f4f4a59a0939796f5
CMAKE_ARGS+=-DGGML_MAX_NAME=128 CMAKE_ARGS+=-DGGML_MAX_NAME=128

View File

@@ -27,107 +27,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <regex> #include <regex>
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
const char* sample_method_str[] = {
"euler",
"euler_a",
"heun",
"dpm2",
"dpm++2s_a",
"dpm++2m",
"dpm++2mv2",
"ipndm",
"ipndm_v",
"lcm",
"ddim_trailing",
"tcd",
"res_multistep",
"res_2s",
};
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
const char* schedulers[] = {
"discrete",
"karras",
"exponential",
"ays",
"gits",
"sgm_uniform",
"simple",
"smoothstep",
"kl_optimal",
"lcm",
"bong_tangent",
};
static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");
// New enum string arrays
const char* rng_type_str[] = {
"std_default",
"cuda",
"cpu",
};
static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch");
const char* prediction_str[] = {
"epsilon",
"v",
"edm_v",
"flow",
"flux_flow",
"flux2_flow",
};
static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch");
const char* lora_apply_mode_str[] = {
"auto",
"immediately",
"at_runtime",
};
static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch");
constexpr const char* sd_type_str[] = {
"f32", // 0
"f16", // 1
"q4_0", // 2
"q4_1", // 3
nullptr, // 4
nullptr, // 5
"q5_0", // 6
"q5_1", // 7
"q8_0", // 8
"q8_1", // 9
"q2_k", // 10
"q3_k", // 11
"q4_k", // 12
"q5_k", // 13
"q6_k", // 14
"q8_k", // 15
"iq2_xxs", // 16
"iq2_xs", // 17
"iq3_xxs", // 18
"iq1_s", // 19
"iq4_nl", // 20
"iq3_s", // 21
"iq2_s", // 22
"iq4_xs", // 23
"i8", // 24
"i16", // 25
"i32", // 26
"i64", // 27
"f64", // 28
"iq1_m", // 29
"bf16", // 30
nullptr, nullptr, nullptr, // 31-33
"tq1_0", // 34
"tq2_0", // 35
nullptr, nullptr, nullptr, // 36-38
"mxfp4" // 39
};
static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch");
sd_ctx_params_t ctx_params; sd_ctx_params_t ctx_params;
sd_ctx_t* sd_c; sd_ctx_t* sd_c;
@@ -596,75 +496,45 @@ int load_model(const char *model, char *model_path, char* options[], int threads
if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval); if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval);
if (!strcmp(optname, "rng_type")) { if (!strcmp(optname, "rng_type")) {
int found = -1; rng_type_t parsed = str_to_rng_type(optval);
for (int m = 0; m < RNG_TYPE_COUNT; m++) { if (parsed != RNG_TYPE_COUNT) {
if (!strcmp(optval, rng_type_str[m])) { rng_type = parsed;
found = m;
break;
}
}
if (found != -1) {
rng_type = (rng_type_t)found;
fprintf(stderr, "Found rng_type: %s\n", optval); fprintf(stderr, "Found rng_type: %s\n", optval);
} else { } else {
fprintf(stderr, "Invalid rng_type: %s, using default\n", optval); fprintf(stderr, "Invalid rng_type: %s, using default\n", optval);
} }
} }
if (!strcmp(optname, "sampler_rng_type")) { if (!strcmp(optname, "sampler_rng_type")) {
int found = -1; rng_type_t parsed = str_to_rng_type(optval);
for (int m = 0; m < RNG_TYPE_COUNT; m++) { if (parsed != RNG_TYPE_COUNT) {
if (!strcmp(optval, rng_type_str[m])) { sampler_rng_type = parsed;
found = m;
break;
}
}
if (found != -1) {
sampler_rng_type = (rng_type_t)found;
fprintf(stderr, "Found sampler_rng_type: %s\n", optval); fprintf(stderr, "Found sampler_rng_type: %s\n", optval);
} else { } else {
fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval); fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval);
} }
} }
if (!strcmp(optname, "prediction")) { if (!strcmp(optname, "prediction")) {
int found = -1; prediction_t parsed = str_to_prediction(optval);
for (int m = 0; m < PREDICTION_COUNT; m++) { if (parsed != PREDICTION_COUNT) {
if (!strcmp(optval, prediction_str[m])) { prediction = parsed;
found = m;
break;
}
}
if (found != -1) {
prediction = (prediction_t)found;
fprintf(stderr, "Found prediction: %s\n", optval); fprintf(stderr, "Found prediction: %s\n", optval);
} else { } else {
fprintf(stderr, "Invalid prediction: %s, using default\n", optval); fprintf(stderr, "Invalid prediction: %s, using default\n", optval);
} }
} }
if (!strcmp(optname, "lora_apply_mode")) { if (!strcmp(optname, "lora_apply_mode")) {
int found = -1; lora_apply_mode_t parsed = str_to_lora_apply_mode(optval);
for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) { if (parsed != LORA_APPLY_MODE_COUNT) {
if (!strcmp(optval, lora_apply_mode_str[m])) { lora_apply_mode = parsed;
found = m;
break;
}
}
if (found != -1) {
lora_apply_mode = (lora_apply_mode_t)found;
fprintf(stderr, "Found lora_apply_mode: %s\n", optval); fprintf(stderr, "Found lora_apply_mode: %s\n", optval);
} else { } else {
fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval); fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval);
} }
} }
if (!strcmp(optname, "wtype")) { if (!strcmp(optname, "wtype")) {
int found = -1; sd_type_t parsed = str_to_sd_type(optval);
for (int m = 0; m < SD_TYPE_COUNT; m++) { if (parsed != SD_TYPE_COUNT) {
if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) { wtype = parsed;
found = m;
break;
}
}
if (found != -1) {
wtype = (sd_type_t)found;
fprintf(stderr, "Found wtype: %s\n", optval); fprintf(stderr, "Found wtype: %s\n", optval);
} else { } else {
fprintf(stderr, "Invalid wtype: %s, using default\n", optval); fprintf(stderr, "Invalid wtype: %s, using default\n", optval);
@@ -735,27 +605,25 @@ int load_model(const char *model, char *model_path, char* options[], int threads
fprintf (stderr, "Created context: OK\n"); fprintf (stderr, "Created context: OK\n");
int sample_method_found = -1; int sample_method_found = -1;
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) { sample_method_t sm = str_to_sample_method(sampler);
if (!strcmp(sampler, sample_method_str[m])) { if (sm != SAMPLE_METHOD_COUNT) {
sample_method_found = m; sample_method_found = (int)sm;
fprintf(stderr, "Found sampler: %s\n", sampler); fprintf(stderr, "Found sampler: %s\n", sampler);
}
} }
if (sample_method_found == -1) { if (sample_method_found == -1) {
sample_method_found = sd_get_default_sample_method(sd_ctx); sample_method_found = sd_get_default_sample_method(sd_ctx);
fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]); fprintf(stderr, "Invalid sample method, using default: %s\n", sd_sample_method_name((sample_method_t)sample_method_found));
} }
sample_method = (sample_method_t)sample_method_found; sample_method = (sample_method_t)sample_method_found;
for (int d = 0; d < SCHEDULER_COUNT; d++) { scheduler_t sched = str_to_scheduler(scheduler_str);
if (!strcmp(scheduler_str, schedulers[d])) { if (sched != SCHEDULER_COUNT) {
scheduler = (scheduler_t)d; scheduler = sched;
fprintf (stderr, "Found scheduler: %s\n", scheduler_str); fprintf(stderr, "Found scheduler: %s\n", scheduler_str);
}
} }
if (scheduler == SCHEDULER_COUNT) { if (scheduler == SCHEDULER_COUNT) {
scheduler = sd_get_default_scheduler(sd_ctx, sample_method); scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]); fprintf(stderr, "Invalid scheduler, using default: %s\n", sd_scheduler_name(scheduler));
} }
sd_c = sd_ctx; sd_c = sd_ctx;

View File

@@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer os.RemoveAll(tmpDir) t.Cleanup(func() { os.RemoveAll(tmpDir) })
// Download sample audio — JFK "ask not what your country can do for you" clip // Download sample audio — JFK "ask not what your country can do for you" clip
audioFile := filepath.Join(tmpDir, "sample.wav") audioFile := filepath.Join(tmpDir, "sample.wav")

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version # whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=30c5194c9691e4e9a98b3dea9f19727397d3f46e WHISPER_CPP_VERSION?=95ea8f9bfb03a15db08a8989966fd1ae3361e20d
SO_TARGET?=libgowhisper.so SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -726,6 +726,7 @@
- TTS - TTS
- &opus - &opus
name: "opus" name: "opus"
alias: "opus"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus"
urls: urls:
- https://opus-codec.org/ - https://opus-codec.org/
@@ -3029,3 +3030,82 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
mirrors: mirrors:
- localai/localai-backends:master-metal-darwin-arm64-voxtral - localai/localai-backends:master-metal-darwin-arm64-voxtral
- &trl
name: "trl"
alias: "trl"
license: apache-2.0
description: |
HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods.
Works on CPU and GPU.
urls:
- https://github.com/huggingface/trl
tags:
- fine-tuning
- LLM
- CPU
- GPU
- CUDA
capabilities:
default: "cpu-trl"
nvidia: "cuda12-trl"
nvidia-cuda-12: "cuda12-trl"
nvidia-cuda-13: "cuda13-trl"
## TRL backend images
- !!merge <<: *trl
name: "cpu-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl"
mirrors:
- localai/localai-backends:latest-cpu-trl
- !!merge <<: *trl
name: "cpu-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl"
mirrors:
- localai/localai-backends:master-cpu-trl
- !!merge <<: *trl
name: "cuda12-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl"
mirrors:
- localai/localai-backends:latest-cublas-cuda12-trl
- !!merge <<: *trl
name: "cuda12-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl"
mirrors:
- localai/localai-backends:master-cublas-cuda12-trl
- !!merge <<: *trl
name: "cuda13-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl"
mirrors:
- localai/localai-backends:latest-cublas-cuda13-trl
- !!merge <<: *trl
name: "cuda13-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
mirrors:
- localai/localai-backends:master-cublas-cuda13-trl
## llama.cpp quantization backend
- &llama-cpp-quantization
name: "llama-cpp-quantization"
alias: "llama-cpp-quantization"
license: mit
icon: https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png
description: |
Model quantization backend using llama.cpp. Downloads HuggingFace models, converts them to GGUF format,
and quantizes them to various formats (q4_k_m, q5_k_m, q8_0, f16, etc.).
urls:
- https://github.com/ggml-org/llama.cpp
tags:
- quantization
- GGUF
- CPU
capabilities:
default: "cpu-llama-cpp-quantization"
metal: "metal-darwin-arm64-llama-cpp-quantization"
- !!merge <<: *llama-cpp-quantization
name: "cpu-llama-cpp-quantization"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp-quantization"
mirrors:
- localai/localai-backends:latest-cpu-llama-cpp-quantization
- !!merge <<: *llama-cpp-quantization
name: "metal-darwin-arm64-llama-cpp-quantization"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp-quantization"
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-llama-cpp-quantization

View File

@@ -19,6 +19,10 @@ import tempfile
import backend_pb2 import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from acestep.inference import ( from acestep.inference import (
GenerationParams, GenerationParams,
GenerationConfig, GenerationConfig,
@@ -444,6 +448,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024), ("grpc.max_send_message_length", 50 * 1024 * 1024),
("grpc.max_receive_message_length", 50 * 1024 * 1024), ("grpc.max_receive_message_length", 50 * 1024 * 1024),
], ],
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View File

@@ -16,6 +16,10 @@ import torchaudio as ta
from chatterbox.tts import ChatterboxTTS from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS from chatterbox.mtl_tts import ChatterboxMultilingualTTS
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import tempfile import tempfile
def is_float(s): def is_float(s):
@@ -225,7 +229,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -0,0 +1,78 @@
"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends.
When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without
a valid Bearer token in the 'authorization' metadata header are rejected with
UNAUTHENTICATED. When the variable is empty or unset, no authentication is
performed (backward compatible).
"""
import hmac
import os
import grpc
class _AbortHandler(grpc.RpcMethodHandler):
"""A method handler that immediately aborts with UNAUTHENTICATED."""
def __init__(self):
self.request_streaming = False
self.response_streaming = False
self.request_deserializer = None
self.response_serializer = None
self.unary_unary = self._abort
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
@staticmethod
def _abort(request, context):
context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token")
class TokenAuthInterceptor(grpc.ServerInterceptor):
"""Sync gRPC server interceptor that validates a bearer token."""
def __init__(self, token: str):
self._token = token
self._abort_handler = _AbortHandler()
def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)
auth = metadata.get("authorization", "")
expected = "Bearer " + self._token
if not hmac.compare_digest(auth, expected):
return self._abort_handler
return continuation(handler_call_details)
class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor):
"""Async gRPC server interceptor that validates a bearer token."""
def __init__(self, token: str):
self._token = token
async def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)
auth = metadata.get("authorization", "")
expected = "Bearer " + self._token
if not hmac.compare_digest(auth, expected):
return _AbortHandler()
return await continuation(handler_call_details)
def get_auth_interceptors(*, aio: bool = False):
"""Return a list of gRPC interceptors for bearer token auth.
Args:
aio: If True, return async-compatible interceptors for grpc.aio.server().
If False (default), return sync interceptors for grpc.server().
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
"""
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
if not token:
return []
if aio:
return [AsyncTokenAuthInterceptor(token)]
return [TokenAuthInterceptor(token)]

View File

@@ -1,3 +1,3 @@
grpcio==1.78.1 grpcio==1.80.0
protobuf protobuf
grpcio-tools grpcio-tools

View File

@@ -15,6 +15,10 @@ import torch
from TTS.api import TTS from TTS.api import TTS
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -93,7 +97,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -1,4 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cpu
transformers==4.48.3 transformers==4.48.3
accelerate accelerate
torch==2.4.1 torch==2.4.1
torchaudio==2.4.1
coqui-tts coqui-tts

View File

@@ -1,4 +1,4 @@
grpcio==1.78.1 grpcio==1.80.0
protobuf protobuf
certifi certifi
packaging==24.1 packaging==24.1

View File

@@ -22,6 +22,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
# Import dynamic loader for pipeline discovery # Import dynamic loader for pipeline discovery
from diffusers_dynamic_loader import ( from diffusers_dynamic_loader import (
@@ -1042,7 +1046,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -15,6 +15,10 @@ import torch
import soundfile as sf import soundfile as sf
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@@ -165,6 +169,8 @@ def serve(address):
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
] ]
,
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View File

@@ -14,6 +14,10 @@ import torch
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -70,7 +74,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -19,6 +19,10 @@ import numpy as np
import json import json
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@@ -424,6 +428,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
], ],
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View File

@@ -16,6 +16,10 @@ from kittentts import KittenTTS
import soundfile as sf import soundfile as sf
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -77,7 +81,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -16,6 +16,10 @@ from kokoro import KPipeline
import soundfile as sf import soundfile as sf
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -84,7 +88,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -21,3 +21,8 @@ if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
fi fi
installRequirements installRequirements
# spaCy is a dependency of misaki (used by kokoro for English phonemization).
# Pre-download the model here because at runtime the portable Python environment
# has no pip/uv, so spacy's auto-download would fail.
python -m spacy download en_core_web_sm

View File

@@ -0,0 +1,26 @@
# Version of llama.cpp to fetch convert_hf_to_gguf.py from
LLAMA_CPP_CONVERT_VERSION ?= master
.PHONY: llama-cpp-quantization
llama-cpp-quantization:
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
.PHONY: run
run: llama-cpp-quantization
@echo "Running llama-cpp-quantization..."
bash run.sh
@echo "llama-cpp-quantization run."
.PHONY: test
test: llama-cpp-quantization
@echo "Testing llama-cpp-quantization..."
bash test.sh
@echo "llama-cpp-quantization tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -0,0 +1,426 @@
#!/usr/bin/env python3
"""
llama.cpp quantization backend for LocalAI.
Downloads HuggingFace models, converts them to GGUF format using
convert_hf_to_gguf.py, and quantizes using llama-quantize.
"""
import argparse
import os
import queue
import re
import signal
import subprocess
import sys
import threading
import time
from concurrent import futures
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2
import backend_pb2_grpc
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
class ActiveJob:
"""Tracks a running quantization job."""
def __init__(self, job_id):
self.job_id = job_id
self.progress_queue = queue.Queue()
self.stop_event = threading.Event()
self.thread = None
self.process = None # subprocess handle for killing
class BackendServicer(backend_pb2_grpc.BackendServicer):
def __init__(self):
self.jobs = {} # job_id -> ActiveJob
def Health(self, request, context):
return backend_pb2.Reply(message=b"OK")
def LoadModel(self, request, context):
"""Accept LoadModel — actual work happens in StartQuantization."""
return backend_pb2.Result(success=True, message="OK")
def StartQuantization(self, request, context):
job_id = request.job_id
if job_id in self.jobs:
return backend_pb2.QuantizationJobResult(
job_id=job_id,
success=False,
message=f"Job {job_id} already exists",
)
job = ActiveJob(job_id)
self.jobs[job_id] = job
job.thread = threading.Thread(
target=self._do_quantization,
args=(job, request),
daemon=True,
)
job.thread.start()
return backend_pb2.QuantizationJobResult(
job_id=job_id,
success=True,
message="Quantization job started",
)
def _send_progress(self, job, status, message, progress_percent=0.0, output_file="", extra_metrics=None):
update = backend_pb2.QuantizationProgressUpdate(
job_id=job.job_id,
progress_percent=progress_percent,
status=status,
message=message,
output_file=output_file,
extra_metrics=extra_metrics or {},
)
job.progress_queue.put(update)
def _do_quantization(self, job, request):
try:
model = request.model
quant_type = request.quantization_type or "q4_k_m"
output_dir = request.output_dir
extra_options = dict(request.extra_options) if request.extra_options else {}
os.makedirs(output_dir, exist_ok=True)
if job.stop_event.is_set():
self._send_progress(job, "stopped", "Job stopped before starting")
return
# Step 1: Download / resolve model
self._send_progress(job, "downloading", f"Resolving model: {model}", progress_percent=0.0)
model_path = self._resolve_model(job, model, output_dir, extra_options)
if model_path is None:
return # error already sent
if job.stop_event.is_set():
self._send_progress(job, "stopped", "Job stopped during download")
return
# Step 2: Convert to f16 GGUF
self._send_progress(job, "converting", "Converting model to GGUF (f16)...", progress_percent=30.0)
f16_gguf_path = os.path.join(output_dir, "model-f16.gguf")
if not self._convert_to_gguf(job, model_path, f16_gguf_path, extra_options):
return # error already sent
if job.stop_event.is_set():
self._send_progress(job, "stopped", "Job stopped during conversion")
return
# Step 3: Quantize
# If the user requested f16, skip quantization — the f16 GGUF is the final output
if quant_type.lower() in ("f16", "fp16"):
output_file = f16_gguf_path
self._send_progress(
job, "completed",
f"Model converted to f16 GGUF: {output_file}",
progress_percent=100.0,
output_file=output_file,
extra_metrics=self._file_metrics(output_file),
)
return
output_file = os.path.join(output_dir, f"model-{quant_type}.gguf")
self._send_progress(job, "quantizing", f"Quantizing to {quant_type}...", progress_percent=50.0)
if not self._quantize(job, f16_gguf_path, output_file, quant_type):
return # error already sent
# Clean up f16 intermediate file to save disk space
try:
os.remove(f16_gguf_path)
except OSError:
pass
self._send_progress(
job, "completed",
f"Quantization complete: {quant_type}",
progress_percent=100.0,
output_file=output_file,
extra_metrics=self._file_metrics(output_file),
)
except Exception as e:
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
def _resolve_model(self, job, model, output_dir, extra_options):
"""Download model from HuggingFace or return local path."""
# If it's a local path that exists, use it directly
if os.path.isdir(model):
return model
# If it looks like a GGUF file path, use it directly
if os.path.isfile(model) and model.endswith(".gguf"):
return model
# Download from HuggingFace
try:
from huggingface_hub import snapshot_download
hf_token = extra_options.get("hf_token") or os.environ.get("HF_TOKEN")
cache_dir = os.path.join(output_dir, "hf_cache")
self._send_progress(job, "downloading", f"Downloading {model} from HuggingFace...", progress_percent=5.0)
local_path = snapshot_download(
repo_id=model,
cache_dir=cache_dir,
token=hf_token,
ignore_patterns=["*.md", "*.txt", "LICENSE*", ".gitattributes"],
)
self._send_progress(job, "downloading", f"Downloaded {model}", progress_percent=25.0)
return local_path
except Exception as e:
error_msg = str(e)
if "gated" in error_msg.lower() or "access" in error_msg.lower():
self._send_progress(
job, "failed",
f"Access denied for {model}. This model may be gated — "
f"please accept the license at https://huggingface.co/{model} "
f"and provide your HF token in extra_options.",
)
else:
self._send_progress(job, "failed", f"Failed to download model: {error_msg}")
return None
def _convert_to_gguf(self, job, model_path, output_path, extra_options):
"""Convert HF model to f16 GGUF using convert_hf_to_gguf.py."""
# If the model_path is already a GGUF file, just use it as-is
if isinstance(model_path, str) and model_path.endswith(".gguf"):
# Copy or symlink the GGUF file
import shutil
shutil.copy2(model_path, output_path)
return True
# Find convert_hf_to_gguf.py
convert_script = self._find_convert_script()
if convert_script is None:
self._send_progress(job, "failed", "convert_hf_to_gguf.py not found. Install it via the backend's install.sh.")
return False
cmd = [
sys.executable, convert_script,
model_path,
"--outfile", output_path,
"--outtype", "f16",
]
self._send_progress(job, "converting", "Running convert_hf_to_gguf.py...", progress_percent=35.0)
try:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
job.process = process
for line in process.stdout:
line = line.strip()
if line:
self._send_progress(job, "converting", line, progress_percent=40.0)
if job.stop_event.is_set():
process.kill()
self._send_progress(job, "stopped", "Job stopped during conversion")
return False
process.wait()
job.process = None
if process.returncode != 0:
self._send_progress(job, "failed", f"convert_hf_to_gguf.py failed with exit code {process.returncode}")
return False
return True
except Exception as e:
self._send_progress(job, "failed", f"Conversion failed: {str(e)}")
return False
def _quantize(self, job, input_path, output_path, quant_type):
"""Quantize a GGUF file using llama-quantize."""
quantize_bin = self._find_quantize_binary()
if quantize_bin is None:
self._send_progress(job, "failed", "llama-quantize binary not found. Ensure it is installed and in PATH.")
return False
cmd = [quantize_bin, input_path, output_path, quant_type]
self._send_progress(job, "quantizing", f"Running llama-quantize ({quant_type})...", progress_percent=55.0)
try:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
job.process = process
for line in process.stdout:
line = line.strip()
if line:
# Try to parse progress from llama-quantize output
progress = self._parse_quantize_progress(line)
pct = 55.0 + (progress * 0.40) if progress else 60.0
self._send_progress(job, "quantizing", line, progress_percent=pct)
if job.stop_event.is_set():
process.kill()
self._send_progress(job, "stopped", "Job stopped during quantization")
return False
process.wait()
job.process = None
if process.returncode != 0:
self._send_progress(job, "failed", f"llama-quantize failed with exit code {process.returncode}")
return False
return True
except Exception as e:
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
return False
def _parse_quantize_progress(self, line):
"""Try to parse a progress percentage from llama-quantize output."""
# llama-quantize typically outputs lines like:
# [ 123/ 1234] quantizing blk.0.attn_k.weight ...
match = re.search(r'\[\s*(\d+)\s*/\s*(\d+)\]', line)
if match:
current = int(match.group(1))
total = int(match.group(2))
if total > 0:
return current / total
return None
def _find_convert_script(self):
"""Find convert_hf_to_gguf.py in known locations."""
candidates = [
# Same directory as this backend
os.path.join(os.path.dirname(__file__), "convert_hf_to_gguf.py"),
# Installed via install.sh
os.path.join(os.path.dirname(os.path.abspath(__file__)), "convert_hf_to_gguf.py"),
]
# Also check if it's on PATH
import shutil
path_script = shutil.which("convert_hf_to_gguf.py")
if path_script:
candidates.append(path_script)
for candidate in candidates:
if os.path.isfile(candidate):
return candidate
return None
def _find_quantize_binary(self):
"""Find llama-quantize binary."""
import shutil
# Check common names on PATH
for name in ["llama-quantize", "quantize"]:
path = shutil.which(name)
if path:
return path
# Check in the backend directory (built by install.sh)
backend_dir = os.path.dirname(os.path.abspath(__file__))
for name in ["llama-quantize", "quantize"]:
candidate = os.path.join(backend_dir, name)
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
return candidate
return None
def _file_metrics(self, filepath):
"""Return file size metrics."""
try:
size_bytes = os.path.getsize(filepath)
return {"file_size_mb": size_bytes / (1024 * 1024)}
except OSError:
return {}
def QuantizationProgress(self, request, context):
job_id = request.job_id
job = self.jobs.get(job_id)
if job is None:
context.abort(grpc.StatusCode.NOT_FOUND, f"Job {job_id} not found")
return
while True:
try:
update = job.progress_queue.get(timeout=1.0)
yield update
# If this is a terminal status, stop streaming
if update.status in ("completed", "failed", "stopped"):
break
except queue.Empty:
# Check if the thread is still alive
if job.thread and not job.thread.is_alive():
# Thread finished but no terminal update — drain queue
while not job.progress_queue.empty():
update = job.progress_queue.get_nowait()
yield update
break
# Check if client disconnected
if context.is_active() is False:
break
def StopQuantization(self, request, context):
job_id = request.job_id
job = self.jobs.get(job_id)
if job is None:
return backend_pb2.Result(success=False, message=f"Job {job_id} not found")
job.stop_event.set()
if job.process:
try:
job.process.kill()
except OSError:
pass
return backend_pb2.Result(success=True, message="Stop signal sent")
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print(f"Quantization backend listening on {address}", file=sys.stderr, flush=True)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="llama.cpp quantization gRPC backend")
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
args = parser.parse_args()
signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0))
serve(args.addr)

View File

@@ -0,0 +1,58 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade "
installRequirements
# Fetch convert_hf_to_gguf.py from llama.cpp
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
if [ ! -f "${CONVERT_SCRIPT}" ]; then
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
curl -L --fail --retry 3 \
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py."
fi
# Install gguf package from the same llama.cpp commit to keep them in sync
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
if [ "x${USE_PIP:-}" == "xtrue" ]; then
pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
pip install "gguf>=0.16.0"
}
else
uv pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
uv pip install "gguf>=0.16.0"
}
fi
# Build llama-quantize from llama.cpp if not already present
QUANTIZE_BIN="${EDIR}/llama-quantize"
if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then
if command -v cmake &>/dev/null; then
echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
LLAMA_CPP_SRC="${EDIR}/llama.cpp"
if [ ! -d "${LLAMA_CPP_SRC}" ]; then
git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \
https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}"
fi
cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF
cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)"
cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}"
chmod +x "${QUANTIZE_BIN}"
echo "Built llama-quantize at ${QUANTIZE_BIN}"
else
echo "Warning: cmake not found — llama-quantize will not be available. Install cmake or provide llama-quantize on PATH."
fi
fi

View File

@@ -0,0 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.10.0
transformers>=4.56.2
huggingface-hub>=1.3.0
sentencepiece

View File

@@ -0,0 +1,4 @@
torch==2.10.0
transformers>=4.56.2
huggingface-hub>=1.3.0
sentencepiece

View File

@@ -0,0 +1,3 @@
grpcio==1.78.1
protobuf
certifi

View File

@@ -0,0 +1,10 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
startBackend $@

View File

@@ -0,0 +1,99 @@
"""
Test script for the llama-cpp-quantization gRPC backend.
Downloads a small model (functiongemma-270m-it), converts it to GGUF,
and quantizes it to q4_k_m.
"""
import os
import shutil
import subprocess
import tempfile
import time
import unittest
import grpc
import backend_pb2
import backend_pb2_grpc
SERVER_ADDR = "localhost:50051"
# Small model for CI testing (~540MB)
TEST_MODEL = "unsloth/functiongemma-270m-it"
class TestQuantizationBackend(unittest.TestCase):
"""Tests for the llama-cpp-quantization gRPC service."""
@classmethod
def setUpClass(cls):
cls.service = subprocess.Popen(
["python3", "backend.py", "--addr", SERVER_ADDR]
)
time.sleep(5)
cls.output_dir = tempfile.mkdtemp(prefix="quantize-test-")
@classmethod
def tearDownClass(cls):
cls.service.kill()
cls.service.wait()
# Clean up output directory
if os.path.isdir(cls.output_dir):
shutil.rmtree(cls.output_dir, ignore_errors=True)
def _channel(self):
return grpc.insecure_channel(SERVER_ADDR)
def test_01_health(self):
"""Test that the server starts and responds to health checks."""
with self._channel() as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b"OK")
def test_02_quantize_small_model(self):
"""Download, convert, and quantize functiongemma-270m-it to q4_k_m."""
with self._channel() as channel:
stub = backend_pb2_grpc.BackendStub(channel)
job_id = "test-quantize-001"
# Start quantization
result = stub.StartQuantization(
backend_pb2.QuantizationRequest(
model=TEST_MODEL,
quantization_type="q4_k_m",
output_dir=self.output_dir,
job_id=job_id,
)
)
self.assertTrue(result.success, f"StartQuantization failed: {result.message}")
self.assertEqual(result.job_id, job_id)
# Stream progress until completion
final_status = None
output_file = None
for update in stub.QuantizationProgress(
backend_pb2.QuantizationProgressRequest(job_id=job_id)
):
print(f" [{update.status}] {update.progress_percent:.1f}% - {update.message}")
final_status = update.status
if update.output_file:
output_file = update.output_file
self.assertEqual(final_status, "completed", f"Expected completed, got {final_status}")
self.assertIsNotNone(output_file, "No output_file in progress updates")
self.assertTrue(os.path.isfile(output_file), f"Output file not found: {output_file}")
# Verify the output is a valid GGUF file (starts with "GGUF" magic)
with open(output_file, "rb") as f:
magic = f.read(4)
self.assertEqual(magic, b"GGUF", f"Output file does not have GGUF magic: {magic!r}")
# Verify reasonable file size (q4_k_m of 270M model should be ~150-400MB)
size_mb = os.path.getsize(output_file) / (1024 * 1024)
print(f" Output file size: {size_mb:.1f} MB")
self.assertGreater(size_mb, 10, "Output file suspiciously small")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests

View File

@@ -15,6 +15,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_audio.tts.utils import load_model from mlx_audio.tts.utils import load_model
import soundfile as sf import soundfile as sf
import numpy as np import numpy as np
@@ -436,7 +440,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View File

@@ -23,6 +23,10 @@ import tempfile
from typing import List from typing import List
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2 import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
@@ -468,6 +472,8 @@ async def serve(address):
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
], ],
interceptors=get_auth_interceptors(aio=True),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View File

@@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_vlm import load, generate, stream_generate from mlx_vlm import load, generate, stream_generate
from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config, load_image from mlx_vlm.utils import load_config, load_image
@@ -446,7 +450,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View File

@@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from mlx_lm import load, generate, stream_generate from mlx_lm import load, generate, stream_generate
from mlx_lm.sample_utils import make_sampler from mlx_lm.sample_utils import make_sampler
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
@@ -421,7 +425,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View File

@@ -17,6 +17,10 @@ from moonshine_voice import (
) )
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -128,7 +132,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -14,6 +14,10 @@ import torch
import nemo.collections.asr as nemo_asr import nemo.collections.asr as nemo_asr
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@@ -119,7 +123,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -4,3 +4,4 @@ certifi
packaging==24.1 packaging==24.1
setuptools setuptools
pyarrow==20.0.0 pyarrow==20.0.0
pybind11

View File

@@ -15,6 +15,10 @@ from neuttsair.neutts import NeuTTSAir
import soundfile as sf import soundfile as sf
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
"""Check if a string can be converted to float.""" """Check if a string can be converted to float."""
@@ -130,7 +134,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -14,6 +14,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import outetts import outetts
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -116,7 +120,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
]) ],
interceptors=get_auth_interceptors(aio=True),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View File

@@ -16,6 +16,10 @@ import torch
from pocket_tts import TTSModel from pocket_tts import TTSModel
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
"""Check if a string can be converted to float.""" """Check if a string can be converted to float."""
@@ -225,7 +229,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -14,6 +14,10 @@ import torch
from qwen_asr import Qwen3ASRModel from qwen_asr import Qwen3ASRModel
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@@ -184,7 +188,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -23,6 +23,10 @@ import hashlib
import pickle import pickle
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
@@ -900,6 +904,8 @@ def serve(address):
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
], ],
interceptors=get_auth_interceptors(),
) )
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)

View File

@@ -14,6 +14,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from rerankers import Reranker from rerankers import Reranker
@@ -97,7 +101,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -1,3 +1,3 @@
grpcio==1.78.1 grpcio==1.80.0
protobuf protobuf
certifi certifi

View File

@@ -13,6 +13,10 @@ import base64
import backend_pb2 import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import requests import requests
@@ -139,7 +143,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -16,16 +16,22 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import torch import torch
import torch.cuda import torch.cuda
XPU=os.environ.get("XPU", "0") == "1" XPU=os.environ.get("XPU", "0") == "1"
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM import transformers as transformers_module
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
from scipy.io import wavfile from scipy.io import wavfile
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
# Backward-compat aliases for model types
TYPE_ALIASES = {"Mamba": "MambaForCausalLM"}
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -52,32 +58,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding. This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
""" """
def Health(self, request, context): def Health(self, request, context):
"""
A gRPC method that returns the health status of the backend service.
Args:
request: A HealthRequest object that contains the request parameters.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
A Reply object that contains the health status of the backend service.
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8')) return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context): def LoadModel(self, request, context):
"""
A gRPC method that loads a model into memory.
Args:
request: A LoadModelRequest object that contains the request parameters.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
A Result object that contains the result of the LoadModel operation.
"""
model_name = request.Model model_name = request.Model
# Check to see if the Model exists in the filesystem already. # Check to see if the Model exists in the filesystem already.
if os.path.exists(request.ModelFile): if os.path.exists(request.ModelFile):
model_name = request.ModelFile model_name = request.ModelFile
@@ -88,8 +73,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.CUDA = torch.cuda.is_available() self.CUDA = torch.cuda.is_available()
self.OV=False self.OV=False
self.DiaTTS=False self.GenericTTS=False
self.SentenceTransformer = False self.SentenceTransformer = False
self.processor = None
device_map="cpu" device_map="cpu"
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -101,7 +87,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Parse options from request.Options # Parse options from request.Options
self.options = {} self.options = {}
options = request.Options options = request.Options
# The options are a list of strings in this form optname:optvalue # The options are a list of strings in this form optname:optvalue
# We are storing all the options in a dict so we can use it later when generating # We are storing all the options in a dict so we can use it later when generating
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"] # Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
@@ -123,7 +109,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
print(f"Parsed options: {self.options}", file=sys.stderr) print(f"Parsed options: {self.options}", file=sys.stderr)
if self.CUDA: if self.CUDA:
from transformers import BitsAndBytesConfig, AutoModelForCausalLM from transformers import BitsAndBytesConfig
if request.MainGPU: if request.MainGPU:
device_map=request.MainGPU device_map=request.MainGPU
else: else:
@@ -140,40 +126,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
quantization = BitsAndBytesConfig( quantization = BitsAndBytesConfig(
load_in_4bit=False, load_in_4bit=False,
bnb_4bit_compute_dtype = None, bnb_4bit_compute_dtype = None,
load_in_8bit=True, load_in_8bit=True,
) )
try: try:
if request.Type == "AutoModelForCausalLM": if XPU and request.Type == "AutoModelForCausalLM":
if XPU: import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
device_map="xpu" device_map="xpu"
compute=torch.float16 compute=torch.float16
if request.Quantization == "xpu_4bit": if request.Quantization == "xpu_4bit":
xpu_4bit = True xpu_4bit = True
xpu_8bit = False xpu_8bit = False
elif request.Quantization == "xpu_8bit": elif request.Quantization == "xpu_8bit":
xpu_4bit = False xpu_4bit = False
xpu_8bit = True xpu_8bit = True
else:
xpu_4bit = False
xpu_8bit = False
self.model = AutoModelForCausalLM.from_pretrained(model_name,
trust_remote_code=request.TrustRemoteCode,
use_safetensors=True,
device_map=device_map,
load_in_4bit=xpu_4bit,
load_in_8bit=xpu_8bit,
torch_dtype=compute)
else: else:
self.model = AutoModelForCausalLM.from_pretrained(model_name, xpu_4bit = False
trust_remote_code=request.TrustRemoteCode, xpu_8bit = False
use_safetensors=True, self.model = AutoModelForCausalLM.from_pretrained(model_name,
quantization_config=quantization, trust_remote_code=request.TrustRemoteCode,
device_map=device_map, device_map=device_map,
torch_dtype=compute) load_in_4bit=xpu_4bit,
load_in_8bit=xpu_8bit,
torch_dtype=compute)
elif request.Type == "OVModelForCausalLM": elif request.Type == "OVModelForCausalLM":
from optimum.intel.openvino import OVModelForCausalLM from optimum.intel.openvino import OVModelForCausalLM
from openvino.runtime import Core from openvino.runtime import Core
@@ -185,14 +162,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
devices = Core().available_devices devices = Core().available_devices
if "GPU" in " ".join(devices): if "GPU" in " ".join(devices):
device_map="AUTO:GPU" device_map="AUTO:GPU"
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
if "CPU" or "NPU" in device_map: if "CPU" or "NPU" in device_map:
if "-CPU" or "-NPU" not in device_map: if "-CPU" or "-NPU" not in device_map:
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"} ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
else: else:
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"} ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
self.model = OVModelForCausalLM.from_pretrained(model_name, self.model = OVModelForCausalLM.from_pretrained(model_name,
compile=True, compile=True,
trust_remote_code=request.TrustRemoteCode, trust_remote_code=request.TrustRemoteCode,
ov_config=ovconfig, ov_config=ovconfig,
@@ -209,59 +184,60 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
devices = Core().available_devices devices = Core().available_devices
if "GPU" in " ".join(devices): if "GPU" in " ".join(devices):
device_map="AUTO:GPU" device_map="AUTO:GPU"
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
if "CPU" or "NPU" in device_map: if "CPU" or "NPU" in device_map:
if "-CPU" or "-NPU" not in device_map: if "-CPU" or "-NPU" not in device_map:
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"} ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
else: else:
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"} ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
self.model = OVModelForFeatureExtraction.from_pretrained(model_name, self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
compile=True, compile=True,
trust_remote_code=request.TrustRemoteCode, trust_remote_code=request.TrustRemoteCode,
ov_config=ovconfig, ov_config=ovconfig,
export=True, export=True,
device=device_map) device=device_map)
self.OV = True self.OV = True
elif request.Type == "MusicgenForConditionalGeneration":
autoTokenizer = False
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
elif request.Type == "DiaForConditionalGeneration":
autoTokenizer = False
print("DiaForConditionalGeneration", file=sys.stderr)
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
if self.CUDA:
self.model = self.model.to("cuda")
self.processor = self.processor.to("cuda")
print("DiaForConditionalGeneration loaded", file=sys.stderr)
self.DiaTTS = True
elif request.Type == "SentenceTransformer": elif request.Type == "SentenceTransformer":
autoTokenizer = False autoTokenizer = False
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode) self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
self.SentenceTransformer = True self.SentenceTransformer = True
elif request.Type == "Mamba":
autoTokenizer = False
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = MambaForCausalLM.from_pretrained(model_name)
else: else:
print("Automodel", file=sys.stderr) # Generic: dynamically resolve model class from transformers
self.model = AutoModel.from_pretrained(model_name, model_type = TYPE_ALIASES.get(request.Type, request.Type)
trust_remote_code=request.TrustRemoteCode, ModelClass = AutoModel # default
use_safetensors=True, if model_type and hasattr(transformers_module, model_type):
quantization_config=quantization, ModelClass = getattr(transformers_module, model_type)
device_map=device_map, print(f"Using model class: {model_type}", file=sys.stderr)
torch_dtype=compute) else:
print(f"Using default AutoModel (type={request.Type!r})", file=sys.stderr)
self.model = ModelClass.from_pretrained(
model_name,
trust_remote_code=request.TrustRemoteCode,
quantization_config=quantization,
device_map=device_map,
torch_dtype=compute,
)
# Try to load a processor (needed for TTS/audio models)
try:
self.processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=request.TrustRemoteCode,
)
self.GenericTTS = True
print(f"Loaded processor for {model_name}", file=sys.stderr)
except Exception:
self.processor = None
if request.ContextSize > 0: if request.ContextSize > 0:
self.max_tokens = request.ContextSize self.max_tokens = request.ContextSize
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'): elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
self.max_tokens = self.model.config.max_position_embeddings self.max_tokens = self.model.config.max_position_embeddings
else: else:
self.max_tokens = self.options.get("max_new_tokens", 512) self.max_tokens = self.options.get("max_new_tokens", 512)
if autoTokenizer: if autoTokenizer:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.XPU = False self.XPU = False
if XPU and self.OV == False: if XPU and self.OV == False:
@@ -275,22 +251,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
except Exception as err: except Exception as err:
print("Error:", err, file=sys.stderr) print("Error:", err, file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
# Replace this with your desired response
return backend_pb2.Result(message="Model loaded successfully", success=True) return backend_pb2.Result(message="Model loaded successfully", success=True)
def Embedding(self, request, context): def Embedding(self, request, context):
"""
A gRPC method that calculates embeddings for a given sentence.
Args:
request: An EmbeddingRequest object that contains the request parameters.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
An EmbeddingResult object that contains the calculated embeddings.
"""
set_seed(request.Seed) set_seed(request.Seed)
# Tokenize input # Tokenize input
max_length = 512 max_length = 512
@@ -303,13 +266,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
embeds = self.model.encode(request.Embeddings) embeds = self.model.encode(request.Embeddings)
else: else:
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt") encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
# Create word embeddings # Create word embeddings
if self.CUDA: if self.CUDA:
encoded_input = encoded_input.to("cuda") encoded_input = encoded_input.to("cuda")
with torch.no_grad(): with torch.no_grad():
model_output = self.model(**encoded_input) model_output = self.model(**encoded_input)
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence # Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
@@ -317,11 +280,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
embeds = sentence_embeddings[0] embeds = sentence_embeddings[0]
return backend_pb2.EmbeddingResult(embeddings=embeds) return backend_pb2.EmbeddingResult(embeddings=embeds)
async def _predict(self, request, context, streaming=False): async def _predict(self, request, context, streaming=False):
set_seed(request.Seed) set_seed(request.Seed)
if request.TopP < 0 or request.TopP > 1: if request.TopP < 0 or request.TopP > 1:
request.TopP = 1 request.TopP = 1
if request.TopK <= 0: if request.TopK <= 0:
request.TopK = 50 request.TopK = 50
@@ -334,7 +297,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
request.Temperature == None request.Temperature == None
prompt = request.Prompt prompt = request.Prompt
if not request.Prompt and request.UseTokenizerTemplate and request.Messages: if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(prompt, return_tensors="pt") inputs = self.tokenizer(prompt, return_tensors="pt")
@@ -363,10 +326,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
skip_prompt=True, skip_prompt=True,
skip_special_tokens=True) skip_special_tokens=True)
config=dict(inputs, config=dict(inputs,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
temperature=request.Temperature, temperature=request.Temperature,
top_p=request.TopP, top_p=request.TopP,
top_k=request.TopK, top_k=request.TopK,
do_sample=sample, do_sample=sample,
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
eos_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
@@ -387,18 +350,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
else: else:
if XPU and self.OV == False: if XPU and self.OV == False:
outputs = self.model.generate(inputs["input_ids"], outputs = self.model.generate(inputs["input_ids"],
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
temperature=request.Temperature, temperature=request.Temperature,
top_p=request.TopP, top_p=request.TopP,
top_k=request.TopK, top_k=request.TopK,
do_sample=sample, do_sample=sample,
pad_token=self.tokenizer.eos_token_id) pad_token=self.tokenizer.eos_token_id)
else: else:
outputs = self.model.generate(**inputs, outputs = self.model.generate(**inputs,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
temperature=request.Temperature, temperature=request.Temperature,
top_p=request.TopP, top_p=request.TopP,
top_k=request.TopK, top_k=request.TopK,
do_sample=sample, do_sample=sample,
eos_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.eos_token_id,
@@ -413,31 +376,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
async def Predict(self, request, context): async def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters.
Args:
request: The predict request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The predict result.
"""
gen = self._predict(request, context, streaming=False) gen = self._predict(request, context, streaming=False)
res = await gen.__anext__() res = await gen.__anext__()
return res return res
async def PredictStream(self, request, context): async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results.
Args:
request: The predict stream request.
context: The gRPC context.
Returns:
backend_pb2.Result: The predict stream result.
"""
iterations = self._predict(request, context, streaming=True) iterations = self._predict(request, context, streaming=True)
try: try:
async for iteration in iterations: async for iteration in iterations:
@@ -455,18 +398,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if self.model is None: if self.model is None:
if model_name == "": if model_name == "":
return backend_pb2.Result(success=False, message="request.model is required") return backend_pb2.Result(success=False, message="request.model is required")
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) # Dynamically resolve model class if configured, otherwise default to MusicgenForConditionalGeneration
model_type = self.options.get("model_type", "MusicgenForConditionalGeneration")
ModelClass = getattr(transformers_module, model_type)
self.model = ModelClass.from_pretrained(model_name)
inputs = None inputs = None
if request.text == "": if request.text == "":
inputs = self.model.get_unconditional_inputs(num_samples=1) inputs = self.model.get_unconditional_inputs(num_samples=1)
elif request.HasField('src'): elif request.HasField('src'):
# TODO SECURITY CODE GOES HERE LOL
# WHO KNOWS IF THIS WORKS???
sample_rate, wsamples = wavfile.read('path_to_your_file.wav') sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
if request.HasField('src_divisor'): if request.HasField('src_divisor'):
wsamples = wsamples[: len(wsamples) // request.src_divisor] wsamples = wsamples[: len(wsamples) // request.src_divisor]
inputs = self.processor( inputs = self.processor(
audio=wsamples, audio=wsamples,
sampling_rate=sample_rate, sampling_rate=sample_rate,
@@ -480,7 +424,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
) )
if request.HasField('duration'): if request.HasField('duration'):
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
guidance = self.options.get("guidance_scale", 3.0) guidance = self.options.get("guidance_scale", 3.0)
@@ -490,92 +434,97 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.HasField('sample'): if request.HasField('sample'):
dosample = request.sample dosample = request.sample
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens) audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr) print("[transformers] SoundGeneration generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) # Save audio output
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr) if hasattr(self.processor, 'save_audio'):
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr) if hasattr(self.processor, 'batch_decode'):
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr) try:
audio_values = self.processor.batch_decode(audio_values)
except Exception:
pass
self.processor.save_audio(audio_values, request.dst)
else:
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
print("[transformers] SoundGeneration saved to", request.dst, file=sys.stderr)
print(request, file=sys.stderr) print(request, file=sys.stderr)
except Exception as err: except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True) return backend_pb2.Result(success=True)
def CallDiaTTS(self, request, context):
"""
Generates dialogue audio using the Dia model.
Args:
request: A TTSRequest containing text dialogue and generation parameters
context: The gRPC context
Returns:
A Result object indicating success or failure
"""
try:
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
text = [request.text]
# Process the input
inputs = self.processor(text=text, padding=True, return_tensors="pt")
# Generate audio with parameters from options or defaults
generation_params = {
**inputs,
"max_new_tokens": self.max_tokens,
"guidance_scale": self.options.get("guidance_scale", 3.0),
"temperature": self.options.get("temperature", 1.8),
"top_p": self.options.get("top_p", 0.90),
"top_k": self.options.get("top_k", 45)
}
outputs = self.model.generate(**generation_params)
# Decode and save audio
outputs = self.processor.batch_decode(outputs)
self.processor.save_audio(outputs, request.dst)
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
def TTS(self, request, context): def TTS(self, request, context):
if self.DiaTTS:
print("DiaTTS", file=sys.stderr)
return self.CallDiaTTS(request, context)
model_name = request.model
try: try:
if self.processor is None: text = request.text
if model_name == "": print(f"[transformers] TTS generating for text: {text[:100]}...", file=sys.stderr)
return backend_pb2.Result(success=False, message="request.model is required")
self.processor = AutoProcessor.from_pretrained(model_name) # Build inputs based on processor capabilities
if self.model is None: if request.voice and os.path.exists(request.voice):
if model_name == "": # Voice cloning: use chat template with reference audio
return backend_pb2.Result(success=False, message="request.model is required") chat_template = [{
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) "role": "0",
inputs = self.processor( "content": [
text=[request.text], {"type": "text", "text": text},
padding=True, {"type": "audio", "path": request.voice},
return_tensors="pt", ],
) }]
tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default inputs = self.processor.apply_chat_template(
audio_values = self.model.generate(**inputs, max_new_tokens=tokens) chat_template, tokenize=True, return_dict=True,
print("[transformers-musicgen] TTS generated!", file=sys.stderr) ).to(self.model.device, self.model.dtype)
sampling_rate = self.model.config.audio_encoder.sampling_rate elif hasattr(self.processor, 'apply_chat_template'):
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) # Models that use chat template format (VibeVoice, CSM, etc.)
print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr) chat_template = [{"role": "0", "content": [{"type": "text", "text": text}]}]
print("[transformers-musicgen] TTS for", file=sys.stderr) try:
print(request, file=sys.stderr) inputs = self.processor.apply_chat_template(
chat_template, tokenize=True, return_dict=True,
).to(self.model.device, self.model.dtype)
except Exception:
# Fallback if chat template fails (not all processors support it)
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
if self.CUDA:
inputs = inputs.to("cuda")
else:
# Direct processor call (Musicgen, etc.)
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
if self.CUDA:
inputs = inputs.to("cuda")
# Build generation kwargs from self.options
gen_kwargs = {**inputs, "max_new_tokens": self.max_tokens}
for key in ["guidance_scale", "temperature", "top_p", "top_k", "do_sample"]:
if key in self.options:
gen_kwargs[key] = self.options[key]
# Add noise scheduler if configured (e.g., for VibeVoice)
noise_scheduler_type = self.options.get("noise_scheduler", None)
if noise_scheduler_type:
import diffusers
SchedulerClass = getattr(diffusers, noise_scheduler_type)
scheduler_kwargs = {}
for key in ["beta_schedule", "prediction_type"]:
if key in self.options:
scheduler_kwargs[key] = self.options[key]
gen_kwargs["noise_scheduler"] = SchedulerClass(**scheduler_kwargs)
# Generate audio
audio = self.model.generate(**gen_kwargs)
print("[transformers] TTS generated!", file=sys.stderr)
# Save audio output
if hasattr(self.processor, 'save_audio'):
if hasattr(self.processor, 'batch_decode'):
try:
audio = self.processor.batch_decode(audio)
except Exception:
pass
self.processor.save_audio(audio, request.dst)
else:
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio[0, 0].numpy())
print("[transformers] TTS saved to", request.dst, file=sys.stderr)
except Exception as err: except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True) return backend_pb2.Result(success=True)
@@ -587,7 +536,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View File

@@ -2,7 +2,9 @@ torch==2.7.1
llvmlite==0.43.0 llvmlite==0.43.0
numba==0.60.0 numba==0.60.0
accelerate accelerate
transformers transformers>=5.0.0
bitsandbytes bitsandbytes
sentence-transformers==5.2.3 sentence-transformers==5.2.3
diffusers
soundfile
protobuf==6.33.5 protobuf==6.33.5

View File

@@ -2,7 +2,9 @@ torch==2.7.1
accelerate accelerate
llvmlite==0.43.0 llvmlite==0.43.0
numba==0.60.0 numba==0.60.0
transformers transformers>=5.0.0
bitsandbytes bitsandbytes
sentence-transformers==5.2.3 sentence-transformers==5.2.3
diffusers
soundfile
protobuf==6.33.5 protobuf==6.33.5

View File

@@ -2,7 +2,9 @@
torch==2.9.0 torch==2.9.0
llvmlite==0.43.0 llvmlite==0.43.0
numba==0.60.0 numba==0.60.0
transformers transformers>=5.0.0
bitsandbytes bitsandbytes
sentence-transformers==5.2.3 sentence-transformers==5.2.3
diffusers
soundfile
protobuf==6.33.5 protobuf==6.33.5

View File

@@ -1,9 +1,11 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.4 --extra-index-url https://download.pytorch.org/whl/rocm6.4
torch==2.8.0+rocm6.4 torch==2.8.0+rocm6.4
accelerate accelerate
transformers transformers>=5.0.0
llvmlite==0.43.0 llvmlite==0.43.0
numba==0.60.0 numba==0.60.0
bitsandbytes bitsandbytes
sentence-transformers==5.2.3 sentence-transformers==5.2.3
diffusers
soundfile
protobuf==6.33.5 protobuf==6.33.5

View File

@@ -3,7 +3,9 @@ torch
optimum[openvino] optimum[openvino]
llvmlite==0.43.0 llvmlite==0.43.0
numba==0.60.0 numba==0.60.0
transformers transformers>=5.0.0
bitsandbytes bitsandbytes
sentence-transformers==5.2.3 sentence-transformers==5.2.3
diffusers
soundfile
protobuf==6.33.5 protobuf==6.33.5

View File

@@ -2,7 +2,9 @@ torch==2.7.1
llvmlite==0.43.0 llvmlite==0.43.0
numba==0.60.0 numba==0.60.0
accelerate accelerate
transformers transformers>=5.0.0
bitsandbytes bitsandbytes
sentence-transformers==5.2.3 sentence-transformers==5.2.3
diffusers
soundfile
protobuf==6.33.5 protobuf==6.33.5

View File

@@ -1,4 +1,4 @@
grpcio==1.78.1 grpcio==1.80.0
protobuf==6.33.5 protobuf==6.33.5
certifi certifi
setuptools setuptools

View File

@@ -0,0 +1,26 @@
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
LLAMA_CPP_CONVERT_VERSION ?= master
.PHONY: trl
trl:
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
.PHONY: run
run: trl
@echo "Running trl..."
bash run.sh
@echo "trl run."
.PHONY: test
test: trl
@echo "Testing trl..."
bash test.sh
@echo "trl tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -0,0 +1,866 @@
#!/usr/bin/env python3
"""
TRL fine-tuning backend for LocalAI.
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
"""
import argparse
import json
import os
import queue
import signal
import sys
import threading
import time
import uuid
from concurrent import futures
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
import backend_pb2
import backend_pb2_grpc
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
class ProgressCallback:
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
def __init__(self, job_id, progress_queue, total_epochs):
self.job_id = job_id
self.progress_queue = progress_queue
self.total_epochs = total_epochs
def get_callback(self):
from transformers import TrainerCallback
parent = self
class _Callback(TrainerCallback):
def __init__(self):
self._train_start_time = None
def on_train_begin(self, args, state, control, **kwargs):
self._train_start_time = time.time()
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
eta = 0.0
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
elapsed = time.time() - self._train_start_time
remaining_steps = total_steps - state.global_step
if state.global_step > 0:
eta = remaining_steps * (elapsed / state.global_step)
extra_metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
extra_metrics[k] = float(v)
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(logs.get('epoch', 0)),
total_epochs=float(parent.total_epochs),
loss=float(logs.get('loss', 0)),
learning_rate=float(logs.get('learning_rate', 0)),
grad_norm=float(logs.get('grad_norm', 0)),
eval_loss=float(logs.get('eval_loss', 0)),
eta_seconds=float(eta),
progress_percent=float(progress),
status="training",
extra_metrics=extra_metrics,
)
parent.progress_queue.put(update)
def on_prediction_step(self, args, state, control, **kwargs):
"""Send periodic updates during evaluation so the UI doesn't freeze."""
if not hasattr(self, '_eval_update_counter'):
self._eval_update_counter = 0
self._eval_update_counter += 1
# Throttle: send an update every 10 prediction steps
if self._eval_update_counter % 10 != 0:
return
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(state.epoch or 0),
total_epochs=float(parent.total_epochs),
progress_percent=float(progress),
status="training",
message=f"Evaluating... (batch {self._eval_update_counter})",
)
parent.progress_queue.put(update)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
"""Report eval results once evaluation is done."""
# Reset prediction counter for next eval round
self._eval_update_counter = 0
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
eval_loss = 0.0
extra_metrics = {}
if metrics:
eval_loss = float(metrics.get('eval_loss', 0))
for k, v in metrics.items():
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
extra_metrics[k] = float(v)
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(state.epoch or 0),
total_epochs=float(parent.total_epochs),
eval_loss=eval_loss,
progress_percent=float(progress),
status="training",
message=f"Evaluation complete at step {state.global_step}",
extra_metrics=extra_metrics,
)
parent.progress_queue.put(update)
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
status="saving",
message=f"Checkpoint saved at step {state.global_step}",
checkpoint_path=checkpoint_path,
)
parent.progress_queue.put(update)
def on_train_end(self, args, state, control, **kwargs):
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=state.max_steps,
progress_percent=100.0,
status="completed",
message="Training completed",
)
parent.progress_queue.put(update)
return _Callback()
class ActiveJob:
"""Represents an active fine-tuning job."""
def __init__(self, job_id):
self.job_id = job_id
self.progress_queue = queue.Queue()
self.trainer = None
self.thread = None
self.model = None
self.tokenizer = None
self.error = None
self.completed = False
self.stopped = False
def _is_gated_repo_error(exc):
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
try:
from huggingface_hub.utils import GatedRepoError
if isinstance(exc, GatedRepoError):
return True
except ImportError:
pass
msg = str(exc).lower()
if "gated repo" in msg or "access to model" in msg:
return True
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
if exc.response.status_code in (401, 403):
return True
return False
class BackendServicer(backend_pb2_grpc.BackendServicer):
def __init__(self):
self.active_job = None
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
"""Accept LoadModel — actual model loading happens in StartFineTune."""
return backend_pb2.Result(success=True, message="OK")
def StartFineTune(self, request, context):
if self.active_job is not None and not self.active_job.completed:
return backend_pb2.FineTuneJobResult(
job_id="",
success=False,
message="A fine-tuning job is already running",
)
job_id = request.job_id if request.job_id else str(uuid.uuid4())
job = ActiveJob(job_id)
self.active_job = job
# Start training in background thread
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
job.thread = thread
thread.start()
return backend_pb2.FineTuneJobResult(
job_id=job_id,
success=True,
message="Fine-tuning job started",
)
def _run_training(self, request, job):
try:
self._do_training(request, job)
except Exception as e:
if _is_gated_repo_error(e):
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
else:
msg = f"Training failed: {e}"
job.error = msg
job.completed = True
update = backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id,
status="failed",
message=msg,
)
job.progress_queue.put(update)
# Send sentinel
job.progress_queue.put(None)
def _do_training(self, request, job):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
extra = dict(request.extra_options)
training_method = request.training_method or "sft"
training_type = request.training_type or "lora"
# Send loading status
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
))
# Determine device and dtype
device_map = "auto" if torch.cuda.is_available() else "cpu"
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
# Load model
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
if hf_token:
model_kwargs["token"] = hf_token
if extra.get("trust_remote_code", "false").lower() == "true":
model_kwargs["trust_remote_code"] = True
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
job.model = model
job.tokenizer = tokenizer
# Apply LoRA if requested
if training_type == "lora":
from peft import LoraConfig, get_peft_model
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
target_modules = list(request.target_modules) if request.target_modules else None
peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules or "all-linear",
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
# Load dataset
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
))
dataset_split = request.dataset_split or "train"
if os.path.exists(request.dataset_source):
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
elif request.dataset_source.endswith('.csv'):
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
else:
dataset = load_dataset(request.dataset_source, split=dataset_split)
else:
dataset = load_dataset(request.dataset_source, split=dataset_split)
# Eval dataset setup
eval_dataset = None
eval_strategy = extra.get("eval_strategy", "steps")
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
if eval_strategy != "no":
eval_split = extra.get("eval_split")
eval_dataset_source = extra.get("eval_dataset_source")
if eval_split:
# Load a specific split as eval dataset
if os.path.exists(request.dataset_source):
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
elif request.dataset_source.endswith('.csv'):
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
else:
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
else:
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
elif eval_dataset_source:
# Load eval dataset from a separate source
eval_dataset = load_dataset(eval_dataset_source, split="train")
else:
# Auto-split from training set
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
split = dataset.train_test_split(test_size=eval_split_ratio)
dataset = split["train"]
eval_dataset = split["test"]
if eval_strategy == "no":
eval_dataset = None
# Training config
output_dir = request.output_dir or f"./output-{job.job_id}"
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
batch_size = request.batch_size if request.batch_size > 0 else 2
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
max_steps = request.max_steps if request.max_steps > 0 else -1
save_steps = request.save_steps if request.save_steps > 0 else 500
seed = request.seed if request.seed > 0 else 3407
optimizer = request.optimizer or "adamw_torch"
# Checkpoint save controls
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
# CPU vs GPU training args (can be overridden via extra_options)
use_cpu = not torch.cuda.is_available()
common_train_kwargs = {}
if use_cpu:
common_train_kwargs["use_cpu"] = True
common_train_kwargs["fp16"] = False
common_train_kwargs["bf16"] = False
common_train_kwargs["gradient_checkpointing"] = False
else:
common_train_kwargs["bf16"] = True
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
# Allow extra_options to override training kwargs
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
if flag in extra:
common_train_kwargs[flag] = extra[flag].lower() == "true"
# Create progress callback
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
# Build save kwargs (shared across all methods)
_save_kwargs = {}
if save_strategy == "steps" and save_steps > 0:
_save_kwargs["save_steps"] = save_steps
_save_kwargs["save_strategy"] = "steps"
elif save_strategy == "epoch":
_save_kwargs["save_strategy"] = "epoch"
elif save_strategy == "no":
_save_kwargs["save_strategy"] = "no"
else:
_save_kwargs["save_steps"] = save_steps
_save_kwargs["save_strategy"] = "steps"
if save_total_limit:
_save_kwargs["save_total_limit"] = save_total_limit
# Eval kwargs
_eval_kwargs = {}
if eval_dataset is not None:
_eval_kwargs["eval_strategy"] = eval_strategy
_eval_kwargs["eval_steps"] = eval_steps
# Common training arguments shared by all methods
_common_args = dict(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=lr,
gradient_accumulation_steps=grad_accum,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
max_steps=max_steps,
seed=seed,
optim=optimizer,
logging_steps=1,
report_to="none",
**_save_kwargs,
**common_train_kwargs,
**_eval_kwargs,
)
# Select trainer based on training method
if training_method == "sft":
from trl import SFTTrainer, SFTConfig
max_length = int(extra.get("max_seq_length", "512"))
packing = extra.get("packing", "false").lower() == "true"
training_args = SFTConfig(
max_length=max_length,
packing=packing,
**_common_args,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "dpo":
from trl import DPOTrainer, DPOConfig
beta = float(extra.get("beta", "0.1"))
loss_type = extra.get("loss_type", "sigmoid")
max_length = int(extra.get("max_length", "512"))
training_args = DPOConfig(
beta=beta,
loss_type=loss_type,
max_length=max_length,
**_common_args,
)
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "grpo":
from trl import GRPOTrainer, GRPOConfig
num_generations = int(extra.get("num_generations", "4"))
max_completion_length = int(extra.get("max_completion_length", "256"))
training_args = GRPOConfig(
num_generations=num_generations,
max_completion_length=max_completion_length,
**_common_args,
)
# GRPO requires reward functions passed via extra_options as a JSON list
from reward_functions import build_reward_functions
reward_funcs = []
if extra.get("reward_funcs"):
reward_funcs = build_reward_functions(extra["reward_funcs"])
if not reward_funcs:
raise ValueError(
"GRPO requires at least one reward function. "
"Specify reward_functions in the request or "
"reward_funcs in extra_options."
)
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_funcs=reward_funcs,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "orpo":
from trl import ORPOTrainer, ORPOConfig
beta = float(extra.get("beta", "0.1"))
max_length = int(extra.get("max_length", "512"))
training_args = ORPOConfig(
beta=beta,
max_length=max_length,
**_common_args,
)
trainer = ORPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "kto":
from trl import KTOTrainer, KTOConfig
beta = float(extra.get("beta", "0.1"))
max_length = int(extra.get("max_length", "512"))
training_args = KTOConfig(
beta=beta,
max_length=max_length,
**_common_args,
)
trainer = KTOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "rloo":
from trl import RLOOTrainer, RLOOConfig
num_generations = int(extra.get("num_generations", "4"))
max_completion_length = int(extra.get("max_completion_length", "256"))
training_args = RLOOConfig(
num_generations=num_generations,
max_new_tokens=max_completion_length,
**_common_args,
)
trainer = RLOOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "reward":
from trl import RewardTrainer, RewardConfig
max_length = int(extra.get("max_length", "512"))
training_args = RewardConfig(
max_length=max_length,
**_common_args,
)
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
else:
raise ValueError(f"Unsupported training method: {training_method}. "
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
job.trainer = trainer
# Start training
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="training", message="Training started",
))
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
trainer.train(resume_from_checkpoint=resume_ckpt)
# Save final model
trainer.save_model(output_dir)
if tokenizer:
tokenizer.save_pretrained(output_dir)
job.completed = True
# Sentinel to signal stream end
job.progress_queue.put(None)
def FineTuneProgress(self, request, context):
if self.active_job is None or self.active_job.job_id != request.job_id:
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Job {request.job_id} not found")
return
job = self.active_job
while True:
try:
update = job.progress_queue.get(timeout=1.0)
if update is None:
break
yield update
if update.status in ("completed", "failed", "stopped"):
break
except queue.Empty:
if job.completed or job.stopped:
break
if not context.is_active():
break
continue
def StopFineTune(self, request, context):
# Stopping is handled by killing the process from Go via ShutdownModel.
return backend_pb2.Result(success=True, message="OK")
def ListCheckpoints(self, request, context):
output_dir = request.output_dir
if not os.path.isdir(output_dir):
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
checkpoints = []
for entry in sorted(os.listdir(output_dir)):
if entry.startswith("checkpoint-"):
ckpt_path = os.path.join(output_dir, entry)
if not os.path.isdir(ckpt_path):
continue
step = 0
try:
step = int(entry.split("-")[1])
except (IndexError, ValueError):
pass
# Try to read trainer_state.json for metadata
loss = 0.0
epoch = 0.0
state_file = os.path.join(ckpt_path, "trainer_state.json")
if os.path.exists(state_file):
try:
with open(state_file) as f:
state = json.load(f)
if state.get("log_history"):
last_log = state["log_history"][-1]
loss = last_log.get("loss", 0.0)
epoch = last_log.get("epoch", 0.0)
except Exception:
pass
created_at = time.strftime(
"%Y-%m-%dT%H:%M:%SZ",
time.gmtime(os.path.getmtime(ckpt_path)),
)
checkpoints.append(backend_pb2.CheckpointInfo(
path=ckpt_path,
step=step,
epoch=float(epoch),
loss=float(loss),
created_at=created_at,
))
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
def ExportModel(self, request, context):
export_format = request.export_format or "lora"
output_path = request.output_path
checkpoint_path = request.checkpoint_path
# Extract HF token for gated model access
extra = dict(request.extra_options) if request.extra_options else {}
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
if not checkpoint_path or not os.path.isdir(checkpoint_path):
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
os.makedirs(output_path, exist_ok=True)
try:
if export_format == "lora":
# Just copy the adapter files
import shutil
for f in os.listdir(checkpoint_path):
src = os.path.join(checkpoint_path, f)
dst = os.path.join(output_path, f)
if os.path.isfile(src):
shutil.copy2(src, dst)
elif export_format in ("merged_16bit", "merged_4bit"):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model_name = request.model
if not base_model_name:
return backend_pb2.Result(success=False, message="Base model name required for merge export")
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
merged = model.merge_and_unload()
merged.save_pretrained(output_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
tokenizer.save_pretrained(output_path)
elif export_format == "gguf":
import torch
import subprocess
import shutil
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model_name = request.model
if not base_model_name:
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
# Step 1: Merge LoRA into base model
merge_dir = os.path.join(output_path, "_hf_merged")
os.makedirs(merge_dir, exist_ok=True)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
merged = model.merge_and_unload()
merged.save_pretrained(merge_dir)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
tokenizer.save_pretrained(merge_dir)
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
# Gemma models need this file for GGUF conversion to use the
# SentencePiece path; without it, the script falls back to BPE
# handling which fails on unrecognized pre-tokenizer hashes.
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
if not os.path.exists(sp_model_path):
sp_copied = False
# Method 1: Load the slow tokenizer which keeps the SP model file
try:
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
import shutil as _shutil
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
sp_copied = True
print(f"Copied tokenizer.model from slow tokenizer cache")
except Exception as e:
print(f"Slow tokenizer method failed: {e}")
# Method 2: Download from HF hub
if not sp_copied:
try:
from huggingface_hub import hf_hub_download
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
import shutil as _shutil
_shutil.copy2(cached_sp, sp_model_path)
sp_copied = True
print(f"Copied tokenizer.model from HF hub")
except Exception as e:
print(f"HF hub download method failed: {e}")
if not sp_copied:
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
"GGUF conversion may fail for SentencePiece models.")
# Free GPU memory before conversion
del merged, model, base_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
quant = request.quantization_method or "auto"
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
outtype = outtype_map.get(quant, "f16")
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
gguf_path = os.path.join(output_path, gguf_filename)
script_dir = os.path.dirname(os.path.abspath(__file__))
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
if not os.path.exists(convert_script):
return backend_pb2.Result(success=False,
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
# Log merge_dir contents for debugging conversion issues
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
print(f"Merge dir contents: {merge_files}", flush=True)
env = os.environ.copy()
env["NO_LOCAL_GGUF"] = "1"
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
if conv_result.returncode != 0:
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
return backend_pb2.Result(success=False,
message=f"GGUF conversion failed: {diag}")
# Clean up intermediate merged model
shutil.rmtree(merge_dir, ignore_errors=True)
else:
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
except Exception as e:
if _is_gated_repo_error(e):
return backend_pb2.Result(success=False,
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
# Handle graceful shutdown
def stop(signum, frame):
server.stop(0)
sys.exit(0)
signal.signal(signal.SIGTERM, stop)
signal.signal(signal.SIGINT, stop)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
args = parser.parse_args()
serve(args.addr)

View File

@@ -0,0 +1,37 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
installRequirements
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
if [ ! -f "${CONVERT_SCRIPT}" ]; then
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
curl -L --fail --retry 3 \
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
fi
# Install gguf package from the same llama.cpp commit to keep them in sync
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
if [ "x${USE_PIP:-}" == "xtrue" ]; then
pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
pip install "gguf>=0.16.0"
}
else
uv pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
uv pip install "gguf>=0.16.0"
}
fi

View File

@@ -0,0 +1,9 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece

View File

@@ -0,0 +1,9 @@
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece
bitsandbytes

View File

@@ -0,0 +1,9 @@
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece
bitsandbytes

View File

@@ -0,0 +1,3 @@
grpcio==1.78.1
protobuf
certifi

View File

@@ -0,0 +1,236 @@
"""
Built-in reward functions and inline function compiler for GRPO training.
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
"""
import json
import re
import math
import string
import functools
# ---------------------------------------------------------------------------
# Built-in reward functions
# ---------------------------------------------------------------------------
def format_reward(completions, **kwargs):
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
return [1.0 if pattern.search(c) else 0.0 for c in completions]
def reasoning_accuracy_reward(completions, **kwargs):
"""Extracts <answer>...</answer> content and compares to the expected answer."""
answers = kwargs.get("answer", [])
if not answers:
return [0.0] * len(completions)
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
scores = []
for i, c in enumerate(completions):
expected = answers[i] if i < len(answers) else ""
match = pattern.search(c)
if match:
extracted = match.group(1).strip()
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
else:
scores.append(0.0)
return scores
def length_reward(completions, target_length=200, **kwargs):
"""Score based on proximity to target_length. Returns [0, 1]."""
scores = []
for c in completions:
length = len(c)
if target_length <= 0:
scores.append(0.0)
else:
diff = abs(length - target_length) / target_length
scores.append(max(0.0, 1.0 - diff))
return scores
def xml_tag_reward(completions, **kwargs):
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
tags = ["think", "answer"]
scores = []
for c in completions:
tag_score = 0.0
for tag in tags:
if f"<{tag}>" in c and f"</{tag}>" in c:
tag_score += 0.5
scores.append(min(tag_score, 1.0))
return scores
def no_repetition_reward(completions, n=4, **kwargs):
"""Penalizes n-gram repetition. Returns [0, 1]."""
scores = []
for c in completions:
words = c.split()
if len(words) < n:
scores.append(1.0)
continue
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
unique = len(set(ngrams))
total = len(ngrams)
scores.append(unique / total if total > 0 else 1.0)
return scores
def code_execution_reward(completions, **kwargs):
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
scores = []
for c in completions:
match = pattern.search(c)
if not match:
scores.append(0.0)
continue
code = match.group(1)
try:
compile(code, "<inline>", "exec")
scores.append(1.0)
except SyntaxError:
scores.append(0.0)
return scores
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
BUILTIN_REGISTRY = {
"format_reward": format_reward,
"reasoning_accuracy_reward": reasoning_accuracy_reward,
"length_reward": length_reward,
"xml_tag_reward": xml_tag_reward,
"no_repetition_reward": no_repetition_reward,
"code_execution_reward": code_execution_reward,
}
# ---------------------------------------------------------------------------
# Inline function compiler
# ---------------------------------------------------------------------------
_SAFE_BUILTINS = {
"len": len, "int": int, "float": float, "str": str, "bool": bool,
"list": list, "dict": dict, "tuple": tuple, "set": set,
"range": range, "enumerate": enumerate, "zip": zip,
"map": map, "filter": filter, "sorted": sorted,
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
"any": any, "all": all, "isinstance": isinstance, "type": type,
"print": print, "True": True, "False": False, "None": None,
"ValueError": ValueError, "TypeError": TypeError,
"KeyError": KeyError, "IndexError": IndexError,
}
def compile_inline_reward(name, code):
"""Compile user-provided code into a reward function.
The code should be the body of a function that receives
`completions` (list[str]) and `**kwargs`, and returns list[float].
Available modules: re, math, json, string.
"""
func_source = (
f"def _user_reward_{name}(completions, **kwargs):\n"
+ "\n".join(f" {line}" for line in code.splitlines())
)
restricted_globals = {
"__builtins__": _SAFE_BUILTINS,
"re": re,
"math": math,
"json": json,
"string": string,
}
try:
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
except SyntaxError as e:
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
exec(compiled, restricted_globals)
func = restricted_globals[f"_user_reward_{name}"]
# Validate with a quick smoke test
try:
result = func(["test"], answer=["test"])
if not isinstance(result, list):
raise ValueError(
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
)
except Exception as e:
if "must return a list" in str(e):
raise
# Other errors during smoke test are acceptable (e.g. missing kwargs)
pass
return func
# ---------------------------------------------------------------------------
# Dispatcher
# ---------------------------------------------------------------------------
def build_reward_functions(specs_json):
"""Parse a JSON list of reward function specs and return a list of callables.
Each spec is a dict with:
- type: "builtin" or "inline"
- name: function name
- code: (inline only) Python function body
- params: (optional) dict of string params applied via functools.partial
"""
if isinstance(specs_json, str):
specs = json.loads(specs_json)
else:
specs = specs_json
if not isinstance(specs, list):
raise ValueError("reward_funcs must be a JSON array of reward function specs")
reward_funcs = []
for spec in specs:
spec_type = spec.get("type", "builtin")
name = spec.get("name", "")
params = spec.get("params", {})
if spec_type == "builtin":
if name not in BUILTIN_REGISTRY:
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
raise ValueError(
f"Unknown builtin reward function '{name}'. Available: {available}"
)
func = BUILTIN_REGISTRY[name]
if params:
# Convert string params to appropriate types
typed_params = {}
for k, v in params.items():
try:
typed_params[k] = int(v)
except (ValueError, TypeError):
try:
typed_params[k] = float(v)
except (ValueError, TypeError):
typed_params[k] = v
func = functools.partial(func, **typed_params)
reward_funcs.append(func)
elif spec_type == "inline":
code = spec.get("code", "")
if not code.strip():
raise ValueError(f"Inline reward function '{name}' has no code")
func = compile_inline_reward(name, code)
reward_funcs.append(func)
else:
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
return reward_funcs

10
backend/python/trl/run.sh Normal file
View File

@@ -0,0 +1,10 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
startBackend $@

View File

@@ -0,0 +1,58 @@
"""
Test script for the TRL fine-tuning gRPC backend.
"""
import unittest
import subprocess
import time
import grpc
import backend_pb2
import backend_pb2_grpc
class TestBackendServicer(unittest.TestCase):
"""Tests for the TRL fine-tuning gRPC service."""
def setUp(self):
self.service = subprocess.Popen(
["python3", "backend.py", "--addr", "localhost:50051"]
)
time.sleep(10)
def tearDown(self):
self.service.kill()
self.service.wait()
def test_server_startup(self):
"""Test that the server starts and responds to health checks."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_list_checkpoints_empty(self):
"""Test listing checkpoints on a non-existent directory."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.ListCheckpoints(
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
)
self.assertEqual(len(response.checkpoints), 0)
except Exception as err:
print(err)
self.fail("ListCheckpoints service failed")
finally:
self.tearDown()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests

View File

@@ -20,6 +20,10 @@ from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalG
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
"""Check if a string can be converted to float.""" """Check if a string can be converted to float."""
@@ -724,7 +728,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -27,6 +27,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from vllm_omni.entrypoints.omni import Omni from vllm_omni.entrypoints.omni import Omni
from vllm_omni.outputs import OmniRequestOutput from vllm_omni.outputs import OmniRequestOutput
@@ -650,7 +654,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024),
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -12,6 +12,10 @@ import backend_pb2
import backend_pb2_grpc import backend_pb2_grpc
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@@ -338,7 +342,9 @@ async def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(aio=True),
)
# Add the servicer to the server # Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address # Bind the server to the address

View File

@@ -1,4 +1,4 @@
grpcio==1.78.1 grpcio==1.80.0
protobuf protobuf
certifi certifi
setuptools setuptools

View File

@@ -18,6 +18,10 @@ import backend_pb2_grpc
import torch import torch
import grpc import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s): def is_float(s):
"""Check if a string can be converted to float.""" """Check if a string can be converted to float."""
@@ -297,7 +301,9 @@ def serve(address):
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
]) ],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@@ -8,6 +8,15 @@ else
source $backend_dir/../common/libbackend.sh source $backend_dir/../common/libbackend.sh
fi fi
# The PyTorch CPU/CUDA indexes mirror common packages (e.g. requests) with
# limited, often outdated version sets. uv's default "first-index" strategy
# locks to the first index that carries a package, so it can pick e.g.
# requests==2.28.1 from the PyTorch index instead of a newer version from
# PyPI. voxcpm's transitive deps (datasets>=3 → requests>=2.32.2) need the
# PyPI versions. "unsafe-best-match" is safe here because we control both
# indexes and there is no dependency confusion risk.
export UV_INDEX_STRATEGY=unsafe-best-match
installRequirements installRequirements
if [ "x${USE_PIP}" == "xtrue" ]; then if [ "x${USE_PIP}" == "xtrue" ]; then

View File

@@ -1,6 +1,7 @@
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
torch torch
torchaudio
soundfile soundfile
numpy numpy
voxcpm voxcpm>=1.5.0
torchcodec torchcodec

View File

@@ -5,4 +5,3 @@ certifi
packaging==24.1 packaging==24.1
soundfile soundfile
numpy numpy
voxcpm

Some files were not shown because too many files have changed in this diff Show More