Compare commits

...

48 Commits

Author SHA1 Message Date
Ettore Di Giacinto
8997ff6042 Fix tests
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-21 01:04:28 +00:00
Ettore Di Giacinto
f1223b45b2 add evals, reorder menu
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-21 00:59:17 +00:00
Ettore Di Giacinto
fa8b1a8673 move fine-tune to generic features
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-20 23:44:28 +00:00
Ettore Di Giacinto
3451dbdccd commit TRL backend, stop by killing process
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-20 20:42:25 +00:00
Ettore Di Giacinto
7b8afc9609 feat(experimental): add fine-tuning endpoint and TRL support
This changeset defines new GRPC signatues for Fine tuning backends, and
add TRL backend as initial fine-tuning engine. This implementation also
supports exporting to GGUF and automatically importing it to LocalAI
after fine-tuning.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-20 20:31:45 +00:00
Ettore Di Giacinto
ae4b758a5a feat: add fine-tuning endpoint
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-20 17:09:33 +00:00
Ettore Di Giacinto
9cdbd89c1f chore(agents.md): update with auth/feature gating instructions
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-19 22:52:28 +00:00
LocalAI [bot]
7d81bf0aa3 chore: ⬆️ Update ggml-org/whisper.cpp to 9386f239401074690479731c1e41683fbbeac557 (#9077)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-19 23:27:35 +01:00
Tv
a6d0e29eba fix(openresponses): do not omit required field ORItemParam.Arguments (#9074)
See #9047
2026-03-19 22:04:45 +01:00
LocalAI [bot]
6054d2a91b feat(swagger): update swagger (#9075)
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-19 21:49:40 +01:00
Ettore Di Giacinto
aea21951a2 feat: add users and authentication support (#9061)
* feat(ui): add users and authentication support

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: allow the admin user to impersonificate users

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: ui improvements, disable 'Users' button in navbar when no auth is configured

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: add OIDC support

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: gate models

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: cache requests to optimize speed

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* small UI enhancements

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(ui): style improvements

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: cover other paths by auth

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: separate local auth, refactor

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* security hardening, approval mode

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: fix tests and expectations

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: update localagi/localrecall

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-19 21:40:51 +01:00
LocalAI [bot]
bbe9067227 docs: Add troubleshooting guide for embedding models (fixes #9064) (#9065)
docs: Add troubleshooting guide for embedding models (#9064)

- Add section on using gallery models for embeddings
- Document common issues with embedding model configuration
- Add troubleshooting guide for Qwen3 embedding models
- Include correct configuration examples for Qwen3-Embedding-4B
- Document context size limits and dimension parameters
- Add table of Qwen3 embedding model specifications

Fixes #9064

Signed-off-by: localai-bot <localai-bot@localai.io>
Co-authored-by: localai-bot <localai-bot@localai.io>
2026-03-19 19:41:12 +01:00
LocalAI [bot]
9a9da062e1 chore: ⬆️ Update ggml-org/llama.cpp to 5744d7ec430e2f875a393770195fda530560773f (#9063)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-19 07:58:30 +01:00
LocalAI [bot]
dd1a8b174f chore: ⬆️ Update ggml-org/whisper.cpp to ef3463bb29ef90d25dfabfd1e75993111c52412d (#9062)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-19 07:58:11 +01:00
Richard Palethorpe
cfb7641eea feat(ui, gallery): Show model backends and add searchable model/backend selector (#9060)
* feat(ui, gallery): Display and filter by the backend models use

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(ui): Add searchable model backend/model selector and prevent delete models being selected

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-18 21:14:41 +01:00
Richard Palethorpe
e832efeb9e fix(ui): Refresh model list on deletion (#9059)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-18 14:07:45 +01:00
dependabot[bot]
a42548e9d1 chore(deps): bump playwright from 1.52.0 to 1.58.2 in /core/http/react-ui in the npm_and_yarn group across 1 directory (#9055)
chore(deps): bump playwright

Bumps the npm_and_yarn group with 1 update in the /core/http/react-ui directory: [playwright](https://github.com/microsoft/playwright).


Updates `playwright` from 1.52.0 to 1.58.2
- [Release notes](https://github.com/microsoft/playwright/releases)
- [Commits](https://github.com/microsoft/playwright/compare/v1.52.0...v1.58.2)

---
updated-dependencies:
- dependency-name: playwright
  dependency-version: 1.58.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-18 14:05:59 +01:00
LocalAI [bot]
8615ce28a8 feat: Add standalone agent run mode inspired by LocalAGI (#9056)
- Add 'agent' subcommand with 'run' and 'list' sub-commands
- Support running agents by name from pool.json registry
- Support running agents from JSON config files
- Implement foreground mode with --prompt flag for single-turn interactions
- Reuse AgentPoolService for consistent agent initialization
- Add comprehensive unit tests for config loading and overrides

Fixes #8960

Signed-off-by: localai-bot <localai-bot@users.noreply.github.com>
Co-authored-by: localai-bot <localai-bot@noreply.github.com>
2026-03-18 14:04:20 +01:00
Ettore Di Giacinto
8336efec41 fix(ui): correctly display backend if specified in the model config, re-order MCP buttons (#9053)
fix(ui): correctly display backend if specified in the model config

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-18 09:58:25 +01:00
LocalAI [bot]
8560a1e571 chore: ⬆️ Update ace-step/acestep.cpp to ab020a9aefcd364423e0665da12babc6b0c7b507 (#9046)
⬆️ Update ace-step/acestep.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-18 08:54:15 +01:00
LocalAI [bot]
29c33e6a6a chore: ⬆️ Update ggml-org/whisper.cpp to dc9611662265870df22a7230b7586176a99c1955 (#9045)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-18 08:46:35 +01:00
LocalAI [bot]
a58475dbef chore: ⬆️ Update ggml-org/llama.cpp to ee4801e5a6ee7ee4063144ab44ab4e127f76fba8 (#9044)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-18 08:46:12 +01:00
Tv
8a0edd0809 Always populate ORItemParam.Summary (#9049)
* fix(openresponses): do not omit required fields summary and id

* fix(openresponses): ensure ORItemParam.Summary is never null

Normalize Summary to an empty slice at serialization chokepoints
(sendSSEEvent, bufferEvent, buildORResponse) so it always serializes
as [] instead of null.

Closes #9047
2026-03-18 08:45:46 +01:00
Richard Palethorpe
35d509d8e7 feat(ui): Per model backend logs and various fixes (#9028)
* feat(gallery): Switch to expandable box instead of pop-over and display model files

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(ui, backends): Add individual backend logging

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(ui): Set the context settings from the model config

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-18 08:31:26 +01:00
Ettore Di Giacinto
eef808d921 fix: call .String() on AllowedTools
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-17 23:16:19 +00:00
Ettore Di Giacinto
9d9ea5c1a0 chore(deps): bump skillserver
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-17 21:04:52 +00:00
LocalAI [bot]
e21ad5cfaa chore: ⬆️ Update leejet/stable-diffusion.cpp to 545fac4f3fb0117a4e962b1a04cf933a7e635933 (#9036)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-17 18:07:30 +01:00
dependabot[bot]
05ab0c0aa2 chore(deps): bump github.com/google/go-containerregistry from 0.21.1 to 0.21.2 (#9033)
chore(deps): bump github.com/google/go-containerregistry

Bumps [github.com/google/go-containerregistry](https://github.com/google/go-containerregistry) from 0.21.1 to 0.21.2.
- [Release notes](https://github.com/google/go-containerregistry/releases)
- [Commits](https://github.com/google/go-containerregistry/compare/v0.21.1...v0.21.2)

---
updated-dependencies:
- dependency-name: github.com/google/go-containerregistry
  dependency-version: 0.21.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 11:43:25 +01:00
dependabot[bot]
e2b6233570 chore(deps): bump actions/upload-artifact from 4 to 7 (#9030)
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4 to 7.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v4...v7)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 11:42:49 +01:00
dependabot[bot]
19f995f38f chore(deps): bump github.com/ebitengine/purego from 0.9.1 to 0.10.0 (#9034)
Bumps [github.com/ebitengine/purego](https://github.com/ebitengine/purego) from 0.9.1 to 0.10.0.
- [Release notes](https://github.com/ebitengine/purego/releases)
- [Commits](https://github.com/ebitengine/purego/compare/v0.9.1...v0.10.0)

---
updated-dependencies:
- dependency-name: github.com/ebitengine/purego
  dependency-version: 0.10.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 11:42:06 +01:00
dependabot[bot]
ac168bbc60 chore(deps): bump github.com/anthropics/anthropic-sdk-go from 1.26.0 to 1.27.0 (#9035)
chore(deps): bump github.com/anthropics/anthropic-sdk-go

Bumps [github.com/anthropics/anthropic-sdk-go](https://github.com/anthropics/anthropic-sdk-go) from 1.26.0 to 1.27.0.
- [Release notes](https://github.com/anthropics/anthropic-sdk-go/releases)
- [Changelog](https://github.com/anthropics/anthropic-sdk-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/anthropics/anthropic-sdk-go/compare/v1.26.0...v1.27.0)

---
updated-dependencies:
- dependency-name: github.com/anthropics/anthropic-sdk-go
  dependency-version: 1.27.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 11:41:36 +01:00
LocalAI [bot]
5c5e537b31 chore: ⬆️ Update ace-step/acestep.cpp to 15740f4301b3ec3020875f1fb975a6cfdb2f6767 (#9038)
⬆️ Update ace-step/acestep.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-17 10:22:53 +01:00
LocalAI [bot]
118bcee196 chore: ⬆️ Update ggml-org/llama.cpp to 9b342d0a9f2f4892daec065491583ec2be129685 (#9039)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-17 10:22:42 +01:00
LocalAI [bot]
3eabd6d1d0 chore: ⬆️ Update ggml-org/whisper.cpp to 79218f51d02ffe70575ef7fba3496dfc7adda027 (#9037)
⬆️ Update ggml-org/whisper.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-17 08:25:31 +01:00
Ettore Di Giacinto
ee96e5e08d chore: refactor endpoints to use same inferencing path, add automatic retrial mechanism in case of errors (#9029)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-16 21:31:02 +01:00
Richard Palethorpe
3d9ccd1ddc fix(ui): Add tracing inline settings back and create UI tests (#9027)
Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-03-16 17:51:06 +01:00
Ettore Di Giacinto
d8161bfe57 fix(api): unescape model names (#9024)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-16 00:56:37 +01:00
Ettore Di Giacinto
5fd42399d4 feat: support streaming mode for tool calls in agent mode, fix interleaved thinking stream (#9023)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-16 00:50:19 +01:00
LocalAI [bot]
b2030255ca chore: ⬆️ Update ggml-org/llama.cpp to 88915cb55c14769738fcab7f1c6eaa6dcc9c2b0c (#9020)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-16 00:10:11 +01:00
LocalAI [bot]
9f903ec06e chore: ⬆️ Update leejet/stable-diffusion.cpp to 862a6586cb6fcec037c14f9ed902329ecec7d990 (#9019)
⬆️ Update leejet/stable-diffusion.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-16 00:09:59 +01:00
Ettore Di Giacinto
4ea461c330 fix(ui): correctly map watchdog fields (#9022)
Fixes: https://github.com/mudler/LocalAI/issues/9018

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-15 22:12:24 +01:00
Ettore Di Giacinto
042a9b8ef6 Remove Table of Contents from README
Removed the Table of Contents section from the README.

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-15 21:49:13 +01:00
Ettore Di Giacinto
65f1a4154a Revise README structure and content
Updated the README to include new sections and remove outdated content.

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-15 21:47:08 +01:00
LocalAI [bot]
c6a51289b0 fix: Automatically disable mmap for Intel SYCL backends (#9012) (#9015)
* fix: Automatically disable mmap for Intel SYCL backends

Fixes issue #9012 where Qwen3.5 models fail to load on Intel Arc GPU
with RPC EOF error.

The Intel SYCL backend has a known issue where mmap enabled causes
the backend to hang. This change automatically disables mmap when
detecting Intel or SYCL backends.

References:
- https://github.com/mudler/LocalAI/issues/9012
- Documentation mentions: SYCL hangs when mmap: true is set

* feat: Add logging for mmap auto-disable on Intel SYCL backends

As requested in PR review, add xlog.Info call to log when mmap
is automatically disabled for Intel SYCL backends. This helps
with debugging and confirms the auto-disable logic is working.

---------

Co-authored-by: localai-bot <localai-bot@users.noreply.github.com>
2026-03-15 21:06:35 +01:00
LocalAI [bot]
87525109f1 chore: ⬆️ Update ggml-org/llama.cpp to 3a6f059909ed5dab8587df5df4120315053d57a4 (#9009)
⬆️ Update ggml-org/llama.cpp

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-15 09:46:45 +01:00
Ettore Di Giacinto
c596d8a5d9 fix: Change baseDir assignment to use ModelPath (#9010)
Fixes: https://github.com/mudler/LocalAI/issues/9005

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-15 09:45:58 +01:00
LocalAI [bot]
d79ad76e48 docs: ⬆️ update docs version mudler/LocalAI (#9008)
⬆️ Update docs version mudler/LocalAI

Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
2026-03-15 08:39:49 +01:00
Ettore Di Giacinto
dde0353432 chore(api): add path to expose collection raw files
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-14 22:31:45 +00:00
209 changed files with 24271 additions and 2894 deletions

View File

@@ -0,0 +1,259 @@
# API Endpoints and Authentication
This guide covers how to add new API endpoints and properly integrate them with the auth/permissions system.
## Architecture overview
Authentication and authorization flow through three layers:
1. **Global auth middleware** (`core/http/auth/middleware.go``auth.Middleware`) — applied to every request in `core/http/app.go`. Handles session cookies, Bearer tokens, API keys, and legacy API keys. Populates `auth_user` and `auth_role` in the Echo context.
2. **Feature middleware** (`auth.RequireFeature`) — per-feature access control applied to route groups or individual routes. Checks if the authenticated user has the specific feature enabled.
3. **Admin middleware** (`auth.RequireAdmin`) — restricts endpoints to admin users only.
When auth is disabled (no auth DB, no legacy API keys), all middleware becomes pass-through (`auth.NoopMiddleware`).
## Adding a new API endpoint
### Step 1: Create the handler
Write the endpoint handler in the appropriate package under `core/http/endpoints/`. Follow existing patterns:
```go
// core/http/endpoints/localai/my_feature.go
func MyFeatureEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
// Use auth.GetUser(c) to get the authenticated user (may be nil if auth is disabled)
user := auth.GetUser(c)
// Your logic here
return c.JSON(http.StatusOK, result)
}
}
```
### Step 2: Register routes
Add routes in the appropriate file under `core/http/routes/`. The file you use depends on the endpoint category:
| File | Category |
|------|----------|
| `routes/openai.go` | OpenAI-compatible API endpoints (`/v1/...`) |
| `routes/localai.go` | LocalAI-specific endpoints (`/api/...`, `/models/...`, `/backends/...`) |
| `routes/agents.go` | Agent pool endpoints (`/api/agents/...`) |
| `routes/auth.go` | Auth endpoints (`/api/auth/...`) |
| `routes/ui_api.go` | UI backend API endpoints |
### Step 3: Apply the right middleware
Choose the appropriate protection level:
#### No auth required (public)
Exempt paths bypass auth entirely. Add to `isExemptPath()` in `middleware.go` or use the `/api/auth/` prefix (always exempt). Use sparingly — most endpoints should require auth.
#### Standard auth (any authenticated user)
The global middleware already handles this. API paths (`/api/`, `/v1/`, etc.) automatically require authentication when auth is enabled. You don't need to add any extra middleware.
```go
router.GET("/v1/my-endpoint", myHandler) // auth enforced by global middleware
```
#### Admin only
Pass `adminMiddleware` to the route. This is set up in `app.go` and passed to `Register*Routes` functions:
```go
// In the Register function signature, accept the middleware:
func RegisterMyRoutes(router *echo.Echo, app *application.Application, adminMiddleware echo.MiddlewareFunc) {
router.POST("/models/apply", myHandler, adminMiddleware)
}
```
#### Feature-gated
For endpoints that should be toggleable per-user, use feature middleware. There are two approaches:
**Approach A: Route-level middleware** (preferred for groups of related endpoints)
```go
// In app.go, create the feature middleware:
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
// Pass it to the route registration function:
routes.RegisterMyRoutes(e, app, myFeatureMw)
// In the routes file, apply to a group:
g := e.Group("/api/my-feature", myFeatureMw)
g.GET("", listHandler)
g.POST("", createHandler)
```
**Approach B: RouteFeatureRegistry** (preferred for individual OpenAI-compatible endpoints)
Add an entry to `RouteFeatureRegistry` in `core/http/auth/features.go`. The `RequireRouteFeature` global middleware will automatically enforce it:
```go
var RouteFeatureRegistry = []RouteFeature{
// ... existing entries ...
{"POST", "/v1/my-endpoint", FeatureMyFeature},
}
```
## Adding a new feature
When you need a new toggleable feature (not just a new endpoint under an existing feature):
### 1. Define the feature constant
Add to `core/http/auth/permissions.go`:
```go
const (
// Add to the appropriate group:
// Agent features (default OFF for new users)
FeatureMyFeature = "my_feature"
// OR API features (default ON for new users)
FeatureMyFeature = "my_feature"
)
```
Then add it to the appropriate slice:
```go
// Default OFF — user must be explicitly granted access:
var AgentFeatures = []string{..., FeatureMyFeature}
// Default ON — user has access unless explicitly revoked:
var APIFeatures = []string{..., FeatureMyFeature}
```
### 2. Add feature metadata
In `core/http/auth/features.go`, add to the appropriate `FeatureMetas` function so the admin UI can display it:
```go
func AgentFeatureMetas() []FeatureMeta {
return []FeatureMeta{
// ... existing ...
{FeatureMyFeature, "My Feature", false}, // false = default OFF
}
}
```
### 3. Wire up the middleware
In `core/http/app.go`:
```go
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
```
Then pass it to the route registration function.
### 4. Register route-feature mappings (if applicable)
If your feature gates standard API endpoints (like `/v1/...`), add entries to `RouteFeatureRegistry` in `features.go` instead of using per-route middleware.
## Accessing the authenticated user in handlers
```go
import "github.com/mudler/LocalAI/core/http/auth"
func MyHandler(c echo.Context) error {
// Get the user (nil when auth is disabled or unauthenticated)
user := auth.GetUser(c)
if user == nil {
// Handle unauthenticated — or let middleware handle it
}
// Check role
if user.Role == auth.RoleAdmin {
// admin-specific logic
}
// Check feature access programmatically (when you need conditional behavior, not full blocking)
if auth.HasFeatureAccess(db, user, auth.FeatureMyFeature) {
// feature-specific logic
}
// Check model access
if !auth.IsModelAllowed(db, user, modelName) {
return c.JSON(http.StatusForbidden, ...)
}
}
```
## Middleware composition patterns
Middleware can be composed at different levels. Here are the patterns used in the codebase:
### Group-level middleware (agents pattern)
```go
// All routes in the group share the middleware
g := e.Group("/api/agents", poolReadyMw, agentsMw)
g.GET("", listHandler)
g.POST("", createHandler)
```
### Per-route middleware (localai pattern)
```go
// Individual routes get middleware as extra arguments
router.POST("/models/apply", applyHandler, adminMiddleware)
router.GET("/metrics", metricsHandler, adminMiddleware)
```
### Middleware slice (openai pattern)
```go
// Build a middleware chain for a handler
chatMiddleware := []echo.MiddlewareFunc{
usageMiddleware,
traceMiddleware,
modelFilterMiddleware,
}
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
```
## Error response format
Always use `schema.ErrorResponse` for auth/permission errors to stay consistent with the OpenAI-compatible API:
```go
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "feature not enabled for your account",
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
```
Use these HTTP status codes:
- `401 Unauthorized` — no valid credentials provided
- `403 Forbidden` — authenticated but lacking permission
- `429 Too Many Requests` — rate limited (auth endpoints)
## Usage tracking
If your endpoint should be tracked for usage (token counts, request counts), add the `usageMiddleware` to its middleware chain. See `core/http/middleware/usage.go` and how it's applied in `routes/openai.go`.
## Path protection rules
The global auth middleware classifies paths as API paths or non-API paths:
- **API paths** (always require auth when auth is enabled): `/api/`, `/v1/`, `/models/`, `/backends/`, `/backend/`, `/tts`, `/vad`, `/video`, `/stores/`, `/system`, `/ws/`, `/metrics`
- **Exempt paths** (never require auth): `/api/auth/` prefix, anything in `appConfig.PathWithoutAuth`
- **Non-API paths** (UI, static assets): pass through without auth — the React UI handles login redirects client-side
If you add endpoints under a new top-level path prefix, add it to `isAPIPath()` in `middleware.go` to ensure it requires authentication.
## Checklist
When adding a new endpoint:
- [ ] Handler in `core/http/endpoints/`
- [ ] Route registered in appropriate `core/http/routes/` file
- [ ] Auth level chosen: public / standard / admin / feature-gated
- [ ] If feature-gated: constant in `permissions.go`, metadata in `features.go`, middleware in `app.go`
- [ ] If new path prefix: added to `isAPIPath()` in `middleware.go`
- [ ] If OpenAI-compatible: entry in `RouteFeatureRegistry`
- [ ] If token-counting: `usageMiddleware` added to middleware chain
- [ ] Error responses use `schema.ErrorResponse` format
- [ ] Tests cover both authenticated and unauthenticated access

View File

@@ -0,0 +1,141 @@
# Debugging and Rebuilding Backends
When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild.
## Architecture Overview
- **Source directory**: `backend/python/<name>/` (or `backend/go/<name>/`, `backend/cpp/<name>/`)
- **Installed directory**: `backends/<name>/` — this is what LocalAI actually runs. It is populated by `make backends/<name>` which builds a Docker image, exports it, and installs it via `local-ai backends install`.
- **Virtual environment**: `backends/<name>/venv/` — the installed Python venv (for Python backends). The Python binary is at `backends/<name>/venv/bin/python`.
Editing files in `backend/python/<name>/` does **not** affect the running backend until you rebuild with `make backends/<name>`.
## Diagnosing Failures
### 1. Check the logs
Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID:
```
GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..."
```
Common error patterns:
- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub.
- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`).
- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback.
### 2. Test the Python environment directly
You can run the installed venv's Python to check imports without starting the full server:
```bash
backends/<name>/venv/bin/python -c "import datasets; print(datasets.__version__)"
```
If `pip` is missing from the venv, bootstrap it:
```bash
backends/<name>/venv/bin/python -m ensurepip
```
Then use `backends/<name>/venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements.
### 3. Check upstream dependency constraints
When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`:
```
https://github.com/huggingface/trl/blob/main/requirements.txt
```
Pin minimum versions in the backend's requirements files to match upstream.
## Common Fixes
### Missing gRPC methods
If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`:
```python
def LoadModel(self, request, context):
"""No-op — actual loading happens elsewhere."""
return backend_pb2.Result(success=True, message="OK")
```
The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading.
### Dependency version conflicts
Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps:
1. Identify the broken import in the logs
2. Test in the installed venv: `backends/<name>/venv/bin/python -c "import <module>"`
3. Check upstream requirements for version constraints
4. Update **all** requirements files in `backend/python/<name>/`:
- `requirements.txt` — base deps (grpcio, protobuf)
- `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index)
- `requirements-cublas12.txt` — CUDA 12
- `requirements-cublas13.txt` — CUDA 13
5. Rebuild: `make backends/<name>`
### PyTorch index conflicts (uv resolver)
The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`:
```bash
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
installRequirements
```
Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference.
## Rebuilding
### Rebuild a single backend
```bash
make backends/<name>
```
This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/<name>.tar`, and installs it into `backends/<name>/`. It also rebuilds the `local-ai` Go binary (without extra tags).
**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/<name>` step will overwrite your binary without that tag. Rebuild the Go binary afterward:
```bash
GO_TAGS=auth make build
```
### Rebuild and restart
After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded.
```bash
# Kill existing process
kill <pid>
# Restart
./local-ai run --debug [your flags]
```
### Quick iteration (skip Docker rebuild)
For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly:
```bash
# Edit the installed copy
vim backends/<name>/backend.py
# Restart LocalAI to respawn the gRPC process
```
This is useful for testing but **does not persist** — the next `make backends/<name>` will overwrite it. Always commit fixes to the source in `backend/python/<name>/`.
## Verification
After fixing and rebuilding:
1. Start LocalAI and confirm the backend registers: look for `Registering backend name="<name>"` in the logs
2. Trigger the operation that failed (e.g. start a fine-tuning job)
3. Watch the GRPC stderr/stdout lines for the backend's model ID
4. Confirm no errors in the traceback

View File

@@ -118,6 +118,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'true'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
@@ -366,6 +379,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
@@ -757,6 +783,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-13-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'l4t'
cuda-major-version: "13"
cuda-minor-version: "0"

72
.github/workflows/tests-ui-e2e.yml vendored Normal file
View File

@@ -0,0 +1,72 @@
---
name: 'UI E2E Tests'
on:
pull_request:
paths:
- 'core/http/**'
- 'tests/e2e-ui/**'
- 'tests/e2e/mock-backend/**'
push:
branches:
- master
concurrency:
group: ci-tests-ui-e2e-${{ github.head_ref || github.ref }}-${{ github.repository }}
cancel-in-progress: true
jobs:
tests-ui-e2e:
runs-on: ubuntu-latest
strategy:
matrix:
go-version: ['1.26.x']
steps:
- name: Clone
uses: actions/checkout@v6
with:
submodules: true
- name: Setup Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '22'
- name: Proto Dependencies
run: |
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
rm protoc.zip
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
- name: System Dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential libopus-dev
- name: Build UI test server
run: PATH="$PATH:$HOME/go/bin" make build-ui-test-server
- name: Install Playwright
working-directory: core/http/react-ui
run: |
npm install
npx playwright install --with-deps chromium
- name: Run Playwright tests
working-directory: core/http/react-ui
run: npx playwright test
- name: Upload Playwright report
if: ${{ failure() }}
uses: actions/upload-artifact@v7
with:
name: playwright-report
path: core/http/react-ui/playwright-report/
retention-days: 7
- name: Setup tmate session if tests fail
if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.23
with:
detached: true
connect-timeout-seconds: 180
limit-access-to-actor: true

5
.gitignore vendored
View File

@@ -72,3 +72,8 @@ core/http/react-ui/dist
# Extracted backend binaries for container-based testing
local-backends/
# UI E2E test artifacts
tests/e2e-ui/ui-test-server
core/http/react-ui/playwright-report/
core/http/react-ui/test-results/

View File

@@ -11,6 +11,8 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
## Quick Reference

View File

@@ -256,7 +256,7 @@ RUN apt-get update && \
FROM build-requirements AS builder-base
ARG GO_TAGS=""
ARG GO_TAGS="auth"
ARG GRPC_BACKENDS
ARG MAKEFLAGS
ARG LD_FLAGS="-s -w"

View File

@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl
GOCMD=go
GOTEST=$(GOCMD) test
@@ -421,6 +421,7 @@ prepare-test-extra: protogen-python
$(MAKE) -C backend/python/voxcpm
$(MAKE) -C backend/python/whisperx
$(MAKE) -C backend/python/ace-step
$(MAKE) -C backend/python/trl
test-extra: prepare-test-extra
$(MAKE) -C backend/python/transformers test
@@ -440,6 +441,7 @@ test-extra: prepare-test-extra
$(MAKE) -C backend/python/voxcpm test
$(MAKE) -C backend/python/whisperx test
$(MAKE) -C backend/python/ace-step test
$(MAKE) -C backend/python/trl test
DOCKER_IMAGE?=local-ai
IMAGE_TYPE?=core
@@ -572,6 +574,7 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true
BACKEND_WHISPERX = whisperx|python|.|false|true
BACKEND_ACE_STEP = ace-step|python|.|false|true
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
BACKEND_TRL = trl|python|.|false|true
# Helper function to build docker image for a backend
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
@@ -629,12 +632,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
# Pattern rule for docker-save targets
docker-save-%: backend-images
docker save local-ai-backend:$* -o backend-images/$*.tar
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl
########################################################
### Mock Backend for E2E Tests
@@ -646,6 +650,23 @@ build-mock-backend: protogen-go
clean-mock-backend:
rm -f tests/e2e/mock-backend/mock-backend
########################################################
### UI E2E Test Server
########################################################
build-ui-test-server: build-mock-backend react-ui protogen-go
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
test-ui-e2e: build-ui-test-server
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
test-ui-e2e-docker:
docker build -t localai-ui-e2e -f tests/e2e-ui/Dockerfile .
docker run --rm localai-ui-e2e
clean-ui-test-server:
rm -f tests/e2e-ui/ui-test-server
########################################################
### END Backends
########################################################

View File

@@ -62,49 +62,16 @@
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
<details>
<summary><strong>Table of Contents</strong></summary>
- [Local Stack Family](#local-stack-family)
- [Screenshots / Video](#screenshots--video)
- [Quickstart](#-quickstart)
- [macOS Download](#macos-download)
- [Containers (Docker, podman, ...)](#containers-docker-podman-)
- [Latest project news](#-latest-project-news)
- [Features](#-features)
- [Supported Backends & Acceleration](#-supported-backends--acceleration)
- [Text Generation & Language Models](#text-generation--language-models)
- [Audio & Speech Processing](#audio--speech-processing)
- [Image & Video Generation](#image--video-generation)
- [Specialized AI Tasks](#specialized-ai-tasks)
- [Hardware Acceleration Matrix](#hardware-acceleration-matrix)
- [Community and integrations](#-community-and-integrations)
- [Resources](#-resources)
- [Media, Blogs, Social](#book--media-blogs-social)
- [Autonomous Development Team](#-autonomous-development-team)
- [Citation](#citation)
- [Sponsors](#-sponsors)
- [Individual sponsors](#individual-sponsors)
- [Star history](#-star-history)
- [License](#-license)
- [Acknowledgements](#-acknowledgements)
- [Contributors](#-contributors)
</details>
## Local Stack Family
Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tools, you might also like:
- **[LocalAGI](https://github.com/mudler/LocalAGI)** - AI agent orchestration platform with OpenAI Responses API compatibility and advanced agentic capabilities
- **[LocalRecall](https://github.com/mudler/LocalRecall)** - MCP/REST API knowledge base system providing persistent memory and storage for AI agents
- 🆕 **[Cogito](https://github.com/mudler/cogito)** - Go library for building intelligent, co-operative agentic software and LLM-powered workflows, focusing on improving results for small, open source language models that scales to any LLM. Powers LocalAGI and LocalAI MCP/Agentic capabilities
- 🆕 **[Wiz](https://github.com/mudler/wiz)** - Terminal-based AI agent accessible via Ctrl+Space keybinding. Portable, local-LLM friendly shell assistant with TUI/CLI modes, tool execution with approval, MCP protocol support, and multi-shell compatibility (zsh, bash, fish)
- 🆕 **[SkillServer](https://github.com/mudler/skillserver)** - Simple, centralized skills database for AI agents via MCP. Manages skills as Markdown files with MCP server integration, web UI for editing, Git synchronization, and full-text search capabilities
## Screenshots / Video
### Chat, Model gallery
https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
### Agents
https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
### Youtube video
<h1 align="center">
@@ -113,29 +80,8 @@ Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tool
<br>
</h1>
### Screenshots
| Talk Interface | Generate Audio |
| --- | --- |
| ![LocalAI - Talk](./docs/assets/images/screenshots/screenshot_talk.png) | ![LocalAI - Generate Audio](./docs/assets/images/screenshots/screenshot_tts.png) |
| Models Overview | Generate Images |
| --- | --- |
| ![LocalAI - Models](./docs/assets/images/screenshots/screenshot_gallery.png) | ![LocalAI - Generate Images](./docs/assets/images/screenshots/screenshot_image.png) |
| Chat Interface | Home |
| --- | --- |
| ![LocalAI - Chat](./docs/assets/images/screenshots/screenshot_chat.png) | ![LocalAI - Home](./docs/assets/images/screenshots/screenshot_home.png) |
| Login | Swarm |
| --- | --- |
| ![LocalAI - Login](./docs/assets/images/screenshots/screenshot_login.png) | ![LocalAI - P2P Dashboard](./docs/assets/images/screenshots/screenshot_p2p.png) |
## 💻 Quickstart
### macOS Download:
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
@@ -143,6 +89,7 @@ Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tool
</a>
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
> Install the DMG and paste this code into terminal: `sudo xattr -d com.apple.quarantine /Applications/LocalAI.app`
### Containers (Docker, podman, ...)

View File

@@ -39,6 +39,13 @@ service Backend {
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
// Fine-tuning RPCs
rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {}
rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {}
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
rpc ExportModel(ExportModelRequest) returns (Result) {}
}
// Define the empty request
@@ -528,3 +535,105 @@ message ModelMetadataResponse {
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
}
// Fine-tuning messages
message FineTuneRequest {
// Model identification
string model = 1; // HF model name or local path
string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train
string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training"
// Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion)
int32 adapter_rank = 10; // LoRA rank (r), default 16
int32 adapter_alpha = 11; // scaling factor, default 16
float adapter_dropout = 12; // default 0.0
repeated string target_modules = 13; // layer names to adapt
// Universal training hyperparameters
float learning_rate = 20; // default 2e-4
int32 num_epochs = 21; // default 3
int32 batch_size = 22; // default 2
int32 gradient_accumulation_steps = 23; // default 4
int32 warmup_steps = 24; // default 5
int32 max_steps = 25; // 0 = use epochs
int32 save_steps = 26; // 0 = only save final
float weight_decay = 27; // default 0.01
bool gradient_checkpointing = 28;
string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy
int32 seed = 30; // default 3407
string mixed_precision = 31; // fp16, bf16, fp8, no
// Dataset
string dataset_source = 40; // HF dataset ID, local file/dir path
string dataset_split = 41; // train, test, etc.
// Output
string output_dir = 50;
string job_id = 51; // client-assigned or auto-generated
// Resume training from a checkpoint
string resume_from_checkpoint = 55; // path to checkpoint dir to resume from
// Backend-specific AND method-specific extensibility
map<string, string> extra_options = 60;
}
message FineTuneJobResult {
string job_id = 1;
bool success = 2;
string message = 3;
}
message FineTuneProgressRequest {
string job_id = 1;
}
message FineTuneProgressUpdate {
string job_id = 1;
int32 current_step = 2;
int32 total_steps = 3;
float current_epoch = 4;
float total_epochs = 5;
float loss = 6;
float learning_rate = 7;
float grad_norm = 8;
float eval_loss = 9;
float eta_seconds = 10;
float progress_percent = 11;
string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped
string message = 13;
string checkpoint_path = 14; // set when a checkpoint is saved
string sample_path = 15; // set when a sample is generated (video/image backends)
map<string, float> extra_metrics = 16; // method-specific metrics
}
message FineTuneStopRequest {
string job_id = 1;
bool save_checkpoint = 2;
}
message ListCheckpointsRequest {
string output_dir = 1;
}
message ListCheckpointsResponse {
repeated CheckpointInfo checkpoints = 1;
}
message CheckpointInfo {
string path = 1;
int32 step = 2;
float epoch = 3;
float loss = 4;
string created_at = 5;
}
message ExportModelRequest {
string checkpoint_path = 1;
string output_path = 2;
string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers
string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc.
string model = 5; // base model name (for merge operations)
map<string, string> extra_options = 6;
}

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=e30f1fdf74ea9238ff562901aa974c75aab6619b
LLAMA_VERSION?=5744d7ec430e2f875a393770195fda530560773f
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# acestep.cpp version
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
ACESTEP_CPP_VERSION?=5aa065445541094cba934299cd498bbb9fa5c434
ACESTEP_CPP_VERSION?=ab020a9aefcd364423e0665da12babc6b0c7b507
SO_TARGET?=libgoacestepcpp.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -112,6 +112,7 @@ func TestLoadModel(t *testing.T) {
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: mainModelPath,
ModelPath: modelDir,
Options: []string{
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
"dit_model:acestep-v15-turbo-Q8_0.gguf",
@@ -151,6 +152,7 @@ func TestSoundGeneration(t *testing.T) {
// Load models
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: mainModelPath,
ModelPath: modelDir,
Options: []string{
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
"dit_model:acestep-v15-turbo-Q8_0.gguf",

View File

@@ -24,7 +24,7 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
lmModel := opts.ModelFile
// Get the base directory from ModelFile for resolving relative paths
baseDir := filepath.Dir(lmModel)
baseDir := opts.ModelPath
var textEncoderModel, ditModel, vaeModel string

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=d6dd6d7b555c233bb9bc9f20b4751eb8c9269743
STABLEDIFFUSION_GGML_VERSION?=545fac4f3fb0117a4e962b1a04cf933a7e635933
CMAKE_ARGS+=-DGGML_MAX_NAME=128

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=30c5194c9691e4e9a98b3dea9f19727397d3f46e
WHISPER_CPP_VERSION?=9386f239401074690479731c1e41683fbbeac557
SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -3029,3 +3029,54 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-voxtral
- &trl
name: "trl"
alias: "trl"
license: apache-2.0
description: |
HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods.
Works on CPU and GPU.
urls:
- https://github.com/huggingface/trl
tags:
- fine-tuning
- LLM
- CPU
- GPU
- CUDA
capabilities:
default: "cpu-trl"
nvidia: "cuda12-trl"
nvidia-cuda-12: "cuda12-trl"
nvidia-cuda-13: "cuda13-trl"
## TRL backend images
- !!merge <<: *trl
name: "cpu-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl"
mirrors:
- localai/localai-backends:latest-cpu-trl
- !!merge <<: *trl
name: "cpu-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl"
mirrors:
- localai/localai-backends:master-cpu-trl
- !!merge <<: *trl
name: "cuda12-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl"
mirrors:
- localai/localai-backends:latest-cublas-cuda12-trl
- !!merge <<: *trl
name: "cuda12-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl"
mirrors:
- localai/localai-backends:master-cublas-cuda12-trl
- !!merge <<: *trl
name: "cuda13-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl"
mirrors:
- localai/localai-backends:latest-cublas-cuda13-trl
- !!merge <<: *trl
name: "cuda13-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
mirrors:
- localai/localai-backends:master-cublas-cuda13-trl

View File

@@ -0,0 +1,26 @@
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
LLAMA_CPP_CONVERT_VERSION ?= master
.PHONY: trl
trl:
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
.PHONY: run
run: trl
@echo "Running trl..."
bash run.sh
@echo "trl run."
.PHONY: test
test: trl
@echo "Testing trl..."
bash test.sh
@echo "trl tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -0,0 +1,860 @@
#!/usr/bin/env python3
"""
TRL fine-tuning backend for LocalAI.
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
"""
import argparse
import json
import os
import queue
import signal
import sys
import threading
import time
import uuid
from concurrent import futures
import grpc
import backend_pb2
import backend_pb2_grpc
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
class ProgressCallback:
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
def __init__(self, job_id, progress_queue, total_epochs):
self.job_id = job_id
self.progress_queue = progress_queue
self.total_epochs = total_epochs
def get_callback(self):
from transformers import TrainerCallback
parent = self
class _Callback(TrainerCallback):
def __init__(self):
self._train_start_time = None
def on_train_begin(self, args, state, control, **kwargs):
self._train_start_time = time.time()
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
eta = 0.0
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
elapsed = time.time() - self._train_start_time
remaining_steps = total_steps - state.global_step
if state.global_step > 0:
eta = remaining_steps * (elapsed / state.global_step)
extra_metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
extra_metrics[k] = float(v)
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(logs.get('epoch', 0)),
total_epochs=float(parent.total_epochs),
loss=float(logs.get('loss', 0)),
learning_rate=float(logs.get('learning_rate', 0)),
grad_norm=float(logs.get('grad_norm', 0)),
eval_loss=float(logs.get('eval_loss', 0)),
eta_seconds=float(eta),
progress_percent=float(progress),
status="training",
extra_metrics=extra_metrics,
)
parent.progress_queue.put(update)
def on_prediction_step(self, args, state, control, **kwargs):
"""Send periodic updates during evaluation so the UI doesn't freeze."""
if not hasattr(self, '_eval_update_counter'):
self._eval_update_counter = 0
self._eval_update_counter += 1
# Throttle: send an update every 10 prediction steps
if self._eval_update_counter % 10 != 0:
return
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(state.epoch or 0),
total_epochs=float(parent.total_epochs),
progress_percent=float(progress),
status="training",
message=f"Evaluating... (batch {self._eval_update_counter})",
)
parent.progress_queue.put(update)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
"""Report eval results once evaluation is done."""
# Reset prediction counter for next eval round
self._eval_update_counter = 0
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
eval_loss = 0.0
extra_metrics = {}
if metrics:
eval_loss = float(metrics.get('eval_loss', 0))
for k, v in metrics.items():
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
extra_metrics[k] = float(v)
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(state.epoch or 0),
total_epochs=float(parent.total_epochs),
eval_loss=eval_loss,
progress_percent=float(progress),
status="training",
message=f"Evaluation complete at step {state.global_step}",
extra_metrics=extra_metrics,
)
parent.progress_queue.put(update)
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
status="saving",
message=f"Checkpoint saved at step {state.global_step}",
checkpoint_path=checkpoint_path,
)
parent.progress_queue.put(update)
def on_train_end(self, args, state, control, **kwargs):
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=state.max_steps,
progress_percent=100.0,
status="completed",
message="Training completed",
)
parent.progress_queue.put(update)
return _Callback()
class ActiveJob:
"""Represents an active fine-tuning job."""
def __init__(self, job_id):
self.job_id = job_id
self.progress_queue = queue.Queue()
self.trainer = None
self.thread = None
self.model = None
self.tokenizer = None
self.error = None
self.completed = False
self.stopped = False
def _is_gated_repo_error(exc):
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
try:
from huggingface_hub.utils import GatedRepoError
if isinstance(exc, GatedRepoError):
return True
except ImportError:
pass
msg = str(exc).lower()
if "gated repo" in msg or "access to model" in msg:
return True
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
if exc.response.status_code in (401, 403):
return True
return False
class BackendServicer(backend_pb2_grpc.BackendServicer):
def __init__(self):
self.active_job = None
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
"""Accept LoadModel — actual model loading happens in StartFineTune."""
return backend_pb2.Result(success=True, message="OK")
def StartFineTune(self, request, context):
if self.active_job is not None and not self.active_job.completed:
return backend_pb2.FineTuneJobResult(
job_id="",
success=False,
message="A fine-tuning job is already running",
)
job_id = request.job_id if request.job_id else str(uuid.uuid4())
job = ActiveJob(job_id)
self.active_job = job
# Start training in background thread
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
job.thread = thread
thread.start()
return backend_pb2.FineTuneJobResult(
job_id=job_id,
success=True,
message="Fine-tuning job started",
)
def _run_training(self, request, job):
try:
self._do_training(request, job)
except Exception as e:
if _is_gated_repo_error(e):
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
else:
msg = f"Training failed: {e}"
job.error = msg
job.completed = True
update = backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id,
status="failed",
message=msg,
)
job.progress_queue.put(update)
# Send sentinel
job.progress_queue.put(None)
def _do_training(self, request, job):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
extra = dict(request.extra_options)
training_method = request.training_method or "sft"
training_type = request.training_type or "lora"
# Send loading status
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
))
# Determine device and dtype
device_map = "auto" if torch.cuda.is_available() else "cpu"
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
# Load model
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
if hf_token:
model_kwargs["token"] = hf_token
if extra.get("trust_remote_code", "false").lower() == "true":
model_kwargs["trust_remote_code"] = True
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
job.model = model
job.tokenizer = tokenizer
# Apply LoRA if requested
if training_type == "lora":
from peft import LoraConfig, get_peft_model
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
target_modules = list(request.target_modules) if request.target_modules else None
peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules or "all-linear",
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
# Load dataset
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
))
dataset_split = request.dataset_split or "train"
if os.path.exists(request.dataset_source):
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
elif request.dataset_source.endswith('.csv'):
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
else:
dataset = load_dataset(request.dataset_source, split=dataset_split)
else:
dataset = load_dataset(request.dataset_source, split=dataset_split)
# Eval dataset setup
eval_dataset = None
eval_strategy = extra.get("eval_strategy", "steps")
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
if eval_strategy != "no":
eval_split = extra.get("eval_split")
eval_dataset_source = extra.get("eval_dataset_source")
if eval_split:
# Load a specific split as eval dataset
if os.path.exists(request.dataset_source):
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
elif request.dataset_source.endswith('.csv'):
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
else:
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
else:
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
elif eval_dataset_source:
# Load eval dataset from a separate source
eval_dataset = load_dataset(eval_dataset_source, split="train")
else:
# Auto-split from training set
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
split = dataset.train_test_split(test_size=eval_split_ratio)
dataset = split["train"]
eval_dataset = split["test"]
if eval_strategy == "no":
eval_dataset = None
# Training config
output_dir = request.output_dir or f"./output-{job.job_id}"
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
batch_size = request.batch_size if request.batch_size > 0 else 2
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
max_steps = request.max_steps if request.max_steps > 0 else -1
save_steps = request.save_steps if request.save_steps > 0 else 500
seed = request.seed if request.seed > 0 else 3407
optimizer = request.optimizer or "adamw_torch"
# Checkpoint save controls
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
# CPU vs GPU training args (can be overridden via extra_options)
use_cpu = not torch.cuda.is_available()
common_train_kwargs = {}
if use_cpu:
common_train_kwargs["use_cpu"] = True
common_train_kwargs["fp16"] = False
common_train_kwargs["bf16"] = False
common_train_kwargs["gradient_checkpointing"] = False
else:
common_train_kwargs["bf16"] = True
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
# Allow extra_options to override training kwargs
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
if flag in extra:
common_train_kwargs[flag] = extra[flag].lower() == "true"
# Create progress callback
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
# Build save kwargs (shared across all methods)
_save_kwargs = {}
if save_strategy == "steps" and save_steps > 0:
_save_kwargs["save_steps"] = save_steps
_save_kwargs["save_strategy"] = "steps"
elif save_strategy == "epoch":
_save_kwargs["save_strategy"] = "epoch"
elif save_strategy == "no":
_save_kwargs["save_strategy"] = "no"
else:
_save_kwargs["save_steps"] = save_steps
_save_kwargs["save_strategy"] = "steps"
if save_total_limit:
_save_kwargs["save_total_limit"] = save_total_limit
# Eval kwargs
_eval_kwargs = {}
if eval_dataset is not None:
_eval_kwargs["eval_strategy"] = eval_strategy
_eval_kwargs["eval_steps"] = eval_steps
# Common training arguments shared by all methods
_common_args = dict(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=lr,
gradient_accumulation_steps=grad_accum,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
max_steps=max_steps,
seed=seed,
optim=optimizer,
logging_steps=1,
report_to="none",
**_save_kwargs,
**common_train_kwargs,
**_eval_kwargs,
)
# Select trainer based on training method
if training_method == "sft":
from trl import SFTTrainer, SFTConfig
max_length = int(extra.get("max_seq_length", "512"))
packing = extra.get("packing", "false").lower() == "true"
training_args = SFTConfig(
max_length=max_length,
packing=packing,
**_common_args,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "dpo":
from trl import DPOTrainer, DPOConfig
beta = float(extra.get("beta", "0.1"))
loss_type = extra.get("loss_type", "sigmoid")
max_length = int(extra.get("max_length", "512"))
training_args = DPOConfig(
beta=beta,
loss_type=loss_type,
max_length=max_length,
**_common_args,
)
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "grpo":
from trl import GRPOTrainer, GRPOConfig
num_generations = int(extra.get("num_generations", "4"))
max_completion_length = int(extra.get("max_completion_length", "256"))
training_args = GRPOConfig(
num_generations=num_generations,
max_completion_length=max_completion_length,
**_common_args,
)
# GRPO requires reward functions passed via extra_options as a JSON list
from reward_functions import build_reward_functions
reward_funcs = []
if extra.get("reward_funcs"):
reward_funcs = build_reward_functions(extra["reward_funcs"])
if not reward_funcs:
raise ValueError(
"GRPO requires at least one reward function. "
"Specify reward_functions in the request or "
"reward_funcs in extra_options."
)
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_funcs=reward_funcs,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "orpo":
from trl import ORPOTrainer, ORPOConfig
beta = float(extra.get("beta", "0.1"))
max_length = int(extra.get("max_length", "512"))
training_args = ORPOConfig(
beta=beta,
max_length=max_length,
**_common_args,
)
trainer = ORPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "kto":
from trl import KTOTrainer, KTOConfig
beta = float(extra.get("beta", "0.1"))
max_length = int(extra.get("max_length", "512"))
training_args = KTOConfig(
beta=beta,
max_length=max_length,
**_common_args,
)
trainer = KTOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "rloo":
from trl import RLOOTrainer, RLOOConfig
num_generations = int(extra.get("num_generations", "4"))
max_completion_length = int(extra.get("max_completion_length", "256"))
training_args = RLOOConfig(
num_generations=num_generations,
max_new_tokens=max_completion_length,
**_common_args,
)
trainer = RLOOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "reward":
from trl import RewardTrainer, RewardConfig
max_length = int(extra.get("max_length", "512"))
training_args = RewardConfig(
max_length=max_length,
**_common_args,
)
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
else:
raise ValueError(f"Unsupported training method: {training_method}. "
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
job.trainer = trainer
# Start training
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="training", message="Training started",
))
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
trainer.train(resume_from_checkpoint=resume_ckpt)
# Save final model
trainer.save_model(output_dir)
if tokenizer:
tokenizer.save_pretrained(output_dir)
job.completed = True
# Sentinel to signal stream end
job.progress_queue.put(None)
def FineTuneProgress(self, request, context):
if self.active_job is None or self.active_job.job_id != request.job_id:
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Job {request.job_id} not found")
return
job = self.active_job
while True:
try:
update = job.progress_queue.get(timeout=1.0)
if update is None:
break
yield update
if update.status in ("completed", "failed", "stopped"):
break
except queue.Empty:
if job.completed or job.stopped:
break
if not context.is_active():
break
continue
def StopFineTune(self, request, context):
# Stopping is handled by killing the process from Go via ShutdownModel.
return backend_pb2.Result(success=True, message="OK")
def ListCheckpoints(self, request, context):
output_dir = request.output_dir
if not os.path.isdir(output_dir):
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
checkpoints = []
for entry in sorted(os.listdir(output_dir)):
if entry.startswith("checkpoint-"):
ckpt_path = os.path.join(output_dir, entry)
if not os.path.isdir(ckpt_path):
continue
step = 0
try:
step = int(entry.split("-")[1])
except (IndexError, ValueError):
pass
# Try to read trainer_state.json for metadata
loss = 0.0
epoch = 0.0
state_file = os.path.join(ckpt_path, "trainer_state.json")
if os.path.exists(state_file):
try:
with open(state_file) as f:
state = json.load(f)
if state.get("log_history"):
last_log = state["log_history"][-1]
loss = last_log.get("loss", 0.0)
epoch = last_log.get("epoch", 0.0)
except Exception:
pass
created_at = time.strftime(
"%Y-%m-%dT%H:%M:%SZ",
time.gmtime(os.path.getmtime(ckpt_path)),
)
checkpoints.append(backend_pb2.CheckpointInfo(
path=ckpt_path,
step=step,
epoch=float(epoch),
loss=float(loss),
created_at=created_at,
))
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
def ExportModel(self, request, context):
export_format = request.export_format or "lora"
output_path = request.output_path
checkpoint_path = request.checkpoint_path
# Extract HF token for gated model access
extra = dict(request.extra_options) if request.extra_options else {}
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
if not checkpoint_path or not os.path.isdir(checkpoint_path):
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
os.makedirs(output_path, exist_ok=True)
try:
if export_format == "lora":
# Just copy the adapter files
import shutil
for f in os.listdir(checkpoint_path):
src = os.path.join(checkpoint_path, f)
dst = os.path.join(output_path, f)
if os.path.isfile(src):
shutil.copy2(src, dst)
elif export_format in ("merged_16bit", "merged_4bit"):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model_name = request.model
if not base_model_name:
return backend_pb2.Result(success=False, message="Base model name required for merge export")
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
merged = model.merge_and_unload()
merged.save_pretrained(output_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
tokenizer.save_pretrained(output_path)
elif export_format == "gguf":
import torch
import subprocess
import shutil
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model_name = request.model
if not base_model_name:
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
# Step 1: Merge LoRA into base model
merge_dir = os.path.join(output_path, "_hf_merged")
os.makedirs(merge_dir, exist_ok=True)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
merged = model.merge_and_unload()
merged.save_pretrained(merge_dir)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
tokenizer.save_pretrained(merge_dir)
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
# Gemma models need this file for GGUF conversion to use the
# SentencePiece path; without it, the script falls back to BPE
# handling which fails on unrecognized pre-tokenizer hashes.
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
if not os.path.exists(sp_model_path):
sp_copied = False
# Method 1: Load the slow tokenizer which keeps the SP model file
try:
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
import shutil as _shutil
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
sp_copied = True
print(f"Copied tokenizer.model from slow tokenizer cache")
except Exception as e:
print(f"Slow tokenizer method failed: {e}")
# Method 2: Download from HF hub
if not sp_copied:
try:
from huggingface_hub import hf_hub_download
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
import shutil as _shutil
_shutil.copy2(cached_sp, sp_model_path)
sp_copied = True
print(f"Copied tokenizer.model from HF hub")
except Exception as e:
print(f"HF hub download method failed: {e}")
if not sp_copied:
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
"GGUF conversion may fail for SentencePiece models.")
# Free GPU memory before conversion
del merged, model, base_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
quant = request.quantization_method or "auto"
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
outtype = outtype_map.get(quant, "f16")
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
gguf_path = os.path.join(output_path, gguf_filename)
script_dir = os.path.dirname(os.path.abspath(__file__))
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
if not os.path.exists(convert_script):
return backend_pb2.Result(success=False,
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
# Log merge_dir contents for debugging conversion issues
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
print(f"Merge dir contents: {merge_files}", flush=True)
env = os.environ.copy()
env["NO_LOCAL_GGUF"] = "1"
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
if conv_result.returncode != 0:
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
return backend_pb2.Result(success=False,
message=f"GGUF conversion failed: {diag}")
# Clean up intermediate merged model
shutil.rmtree(merge_dir, ignore_errors=True)
else:
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
except Exception as e:
if _is_gated_repo_error(e):
return backend_pb2.Result(success=False,
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
# Handle graceful shutdown
def stop(signum, frame):
server.stop(0)
sys.exit(0)
signal.signal(signal.SIGTERM, stop)
signal.signal(signal.SIGINT, stop)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
args = parser.parse_args()
serve(args.addr)

View File

@@ -0,0 +1,37 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
installRequirements
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
if [ ! -f "${CONVERT_SCRIPT}" ]; then
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
curl -L --fail --retry 3 \
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
fi
# Install gguf package from the same llama.cpp commit to keep them in sync
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
if [ "x${USE_PIP:-}" == "xtrue" ]; then
pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
pip install "gguf>=0.16.0"
}
else
uv pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
uv pip install "gguf>=0.16.0"
}
fi

View File

@@ -0,0 +1,9 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece

View File

@@ -0,0 +1,9 @@
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece
bitsandbytes

View File

@@ -0,0 +1,9 @@
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece
bitsandbytes

View File

@@ -0,0 +1,3 @@
grpcio==1.78.1
protobuf
certifi

View File

@@ -0,0 +1,236 @@
"""
Built-in reward functions and inline function compiler for GRPO training.
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
"""
import json
import re
import math
import string
import functools
# ---------------------------------------------------------------------------
# Built-in reward functions
# ---------------------------------------------------------------------------
def format_reward(completions, **kwargs):
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
return [1.0 if pattern.search(c) else 0.0 for c in completions]
def reasoning_accuracy_reward(completions, **kwargs):
"""Extracts <answer>...</answer> content and compares to the expected answer."""
answers = kwargs.get("answer", [])
if not answers:
return [0.0] * len(completions)
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
scores = []
for i, c in enumerate(completions):
expected = answers[i] if i < len(answers) else ""
match = pattern.search(c)
if match:
extracted = match.group(1).strip()
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
else:
scores.append(0.0)
return scores
def length_reward(completions, target_length=200, **kwargs):
"""Score based on proximity to target_length. Returns [0, 1]."""
scores = []
for c in completions:
length = len(c)
if target_length <= 0:
scores.append(0.0)
else:
diff = abs(length - target_length) / target_length
scores.append(max(0.0, 1.0 - diff))
return scores
def xml_tag_reward(completions, **kwargs):
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
tags = ["think", "answer"]
scores = []
for c in completions:
tag_score = 0.0
for tag in tags:
if f"<{tag}>" in c and f"</{tag}>" in c:
tag_score += 0.5
scores.append(min(tag_score, 1.0))
return scores
def no_repetition_reward(completions, n=4, **kwargs):
"""Penalizes n-gram repetition. Returns [0, 1]."""
scores = []
for c in completions:
words = c.split()
if len(words) < n:
scores.append(1.0)
continue
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
unique = len(set(ngrams))
total = len(ngrams)
scores.append(unique / total if total > 0 else 1.0)
return scores
def code_execution_reward(completions, **kwargs):
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
scores = []
for c in completions:
match = pattern.search(c)
if not match:
scores.append(0.0)
continue
code = match.group(1)
try:
compile(code, "<inline>", "exec")
scores.append(1.0)
except SyntaxError:
scores.append(0.0)
return scores
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
BUILTIN_REGISTRY = {
"format_reward": format_reward,
"reasoning_accuracy_reward": reasoning_accuracy_reward,
"length_reward": length_reward,
"xml_tag_reward": xml_tag_reward,
"no_repetition_reward": no_repetition_reward,
"code_execution_reward": code_execution_reward,
}
# ---------------------------------------------------------------------------
# Inline function compiler
# ---------------------------------------------------------------------------
_SAFE_BUILTINS = {
"len": len, "int": int, "float": float, "str": str, "bool": bool,
"list": list, "dict": dict, "tuple": tuple, "set": set,
"range": range, "enumerate": enumerate, "zip": zip,
"map": map, "filter": filter, "sorted": sorted,
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
"any": any, "all": all, "isinstance": isinstance, "type": type,
"print": print, "True": True, "False": False, "None": None,
"ValueError": ValueError, "TypeError": TypeError,
"KeyError": KeyError, "IndexError": IndexError,
}
def compile_inline_reward(name, code):
"""Compile user-provided code into a reward function.
The code should be the body of a function that receives
`completions` (list[str]) and `**kwargs`, and returns list[float].
Available modules: re, math, json, string.
"""
func_source = (
f"def _user_reward_{name}(completions, **kwargs):\n"
+ "\n".join(f" {line}" for line in code.splitlines())
)
restricted_globals = {
"__builtins__": _SAFE_BUILTINS,
"re": re,
"math": math,
"json": json,
"string": string,
}
try:
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
except SyntaxError as e:
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
exec(compiled, restricted_globals)
func = restricted_globals[f"_user_reward_{name}"]
# Validate with a quick smoke test
try:
result = func(["test"], answer=["test"])
if not isinstance(result, list):
raise ValueError(
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
)
except Exception as e:
if "must return a list" in str(e):
raise
# Other errors during smoke test are acceptable (e.g. missing kwargs)
pass
return func
# ---------------------------------------------------------------------------
# Dispatcher
# ---------------------------------------------------------------------------
def build_reward_functions(specs_json):
"""Parse a JSON list of reward function specs and return a list of callables.
Each spec is a dict with:
- type: "builtin" or "inline"
- name: function name
- code: (inline only) Python function body
- params: (optional) dict of string params applied via functools.partial
"""
if isinstance(specs_json, str):
specs = json.loads(specs_json)
else:
specs = specs_json
if not isinstance(specs, list):
raise ValueError("reward_funcs must be a JSON array of reward function specs")
reward_funcs = []
for spec in specs:
spec_type = spec.get("type", "builtin")
name = spec.get("name", "")
params = spec.get("params", {})
if spec_type == "builtin":
if name not in BUILTIN_REGISTRY:
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
raise ValueError(
f"Unknown builtin reward function '{name}'. Available: {available}"
)
func = BUILTIN_REGISTRY[name]
if params:
# Convert string params to appropriate types
typed_params = {}
for k, v in params.items():
try:
typed_params[k] = int(v)
except (ValueError, TypeError):
try:
typed_params[k] = float(v)
except (ValueError, TypeError):
typed_params[k] = v
func = functools.partial(func, **typed_params)
reward_funcs.append(func)
elif spec_type == "inline":
code = spec.get("code", "")
if not code.strip():
raise ValueError(f"Inline reward function '{name}' has no code")
func = compile_inline_reward(name, code)
reward_funcs.append(func)
else:
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
return reward_funcs

10
backend/python/trl/run.sh Normal file
View File

@@ -0,0 +1,10 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
startBackend $@

View File

@@ -0,0 +1,58 @@
"""
Test script for the TRL fine-tuning gRPC backend.
"""
import unittest
import subprocess
import time
import grpc
import backend_pb2
import backend_pb2_grpc
class TestBackendServicer(unittest.TestCase):
"""Tests for the TRL fine-tuning gRPC service."""
def setUp(self):
self.service = subprocess.Popen(
["python3", "backend.py", "--addr", "localhost:50051"]
)
time.sleep(10)
def tearDown(self):
self.service.kill()
self.service.wait()
def test_server_startup(self):
"""Test that the server starts and responds to health checks."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_list_checkpoints_empty(self):
"""Test listing checkpoints on a non-existent directory."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.ListCheckpoints(
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
)
self.assertEqual(len(response.checkpoints), 0)
except Exception as err:
print(err)
self.fail("ListCheckpoints service failed")
finally:
self.tearDown()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests

View File

@@ -11,6 +11,7 @@ import (
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
"gorm.io/gorm"
)
type Application struct {
@@ -22,6 +23,7 @@ type Application struct {
galleryService *services.GalleryService
agentJobService *services.AgentJobService
agentPoolService atomic.Pointer[services.AgentPoolService]
authDB *gorm.DB
watchdogMutex sync.Mutex
watchdogStop chan bool
p2pMutex sync.Mutex
@@ -74,6 +76,11 @@ func (a *Application) AgentPoolService() *services.AgentPoolService {
return a.agentPoolService.Load()
}
// AuthDB returns the auth database connection, or nil if auth is not enabled.
func (a *Application) AuthDB() *gorm.DB {
return a.authDB
}
// StartupConfig returns the original startup configuration (from env vars, before file loading)
func (a *Application) StartupConfig() *config.ApplicationConfig {
return a.startupConfig
@@ -118,9 +125,23 @@ func (a *Application) StartAgentPool() {
xlog.Error("Failed to create agent pool service", "error", err)
return
}
if a.authDB != nil {
aps.SetAuthDB(a.authDB)
}
if err := aps.Start(a.applicationConfig.Context); err != nil {
xlog.Error("Failed to start agent pool", "error", err)
return
}
// Wire per-user scoped services so collections, skills, and jobs are isolated per user
usm := services.NewUserServicesManager(
aps.UserStorage(),
a.applicationConfig,
a.modelLoader,
a.backendLoader,
a.templatesEvaluator,
)
aps.SetUserServicesManager(usm)
a.agentPoolService.Store(aps)
}

View File

@@ -207,7 +207,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
envF16 := appConfig.F16 == startupAppConfig.F16
envDebug := appConfig.Debug == startupAppConfig.Debug
envCORS := appConfig.CORS == startupAppConfig.CORS
envCSRF := appConfig.CSRF == startupAppConfig.CSRF
envCSRF := appConfig.DisableCSRF == startupAppConfig.DisableCSRF
envCORSAllowOrigins := appConfig.CORSAllowOrigins == startupAppConfig.CORSAllowOrigins
envP2PToken := appConfig.P2PToken == startupAppConfig.P2PToken
envP2PNetworkID := appConfig.P2PNetworkID == startupAppConfig.P2PNetworkID
@@ -313,7 +313,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
appConfig.CORS = *settings.CORS
}
if settings.CSRF != nil && !envCSRF {
appConfig.CSRF = *settings.CSRF
appConfig.DisableCSRF = *settings.CSRF
}
if settings.CORSAllowOrigins != nil && !envCORSAllowOrigins {
appConfig.CORSAllowOrigins = *settings.CORSAllowOrigins

View File

@@ -1,6 +1,8 @@
package application
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"os"
@@ -10,6 +12,7 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services"
coreStartup "github.com/mudler/LocalAI/core/startup"
"github.com/mudler/LocalAI/internal"
@@ -81,6 +84,45 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}
// Initialize auth database if auth is enabled
if options.Auth.Enabled {
// Auto-generate HMAC secret if not provided
if options.Auth.APIKeyHMACSecret == "" {
secretFile := filepath.Join(options.DataPath, ".hmac_secret")
secret, err := loadOrGenerateHMACSecret(secretFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize HMAC secret: %w", err)
}
options.Auth.APIKeyHMACSecret = secret
}
authDB, err := auth.InitDB(options.Auth.DatabaseURL)
if err != nil {
return nil, fmt.Errorf("failed to initialize auth database: %w", err)
}
application.authDB = authDB
xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL)
// Start session and expired API key cleanup goroutine
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-options.Context.Done():
return
case <-ticker.C:
if err := auth.CleanExpiredSessions(authDB); err != nil {
xlog.Error("failed to clean expired sessions", "error", err)
}
if err := auth.CleanExpiredAPIKeys(authDB); err != nil {
xlog.Error("failed to clean expired API keys", "error", err)
}
}
}
}()
}
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
xlog.Error("error installing models", "error", err)
}
@@ -136,6 +178,8 @@ func New(opts ...config.AppOption) (*Application, error) {
loadRuntimeSettingsFromFile(options)
}
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
@@ -382,6 +426,12 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
}
}
if settings.EnableBackendLogging != nil {
if !options.EnableBackendLogging {
options.EnableBackendLogging = *settings.EnableBackendLogging
}
}
xlog.Debug("Runtime settings loaded from runtime_settings.json")
}
@@ -426,6 +476,31 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon
}
}
// loadOrGenerateHMACSecret loads an HMAC secret from the given file path,
// or generates a random 32-byte secret and persists it if the file doesn't exist.
func loadOrGenerateHMACSecret(path string) (string, error) {
data, err := os.ReadFile(path)
if err == nil {
secret := string(data)
if len(secret) >= 32 {
return secret, nil
}
}
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate HMAC secret: %w", err)
}
secret := hex.EncodeToString(b)
if err := os.WriteFile(path, []byte(secret), 0600); err != nil {
return "", fmt.Errorf("failed to persist HMAC secret: %w", err)
}
xlog.Info("Generated new HMAC secret for API key hashing", "path", path)
return secret, nil
}
// migrateDataFiles moves persistent data files from the old config directory
// to the new data directory. Only moves files that exist in src but not in dst.
func migrateDataFiles(srcDir, dstDir string) {

View File

@@ -3,8 +3,10 @@ package backend
import (
"context"
"fmt"
"time"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
)
@@ -18,6 +20,7 @@ func Detection(
opts := ModelOptions(modelConfig, appConfig)
detectionModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}
@@ -25,9 +28,35 @@ func Detection(
return nil, fmt.Errorf("could not load detection model")
}
var startTime time.Time
if appConfig.EnableTracing {
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
startTime = time.Now()
}
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
Src: sourceFile,
})
if appConfig.EnableTracing {
errStr := ""
if err != nil {
errStr = err.Error()
}
trace.RecordBackendTrace(trace.BackendTrace{
Timestamp: startTime,
Duration: time.Since(startTime),
Type: trace.BackendTraceDetection,
ModelName: modelConfig.Name,
Backend: modelConfig.Backend,
Summary: trace.TruncateString(sourceFile, 200),
Error: errStr,
Data: map[string]any{
"source_file": sourceFile,
},
})
}
return res, err
}

View File

@@ -17,6 +17,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConf
inferenceModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}

View File

@@ -17,6 +17,7 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
opts...,
)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}

View File

@@ -38,6 +38,10 @@ type TokenUsage struct {
TimingTokenGeneration float64
}
// ModelInferenceFunc is a test-friendly indirection to call model inference logic.
// Tests can override this variable to provide a stub implementation.
var ModelInferenceFunc = ModelInference
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64, metadata map[string]string) (func() (LLMResponse, error), error) {
modelFile := c.Model
@@ -61,6 +65,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
opts := ModelOptions(*c, o)
inferenceModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile})
return nil, err
}

View File

@@ -4,13 +4,33 @@ import (
"math/rand"
"os"
"path/filepath"
"strings"
"time"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/trace"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
// recordModelLoadFailure records a backend trace when model loading fails.
func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, backend string, err error, data map[string]any) {
if !appConfig.EnableTracing {
return
}
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
trace.RecordBackendTrace(trace.BackendTrace{
Timestamp: time.Now(),
Type: trace.BackendTraceModelLoad,
ModelName: modelName,
Backend: backend,
Summary: "Model load failed",
Error: err.Error(),
Data: data,
})
}
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
name := c.Name
if name == "" {
@@ -109,6 +129,16 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
mmap = *c.MMap
}
// Intel SYCL backend has issues with mmap enabled
// See: https://github.com/mudler/LocalAI/issues/9012
// Automatically disable mmap for Intel SYCL backends
if c.Backend != "" {
if strings.Contains(strings.ToLower(c.Backend), "intel") || strings.Contains(strings.ToLower(c.Backend), "sycl") {
mmap = false
xlog.Info("Auto-disabling mmap for Intel SYCL backend", "backend", c.Backend)
}
}
ctxSize := 4096
if c.ContextSize != nil {
ctxSize = *c.ContextSize

View File

@@ -15,6 +15,7 @@ func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *
opts := ModelOptions(modelConfig, appConfig)
rerankModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}

View File

@@ -37,6 +37,7 @@ func SoundGeneration(
opts := ModelOptions(modelConfig, appConfig)
soundGenModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return "", nil, err
}

View File

@@ -18,6 +18,7 @@ func TokenMetrics(
opts := ModelOptions(modelConfig, appConfig, model.WithModel(modelFile))
model, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}

View File

@@ -18,6 +18,7 @@ func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.Model
opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err = loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return schema.TokenizeResponse{}, err
}

View File

@@ -23,6 +23,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
transcriptionModel, err := ml.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}

View File

@@ -31,6 +31,7 @@ func ModelTTS(
opts := ModelOptions(modelConfig, appConfig)
ttsModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return "", nil, err
}
@@ -131,6 +132,7 @@ func ModelTTSStream(
opts := ModelOptions(modelConfig, appConfig)
ttsModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return err
}

View File

@@ -17,6 +17,7 @@ func VAD(request *schema.VADRequest,
opts := ModelOptions(modelConfig, appConfig)
vadModel, err := ml.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}

View File

@@ -17,6 +17,7 @@ func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, en
opts...,
)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
return nil, err
}

235
core/cli/agent.go Normal file
View File

@@ -0,0 +1,235 @@
package cli
import (
"context"
"encoding/json"
"fmt"
"os"
"os/signal"
"syscall"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAGI/core/state"
coreTypes "github.com/mudler/LocalAGI/core/types"
"github.com/mudler/xlog"
)
type AgentCMD struct {
Run AgentRunCMD `cmd:"" help:"Run an agent standalone (without the full LocalAI server)"`
List AgentListCMD `cmd:"" help:"List agents in the pool registry"`
}
type AgentRunCMD struct {
Name string `arg:"" optional:"" help:"Agent name to run from the pool registry (pool.json)"`
Config string `short:"c" help:"Path to a JSON agent config file (alternative to loading by name)" type:"path"`
Prompt string `short:"p" help:"Run in foreground mode: send a single prompt and print the response"`
// Agent pool settings (mirrors RunCMD agent flags)
APIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"API URL for the agent to call (e.g. http://127.0.0.1:8080)" group:"agents"`
APIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"API key for the agent" group:"agents"`
DefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for the agent" group:"agents"`
MultimodalModel string `env:"LOCALAI_AGENT_POOL_MULTIMODAL_MODEL" help:"Multimodal model for the agent" group:"agents"`
TranscriptionModel string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_MODEL" help:"Transcription model for the agent" group:"agents"`
TranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Transcription language for the agent" group:"agents"`
TTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"TTS model for the agent" group:"agents"`
StateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" default:"agents" help:"State directory containing pool.json" type:"path" group:"agents"`
Timeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Agent timeout" group:"agents"`
EnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service" group:"agents"`
EnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
CustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory" group:"agents"`
}
func (r *AgentRunCMD) Run(ctx *cliContext.Context) error {
if r.Name == "" && r.Config == "" {
return fmt.Errorf("either an agent name or --config must be provided")
}
agentConfig, err := r.loadAgentConfig()
if err != nil {
return err
}
// Override agent config fields from CLI flags when provided
r.applyOverrides(agentConfig)
xlog.Info("Starting standalone agent", "name", agentConfig.Name)
appConfig := r.buildAppConfig()
poolService, err := services.NewAgentPoolService(appConfig)
if err != nil {
return fmt.Errorf("failed to create agent pool service: %w", err)
}
if err := poolService.Start(appConfig.Context); err != nil {
return fmt.Errorf("failed to start agent pool service: %w", err)
}
defer poolService.Stop()
pool := poolService.Pool()
// Start the agent standalone (does not persist to pool.json)
if err := pool.StartAgentStandalone(agentConfig.Name, agentConfig); err != nil {
return fmt.Errorf("failed to start agent %q: %w", agentConfig.Name, err)
}
ag := pool.GetAgent(agentConfig.Name)
if ag == nil {
return fmt.Errorf("agent %q not found after start", agentConfig.Name)
}
// Foreground mode: send a single prompt and exit
if r.Prompt != "" {
xlog.Info("Sending prompt to agent", "agent", agentConfig.Name)
result := ag.Ask(coreTypes.WithText(r.Prompt))
if result == nil {
return fmt.Errorf("agent returned no result")
}
if result.Error != nil {
return fmt.Errorf("agent error: %w", result.Error)
}
fmt.Println(result.Response)
return nil
}
// Background mode: run until interrupted
xlog.Info("Agent running in background mode. Press Ctrl+C to stop.", "agent", agentConfig.Name)
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
xlog.Info("Shutting down agent", "agent", agentConfig.Name)
return nil
}
func (r *AgentRunCMD) loadAgentConfig() (*state.AgentConfig, error) {
// Load from JSON config file
if r.Config != "" {
data, err := os.ReadFile(r.Config)
if err != nil {
return nil, fmt.Errorf("failed to read config file %q: %w", r.Config, err)
}
var cfg state.AgentConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse config file %q: %w", r.Config, err)
}
if cfg.Name == "" {
return nil, fmt.Errorf("agent config must have a name")
}
return &cfg, nil
}
// Load from pool.json by name
poolFile := r.StateDir + "/pool.json"
data, err := os.ReadFile(poolFile)
if err != nil {
return nil, fmt.Errorf("failed to read pool registry %q: %w", poolFile, err)
}
var pool map[string]state.AgentConfig
if err := json.Unmarshal(data, &pool); err != nil {
return nil, fmt.Errorf("failed to parse pool registry %q: %w", poolFile, err)
}
cfg, ok := pool[r.Name]
if !ok {
available := make([]string, 0, len(pool))
for name := range pool {
available = append(available, name)
}
return nil, fmt.Errorf("agent %q not found in pool registry. Available agents: %v", r.Name, available)
}
cfg.Name = r.Name
return &cfg, nil
}
func (r *AgentRunCMD) applyOverrides(cfg *state.AgentConfig) {
if r.APIURL != "" {
cfg.APIURL = r.APIURL
}
if r.APIKey != "" {
cfg.APIKey = r.APIKey
}
if r.DefaultModel != "" && cfg.Model == "" {
cfg.Model = r.DefaultModel
}
if r.MultimodalModel != "" && cfg.MultimodalModel == "" {
cfg.MultimodalModel = r.MultimodalModel
}
if r.TranscriptionModel != "" && cfg.TranscriptionModel == "" {
cfg.TranscriptionModel = r.TranscriptionModel
}
if r.TranscriptionLanguage != "" && cfg.TranscriptionLanguage == "" {
cfg.TranscriptionLanguage = r.TranscriptionLanguage
}
if r.TTSModel != "" && cfg.TTSModel == "" {
cfg.TTSModel = r.TTSModel
}
}
func (r *AgentRunCMD) buildAppConfig() *config.ApplicationConfig {
appConfig := &config.ApplicationConfig{
Context: context.Background(),
}
appConfig.AgentPool = config.AgentPoolConfig{
Enabled: true,
APIURL: r.APIURL,
APIKey: r.APIKey,
DefaultModel: r.DefaultModel,
MultimodalModel: r.MultimodalModel,
TranscriptionModel: r.TranscriptionModel,
TranscriptionLanguage: r.TranscriptionLanguage,
TTSModel: r.TTSModel,
StateDir: r.StateDir,
Timeout: r.Timeout,
EnableSkills: r.EnableSkills,
EnableLogs: r.EnableLogs,
CustomActionsDir: r.CustomActionsDir,
}
return appConfig
}
type AgentListCMD struct {
StateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" default:"agents" help:"State directory containing pool.json" type:"path" group:"agents"`
}
func (r *AgentListCMD) Run(ctx *cliContext.Context) error {
poolFile := r.StateDir + "/pool.json"
data, err := os.ReadFile(poolFile)
if err != nil {
if os.IsNotExist(err) {
fmt.Println("No agents found (pool.json does not exist)")
return nil
}
return fmt.Errorf("failed to read pool registry %q: %w", poolFile, err)
}
var pool map[string]state.AgentConfig
if err := json.Unmarshal(data, &pool); err != nil {
return fmt.Errorf("failed to parse pool registry %q: %w", poolFile, err)
}
if len(pool) == 0 {
fmt.Println("No agents found in pool registry")
return nil
}
fmt.Printf("Agents in %s:\n", poolFile)
for name, cfg := range pool {
model := cfg.Model
if model == "" {
model = "(default)"
}
desc := cfg.Description
if desc == "" {
desc = "(no description)"
}
fmt.Printf(" - %s [model: %s] %s\n", name, model, desc)
}
return nil
}

214
core/cli/agent_test.go Normal file
View File

@@ -0,0 +1,214 @@
package cli
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/mudler/LocalAGI/core/state"
)
func TestAgentRunCMD_LoadAgentConfigFromFile(t *testing.T) {
// Create a temporary agent config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "agent.json")
cfg := state.AgentConfig{
Name: "test-agent",
Model: "llama3",
SystemPrompt: "You are a helpful assistant",
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(configFile, data, 0644); err != nil {
t.Fatal(err)
}
cmd := &AgentRunCMD{
Config: configFile,
StateDir: tmpDir,
}
loaded, err := cmd.loadAgentConfig()
if err != nil {
t.Fatalf("loadAgentConfig() error: %v", err)
}
if loaded.Name != "test-agent" {
t.Errorf("expected name %q, got %q", "test-agent", loaded.Name)
}
if loaded.Model != "llama3" {
t.Errorf("expected model %q, got %q", "llama3", loaded.Model)
}
}
func TestAgentRunCMD_LoadAgentConfigFromPool(t *testing.T) {
tmpDir := t.TempDir()
pool := map[string]state.AgentConfig{
"my-agent": {
Model: "gpt-4",
Description: "A test agent",
SystemPrompt: "Hello",
},
"other-agent": {
Model: "llama3",
},
}
data, err := json.MarshalIndent(pool, "", " ")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
t.Fatal(err)
}
cmd := &AgentRunCMD{
Name: "my-agent",
StateDir: tmpDir,
}
loaded, err := cmd.loadAgentConfig()
if err != nil {
t.Fatalf("loadAgentConfig() error: %v", err)
}
if loaded.Name != "my-agent" {
t.Errorf("expected name %q, got %q", "my-agent", loaded.Name)
}
if loaded.Model != "gpt-4" {
t.Errorf("expected model %q, got %q", "gpt-4", loaded.Model)
}
}
func TestAgentRunCMD_LoadAgentConfigFromPool_NotFound(t *testing.T) {
tmpDir := t.TempDir()
pool := map[string]state.AgentConfig{
"existing-agent": {Model: "llama3"},
}
data, err := json.MarshalIndent(pool, "", " ")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
t.Fatal(err)
}
cmd := &AgentRunCMD{
Name: "nonexistent",
StateDir: tmpDir,
}
_, err = cmd.loadAgentConfig()
if err == nil {
t.Fatal("expected error for missing agent, got nil")
}
}
func TestAgentRunCMD_LoadAgentConfigNoNameOrConfig(t *testing.T) {
cmd := &AgentRunCMD{
StateDir: t.TempDir(),
}
_, err := cmd.loadAgentConfig()
if err == nil {
t.Fatal("expected error when no pool.json exists, got nil")
}
}
func TestAgentRunCMD_ApplyOverrides(t *testing.T) {
cfg := &state.AgentConfig{
Name: "test",
}
cmd := &AgentRunCMD{
APIURL: "http://localhost:9090",
APIKey: "secret",
DefaultModel: "my-model",
}
cmd.applyOverrides(cfg)
if cfg.APIURL != "http://localhost:9090" {
t.Errorf("expected APIURL %q, got %q", "http://localhost:9090", cfg.APIURL)
}
if cfg.APIKey != "secret" {
t.Errorf("expected APIKey %q, got %q", "secret", cfg.APIKey)
}
if cfg.Model != "my-model" {
t.Errorf("expected Model %q, got %q", "my-model", cfg.Model)
}
}
func TestAgentRunCMD_ApplyOverridesDoesNotOverwriteExisting(t *testing.T) {
cfg := &state.AgentConfig{
Name: "test",
Model: "existing-model",
}
cmd := &AgentRunCMD{
DefaultModel: "override-model",
}
cmd.applyOverrides(cfg)
if cfg.Model != "existing-model" {
t.Errorf("expected Model to remain %q, got %q", "existing-model", cfg.Model)
}
}
func TestAgentRunCMD_LoadConfigMissingName(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "agent.json")
// Agent config with no name
cfg := state.AgentConfig{
Model: "llama3",
}
data, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configFile, data, 0644)
cmd := &AgentRunCMD{
Config: configFile,
StateDir: tmpDir,
}
_, err := cmd.loadAgentConfig()
if err == nil {
t.Fatal("expected error for config with no name, got nil")
}
}
func TestAgentListCMD_NoPoolFile(t *testing.T) {
cmd := &AgentListCMD{
StateDir: t.TempDir(),
}
// Should not error, just print "no agents found"
err := cmd.Run(nil)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
}
func TestAgentListCMD_WithAgents(t *testing.T) {
tmpDir := t.TempDir()
pool := map[string]state.AgentConfig{
"agent-a": {Model: "llama3", Description: "First agent"},
"agent-b": {Model: "gpt-4"},
}
data, _ := json.MarshalIndent(pool, "", " ")
os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)
cmd := &AgentListCMD{
StateDir: tmpDir,
}
err := cmd.Run(nil)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
}

View File

@@ -17,6 +17,7 @@ var CLI struct {
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
Util UtilCMD `cmd:"" help:"Utility commands"`
Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"`
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
Completion CompletionCMD `cmd:"" help:"Generate shell completion scripts for bash, zsh, or fish"`
}

View File

@@ -58,7 +58,7 @@ type RunCMD struct {
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
DisableCSRF bool `env:"LOCALAI_DISABLE_CSRF" help:"Disable CSRF middleware (enabled by default)" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
@@ -121,6 +121,24 @@ type RunCMD struct {
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
// Fine-tuning
EnableFineTuning bool `env:"LOCALAI_ENABLE_FINETUNING" default:"false" help:"Enable fine-tuning support" group:"finetuning"`
// Authentication
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"`
GitHubClientID string `env:"GITHUB_CLIENT_ID" help:"GitHub OAuth App Client ID (auto-enables auth when set)" group:"auth"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET" help:"GitHub OAuth App Client Secret" group:"auth"`
OIDCIssuer string `env:"LOCALAI_OIDC_ISSUER" help:"OIDC issuer URL for auto-discovery" group:"auth"`
OIDCClientID string `env:"LOCALAI_OIDC_CLIENT_ID" help:"OIDC Client ID (auto-enables auth)" group:"auth"`
OIDCClientSecret string `env:"LOCALAI_OIDC_CLIENT_SECRET" help:"OIDC Client Secret" group:"auth"`
AuthBaseURL string `env:"LOCALAI_BASE_URL" help:"Base URL for OAuth callbacks (e.g. http://localhost:8080)" group:"auth"`
AuthAdminEmail string `env:"LOCALAI_ADMIN_EMAIL" help:"Email address to auto-promote to admin role" group:"auth"`
AuthRegistrationMode string `env:"LOCALAI_REGISTRATION_MODE" default:"open" help:"Registration mode: 'open' (default), 'approval', or 'invite' (invite code required)" group:"auth"`
DisableLocalAuth bool `env:"LOCALAI_DISABLE_LOCAL_AUTH" default:"false" help:"Disable local email/password registration and login (use with OAuth/OIDC-only setups)" group:"auth"`
AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"`
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
Version bool
}
@@ -165,7 +183,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithBackendGalleries(r.BackendGalleries),
config.WithCors(r.CORS),
config.WithCorsAllowOrigins(r.CORSAllowOrigins),
config.WithCsrf(r.CSRF),
config.WithDisableCSRF(r.DisableCSRF),
config.WithThreads(r.Threads),
config.WithUploadLimitMB(r.UploadLimit),
config.WithApiKeys(r.APIKeys),
@@ -311,6 +329,51 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
opts = append(opts, config.WithAgentHubURL(r.AgentHubURL))
}
// Fine-tuning
if r.EnableFineTuning {
opts = append(opts, config.EnableFineTuning)
}
// Authentication
authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != ""
if authEnabled {
opts = append(opts, config.WithAuthEnabled(true))
dbURL := r.AuthDatabaseURL
if dbURL == "" {
dbURL = filepath.Join(r.DataPath, "database.db")
}
opts = append(opts, config.WithAuthDatabaseURL(dbURL))
if r.GitHubClientID != "" {
opts = append(opts, config.WithAuthGitHubClientID(r.GitHubClientID))
opts = append(opts, config.WithAuthGitHubClientSecret(r.GitHubClientSecret))
}
if r.OIDCClientID != "" {
opts = append(opts, config.WithAuthOIDCIssuer(r.OIDCIssuer))
opts = append(opts, config.WithAuthOIDCClientID(r.OIDCClientID))
opts = append(opts, config.WithAuthOIDCClientSecret(r.OIDCClientSecret))
}
if r.AuthBaseURL != "" {
opts = append(opts, config.WithAuthBaseURL(r.AuthBaseURL))
}
if r.AuthAdminEmail != "" {
opts = append(opts, config.WithAuthAdminEmail(r.AuthAdminEmail))
}
if r.AuthRegistrationMode != "" {
opts = append(opts, config.WithAuthRegistrationMode(r.AuthRegistrationMode))
}
if r.DisableLocalAuth {
opts = append(opts, config.WithAuthDisableLocalAuth(true))
}
if r.AuthAPIKeyHMACSecret != "" {
opts = append(opts, config.WithAuthAPIKeyHMACSecret(r.AuthAPIKeyHMACSecret))
}
if r.DefaultAPIKeyExpiry != "" {
opts = append(opts, config.WithAuthDefaultAPIKeyExpiry(r.DefaultAPIKeyExpiry))
}
}
if idleWatchDog || busyWatchDog {
opts = append(opts, config.EnableWatchDog)
if idleWatchDog {

View File

@@ -21,6 +21,7 @@ type ApplicationConfig struct {
Debug bool
EnableTracing bool
TracingMaxItems int
EnableBackendLogging bool
GeneratedContentDir string
UploadDir string
@@ -29,7 +30,7 @@ type ApplicationConfig struct {
DynamicConfigsDir string
DynamicConfigsDirPollInterval time.Duration
CORS bool
CSRF bool
DisableCSRF bool
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
@@ -95,6 +96,29 @@ type ApplicationConfig struct {
// Agent Pool (LocalAGI integration)
AgentPool AgentPoolConfig
// Fine-tuning
FineTuning FineTuningConfig
// Authentication & Authorization
Auth AuthConfig
}
// AuthConfig holds configuration for user authentication and authorization.
type AuthConfig struct {
Enabled bool
DatabaseURL string // "postgres://..." or file path for SQLite
GitHubClientID string
GitHubClientSecret string
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
OIDCClientID string
OIDCClientSecret string
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
AdminEmail string // auto-promote to admin on login
RegistrationMode string // "open", "approval" (default when empty), "invite"
DisableLocalAuth bool // disable local email/password registration and login
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry
}
// AgentPoolConfig holds configuration for the LocalAGI agent pool integration.
@@ -121,6 +145,11 @@ type AgentPoolConfig struct {
AgentHubURL string // default: "https://agenthub.localai.io"
}
// FineTuningConfig holds configuration for fine-tuning support.
type FineTuningConfig struct {
Enabled bool
}
type AppOption func(*ApplicationConfig)
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
@@ -149,6 +178,8 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
"/favicon.svg",
"/readyz",
"/healthz",
"/api/auth/",
"/assets/",
},
}
for _, oo := range o {
@@ -193,9 +224,9 @@ func WithP2PNetworkID(s string) AppOption {
}
}
func WithCsrf(b bool) AppOption {
func WithDisableCSRF(b bool) AppOption {
return func(o *ApplicationConfig) {
o.CSRF = b
o.DisableCSRF = b
}
}
@@ -213,6 +244,10 @@ var EnableTracing = func(o *ApplicationConfig) {
o.EnableTracing = true
}
var EnableBackendLogging = func(o *ApplicationConfig) {
o.EnableBackendLogging = true
}
var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
o.WatchDog = true
o.WatchDogIdle = true
@@ -706,6 +741,92 @@ func WithAgentHubURL(url string) AppOption {
}
}
// Fine-tuning options
var EnableFineTuning = func(o *ApplicationConfig) {
o.FineTuning.Enabled = true
}
// Auth options
func WithAuthEnabled(enabled bool) AppOption {
return func(o *ApplicationConfig) {
o.Auth.Enabled = enabled
}
}
func WithAuthDatabaseURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.DatabaseURL = url
}
}
func WithAuthGitHubClientID(clientID string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.GitHubClientID = clientID
}
}
func WithAuthGitHubClientSecret(clientSecret string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.GitHubClientSecret = clientSecret
}
}
func WithAuthBaseURL(baseURL string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.BaseURL = baseURL
}
}
func WithAuthAdminEmail(email string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.AdminEmail = email
}
}
func WithAuthRegistrationMode(mode string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.RegistrationMode = mode
}
}
func WithAuthDisableLocalAuth(disable bool) AppOption {
return func(o *ApplicationConfig) {
o.Auth.DisableLocalAuth = disable
}
}
func WithAuthOIDCIssuer(issuer string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.OIDCIssuer = issuer
}
}
func WithAuthOIDCClientID(clientID string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.OIDCClientID = clientID
}
}
func WithAuthOIDCClientSecret(clientSecret string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.OIDCClientSecret = clientSecret
}
}
func WithAuthAPIKeyHMACSecret(secret string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.APIKeyHMACSecret = secret
}
}
func WithAuthDefaultAPIKeyExpiry(expiry string) AppOption {
return func(o *ApplicationConfig) {
o.Auth.DefaultAPIKeyExpiry = expiry
}
}
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
// Some options defined at the application level are going to be passed as defaults for
// all the configuration for the models.
@@ -743,8 +864,9 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
debug := o.Debug
tracingMaxItems := o.TracingMaxItems
enableTracing := o.EnableTracing
enableBackendLogging := o.EnableBackendLogging
cors := o.CORS
csrf := o.CSRF
csrf := o.DisableCSRF
corsAllowOrigins := o.CORSAllowOrigins
p2pToken := o.P2PToken
p2pNetworkID := o.P2PNetworkID
@@ -816,6 +938,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
Debug: &debug,
TracingMaxItems: &tracingMaxItems,
EnableTracing: &enableTracing,
EnableBackendLogging: &enableBackendLogging,
CORS: &cors,
CSRF: &csrf,
CORSAllowOrigins: &corsAllowOrigins,
@@ -944,11 +1067,14 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
if settings.TracingMaxItems != nil {
o.TracingMaxItems = *settings.TracingMaxItems
}
if settings.EnableBackendLogging != nil {
o.EnableBackendLogging = *settings.EnableBackendLogging
}
if settings.CORS != nil {
o.CORS = *settings.CORS
}
if settings.CSRF != nil {
o.CSRF = *settings.CSRF
o.DisableCSRF = *settings.CSRF
}
if settings.CORSAllowOrigins != nil {
o.CORSAllowOrigins = *settings.CORSAllowOrigins

View File

@@ -26,7 +26,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
F16: true,
Debug: true,
CORS: true,
CSRF: true,
DisableCSRF: true,
CORSAllowOrigins: "https://example.com",
P2PToken: "test-token",
P2PNetworkID: "test-network",
@@ -377,7 +377,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
appConfig.ApplyRuntimeSettings(rs)
Expect(appConfig.CORS).To(BeTrue())
Expect(appConfig.CSRF).To(BeTrue())
Expect(appConfig.DisableCSRF).To(BeTrue())
Expect(appConfig.CORSAllowOrigins).To(Equal("https://example.com,https://other.com"))
})
@@ -463,7 +463,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
F16: true,
Debug: false,
CORS: true,
CSRF: false,
DisableCSRF: false,
CORSAllowOrigins: "https://test.com",
P2PToken: "round-trip-token",
P2PNetworkID: "round-trip-network",
@@ -495,7 +495,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
Expect(target.F16).To(Equal(original.F16))
Expect(target.Debug).To(Equal(original.Debug))
Expect(target.CORS).To(Equal(original.CORS))
Expect(target.CSRF).To(Equal(original.CSRF))
Expect(target.DisableCSRF).To(Equal(original.DisableCSRF))
Expect(target.CORSAllowOrigins).To(Equal(original.CORSAllowOrigins))
Expect(target.P2PToken).To(Equal(original.P2PToken))
Expect(target.P2PNetworkID).To(Equal(original.P2PNetworkID))

View File

@@ -36,8 +36,9 @@ type RuntimeSettings struct {
ContextSize *int `json:"context_size,omitempty"`
F16 *bool `json:"f16,omitempty"`
Debug *bool `json:"debug,omitempty"`
EnableTracing *bool `json:"enable_tracing,omitempty"`
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
EnableTracing *bool `json:"enable_tracing,omitempty"`
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
// Security/CORS settings
CORS *bool `json:"cors,omitempty"`

View File

@@ -0,0 +1,173 @@
package gallery
import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/xsync"
"github.com/mudler/xlog"
"gopkg.in/yaml.v3"
)
// modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config.
type modelConfigCacheEntry struct {
configMap map[string]interface{}
lastUpdated time.Time
}
func (e modelConfigCacheEntry) hasExpired() bool {
return e.lastUpdated.Before(time.Now().Add(-1 * time.Hour))
}
// modelConfigCache caches parsed model config maps keyed by URL.
var modelConfigCache = xsync.NewSyncedMap[string, modelConfigCacheEntry]()
// resolveBackend determines the backend for a GalleryModel by checking (in priority order):
// 1. Overrides["backend"] — highest priority, same as install-time merge
// 2. Inline ConfigFile["backend"] — for models with inline config maps
// 3. URL-referenced config file — fetched, parsed, and cached
//
// The model's URL should already be resolved (local override applied) before calling this.
func resolveBackend(m *GalleryModel, basePath string) string {
// 1. Overrides take priority (matches install-time mergo.WithOverride behavior)
if b, ok := m.Overrides["backend"].(string); ok && b != "" {
return b
}
// 2. Inline config_file map
if b, ok := m.ConfigFile["backend"].(string); ok && b != "" {
return b
}
// 3. Fetch and parse the URL-referenced config
if m.URL != "" {
configMap := fetchModelConfigMap(m.URL, basePath)
if b, ok := configMap["backend"].(string); ok && b != "" {
return b
}
}
return ""
}
// fetchModelConfigMap fetches a model config URL, parses the config_file YAML string
// inside it, and returns the result as a map. Results are cached for 1 hour.
// Local file:// URLs skip the cache so edits are picked up immediately.
func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
// Check cache (skip for file:// URLs so local edits are picked up immediately)
isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix)
if !isLocal && modelConfigCache.Exists(modelURL) {
entry := modelConfigCache.Get(modelURL)
if !entry.hasExpired() {
return entry.configMap
}
modelConfigCache.Delete(modelURL)
}
// Reuse existing gallery config fetcher
modelConfig, err := GetGalleryConfigFromURL[ModelConfig](modelURL, basePath)
if err != nil {
xlog.Debug("Failed to fetch model config for backend resolution", "url", modelURL, "error", err)
// Cache the failure for remote URLs to avoid repeated fetch attempts
if !isLocal {
modelConfigCache.Set(modelURL, modelConfigCacheEntry{
configMap: map[string]interface{}{},
lastUpdated: time.Now(),
})
}
return map[string]interface{}{}
}
// Parse the config_file YAML string into a map
configMap := make(map[string]interface{})
if modelConfig.ConfigFile != "" {
if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil {
xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err)
}
}
// Cache for remote URLs
if !isLocal {
modelConfigCache.Set(modelURL, modelConfigCacheEntry{
configMap: configMap,
lastUpdated: time.Now(),
})
}
return configMap
}
// prefetchModelConfigs fetches model config URLs in parallel to warm the cache.
// This avoids sequential HTTP requests on cold start (~50 unique gallery files).
func prefetchModelConfigs(urls []string, basePath string) {
const maxConcurrency = 10
sem := make(chan struct{}, maxConcurrency)
var wg sync.WaitGroup
for _, url := range urls {
wg.Add(1)
go func(u string) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
fetchModelConfigMap(u, basePath)
}(url)
}
wg.Wait()
}
// resolveModelURLLocally attempts to resolve a github: model URL to a local file://
// path when the gallery itself was loaded from a local path. This supports development
// workflows where new model files are added locally before being pushed to GitHub.
//
// For example, if the gallery was loaded from file:///path/to/gallery/index.yaml
// and a model references github:mudler/LocalAI/gallery/foo.yaml@master, this will
// check if /path/to/gallery/foo.yaml exists locally and return file:///path/to/gallery/foo.yaml.
//
// This is applied to model.URL in AvailableGalleryModels so that both listing (backend
// resolution) and installation use the same resolved URL.
func resolveModelURLLocally(modelURL, galleryURL string) string {
galleryDir := localGalleryDir(galleryURL)
if galleryDir == "" {
return modelURL
}
// Only handle github: URLs
if !strings.HasPrefix(modelURL, downloader.GithubURI) && !strings.HasPrefix(modelURL, downloader.GithubURI2) {
return modelURL
}
// Extract the filename from the github URL
// Format: github:org/repo/path/to/file.yaml@branch
raw := strings.TrimPrefix(modelURL, downloader.GithubURI2)
raw = strings.TrimPrefix(raw, downloader.GithubURI)
// Remove @branch suffix
if idx := strings.LastIndex(raw, "@"); idx >= 0 {
raw = raw[:idx]
}
filename := filepath.Base(raw)
localPath := filepath.Join(galleryDir, filename)
if _, err := os.Stat(localPath); err == nil {
return downloader.LocalPrefix + localPath
}
return modelURL
}
// localGalleryDir returns the directory of a gallery URL if it's local, or "" if remote.
func localGalleryDir(galleryURL string) string {
if strings.HasPrefix(galleryURL, downloader.LocalPrefix) {
return filepath.Dir(strings.TrimPrefix(galleryURL, downloader.LocalPrefix))
}
// Plain path (no scheme) that exists on disk
if !strings.Contains(galleryURL, "://") && !strings.HasPrefix(galleryURL, downloader.GithubURI) {
if info, err := os.Stat(galleryURL); err == nil && !info.IsDir() {
return filepath.Dir(galleryURL)
}
}
return ""
}

View File

@@ -218,6 +218,36 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst
if err != nil {
return nil, err
}
// Resolve model URLs locally (for local galleries) and collect unique
// URLs that need fetching for backend resolution.
uniqueURLs := map[string]struct{}{}
for _, m := range galleryModels {
if m.URL != "" {
m.URL = resolveModelURLLocally(m.URL, gallery.URL)
}
if m.Backend == "" && m.URL != "" {
uniqueURLs[m.URL] = struct{}{}
}
}
// Pre-warm cache with parallel fetches to avoid sequential HTTP
// requests on cold start (~50 unique gallery config files).
if len(uniqueURLs) > 0 {
urls := make([]string, 0, len(uniqueURLs))
for u := range uniqueURLs {
urls = append(urls, u)
}
prefetchModelConfigs(urls, systemState.Model.ModelsPath)
}
// Resolve backends from warm cache.
for _, m := range galleryModels {
if m.Backend == "" {
m.Backend = resolveBackend(m, systemState.Model.ModelsPath)
}
}
models = append(models, galleryModels...)
}

View File

@@ -0,0 +1,205 @@
package importers
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/xlog"
)
// ImportLocalPath scans a local directory for exported model files and produces
// a config.ModelConfig with the correct backend, model path, and options.
// Paths in the returned config are relative to modelsPath when possible so that
// the YAML config remains portable.
//
// Detection order:
// 1. GGUF files (*.gguf) — uses llama-cpp backend
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
// Make paths relative to the models directory (parent of dirPath)
// so config YAML stays portable.
modelsDir := filepath.Dir(dirPath)
relPath := func(absPath string) string {
if rel, err := filepath.Rel(modelsDir, absPath); err == nil {
return rel
}
return absPath
}
// 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention)
ggufFile := findGGUF(dirPath)
if ggufFile == "" {
ggufSubdir := dirPath + "_gguf"
ggufFile = findGGUF(ggufSubdir)
}
if ggufFile != "" {
xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile)
cfg := &config.ModelConfig{
Name: name,
Backend: "llama-cpp",
KnownUsecaseStrings: []string{"chat"},
Options: []string{"use_jinja:true"},
}
cfg.Model = relPath(ggufFile)
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.Description = buildDescription(dirPath, "GGUF")
return cfg, nil
}
// 2. LoRA adapter: look for adapter_config.json
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
if fileExists(adapterConfigPath) {
xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath)
baseModel := readBaseModel(dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = baseModel
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
cfg.Description = buildDescription(dirPath, "LoRA adapter")
return cfg, nil
}
// Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json
if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) {
xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath)
baseModel := readBaseModel(dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = baseModel
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
cfg.Description = buildDescription(dirPath, "LoRA adapter")
return cfg, nil
}
// 3. Merged model: *.safetensors or pytorch_model*.bin + config.json
if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) {
xlog.Info("ImportLocalPath: detected merged model", "path", dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = relPath(dirPath)
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.Description = buildDescription(dirPath, "merged model")
return cfg, nil
}
return nil, fmt.Errorf("could not detect model format in directory %s", dirPath)
}
// findGGUF returns the path to the first .gguf file found in dir, or "".
func findGGUF(dir string) string {
entries, err := os.ReadDir(dir)
if err != nil {
return ""
}
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") {
return filepath.Join(dir, e.Name())
}
}
return ""
}
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
func readBaseModel(dirPath string) string {
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
var ac map[string]any
if json.Unmarshal(data, &ac) == nil {
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
return bm
}
}
}
// Try export_metadata.json → base_model (Unsloth writes this)
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
var meta map[string]any
if json.Unmarshal(data, &meta) == nil {
if bm, ok := meta["base_model"].(string); ok && bm != "" {
return bm
}
}
}
return ""
}
// buildDescription creates a human-readable description using available metadata.
func buildDescription(dirPath, formatLabel string) string {
base := ""
// Try adapter_config.json
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
var ac map[string]any
if json.Unmarshal(data, &ac) == nil {
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
base = bm
}
}
}
// Try export_metadata.json
if base == "" {
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
var meta map[string]any
if json.Unmarshal(data, &meta) == nil {
if bm, ok := meta["base_model"].(string); ok && bm != "" {
base = bm
}
}
}
}
if base != "" {
return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel)
}
return fmt.Sprintf("Fine-tuned model (%s)", formatLabel)
}
func fileExists(path string) bool {
info, err := os.Stat(path)
return err == nil && !info.IsDir()
}
func hasFileWithSuffix(dir, suffix string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) {
return true
}
}
return false
}
func hasFileWithPrefix(dir, prefix string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) {
return true
}
}
return false
}

View File

@@ -0,0 +1,148 @@
package importers_test
import (
"encoding/json"
"os"
"path/filepath"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/gallery/importers"
)
var _ = Describe("ImportLocalPath", func() {
var tmpDir string
BeforeEach(func() {
var err error
tmpDir, err = os.MkdirTemp("", "importers-local-test")
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
os.RemoveAll(tmpDir)
})
Context("GGUF detection", func() {
It("detects a GGUF file in the directory", func() {
modelDir := filepath.Join(tmpDir, "my-model")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("llama-cpp"))
Expect(cfg.Model).To(ContainSubstring(".gguf"))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat"))
Expect(cfg.Options).To(ContainElement("use_jinja:true"))
})
It("detects GGUF in _gguf subdirectory", func() {
modelDir := filepath.Join(tmpDir, "my-model")
ggufDir := modelDir + "_gguf"
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("llama-cpp"))
})
})
Context("LoRA adapter detection", func() {
It("detects LoRA adapter via adapter_config.json", func() {
modelDir := filepath.Join(tmpDir, "lora-model")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
"peft_type": "LORA",
}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "lora-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf"))
Expect(cfg.LLMConfig.LoraAdapter).To(Equal("lora-model"))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
})
It("reads base model from export_metadata.json as fallback", func() {
modelDir := filepath.Join(tmpDir, "lora-unsloth")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{"peft_type": "LORA"}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"}
data, _ = json.Marshal(metadata)
Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit"))
})
})
Context("Merged model detection", func() {
It("detects merged model with safetensors + config.json", func() {
modelDir := filepath.Join(tmpDir, "merged")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "merged")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal("merged"))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
})
It("detects merged model with pytorch_model files", func() {
modelDir := filepath.Join(tmpDir, "merged-pt")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "merged-pt")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal("merged-pt"))
})
})
Context("fallback", func() {
It("returns error for empty directory", func() {
modelDir := filepath.Join(tmpDir, "empty")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
_, err := importers.ImportLocalPath(modelDir, "empty")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("could not detect model format"))
})
})
Context("description", func() {
It("includes base model name in description", func() {
modelDir := filepath.Join(tmpDir, "desc-test")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{
"base_model_name_or_path": "TinyLlama/TinyLlama-1.1B",
}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "desc-test")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B"))
Expect(cfg.Description).To(ContainSubstring("Fine-tuned from"))
})
})
})

View File

@@ -19,4 +19,7 @@ type Metadata struct {
Gallery config.Gallery `json:"gallery,omitempty" yaml:"gallery,omitempty"`
// Installed is used to indicate if the model is installed or not
Installed bool `json:"installed,omitempty" yaml:"installed,omitempty"`
// Backend is the resolved backend engine for this model (e.g. "llama-cpp").
// Populated at load time from overrides, inline config, or the URL-referenced config file.
Backend string `json:"backend,omitempty" yaml:"backend,omitempty"`
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
@@ -170,11 +171,9 @@ func API(application *application.Application) (*echo.Echo, error) {
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(e)
// Get key auth middleware
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
}
// Build auth middleware: use the new auth.Middleware when auth is enabled or
// as a unified replacement for the legacy key-auth middleware.
authMiddleware := auth.Middleware(application.AuthDB(), application.ApplicationConfig())
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
@@ -209,8 +208,20 @@ func API(application *application.Application) (*echo.Echo, error) {
e.Static("/generated-videos", videoPath)
}
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
e.Use(keyAuthMiddleware)
// Initialize usage recording when auth DB is available
if application.AuthDB() != nil {
httpMiddleware.InitUsageRecorder(application.AuthDB())
}
// Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is
// the role of the exempt-path logic inside the middleware.
e.Use(authMiddleware)
// Feature and model access control (after auth middleware, before routes)
if application.AuthDB() != nil {
e.Use(auth.RequireRouteFeature(application.AuthDB()))
e.Use(auth.RequireModelAccess(application.AuthDB()))
}
// CORS middleware
if application.ApplicationConfig().CORS {
@@ -223,14 +234,63 @@ func API(application *application.Application) (*echo.Echo, error) {
e.Use(middleware.CORS())
}
// CSRF middleware
if application.ApplicationConfig().CSRF {
xlog.Debug("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
e.Use(middleware.CSRF())
// CSRF middleware (enabled by default, disable with LOCALAI_DISABLE_CSRF=true)
//
// Protection relies on Echo's Sec-Fetch-Site header check (supported by all
// modern browsers). The legacy cookie+token approach is removed because
// Echo's Sec-Fetch-Site short-circuit never sets the cookie, so the frontend
// could never read a token to send back.
if !application.ApplicationConfig().DisableCSRF {
xlog.Debug("Enabling CSRF middleware (Sec-Fetch-Site mode)")
e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
Skipper: func(c echo.Context) bool {
// Skip CSRF for API clients using auth headers (may be cross-origin)
if c.Request().Header.Get("Authorization") != "" {
return true
}
if c.Request().Header.Get("x-api-key") != "" || c.Request().Header.Get("xi-api-key") != "" {
return true
}
// Skip when Sec-Fetch-Site header is absent (older browsers, reverse
// proxies that strip the header). The SameSite=Lax cookie attribute
// provides baseline CSRF protection for these clients.
if c.Request().Header.Get("Sec-Fetch-Site") == "" {
return true
}
return false
},
// Allow same-site requests (subdomains / different ports) in addition
// to same-origin which Echo already permits by default.
AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
secFetchSite := c.Request().Header.Get("Sec-Fetch-Site")
if secFetchSite == "same-site" {
return true, nil
}
// cross-site: block
return false, nil
},
}))
}
// Admin middleware: enforces admin role when auth is enabled, no-op otherwise
var adminMiddleware echo.MiddlewareFunc
if application.AuthDB() != nil {
adminMiddleware = auth.RequireAdmin()
} else {
adminMiddleware = auth.NoopMiddleware()
}
// Feature middlewares: per-feature access control
agentsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureAgents)
skillsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureSkills)
collectionsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureCollections)
mcpJobsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCPJobs)
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Register auth routes (login, callback, API keys, user management)
routes.RegisterAuthRoutes(e, application)
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
@@ -239,14 +299,26 @@ func API(application *application.Application) (*echo.Echo, error) {
opcache = services.NewOpCache(application.GalleryService())
}
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application)
routes.RegisterAgentPoolRoutes(e, application)
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw)
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
// Fine-tuning routes
if application.ApplicationConfig().FineTuning.Enabled {
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
ftService := services.NewFineTuneService(
application.ApplicationConfig(),
application.ModelLoader(),
application.ModelConfigLoader(),
)
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
}
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application)
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware)
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
// Serve React SPA from / with SPA fallback via 404 handler
reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist")

View File

@@ -428,8 +428,10 @@ var _ = Describe("API test", func() {
"X-Forwarded-Prefix": {"/myprefix/"},
})
Expect(err).To(BeNil(), "error")
Expect(sc).To(Equal(401), "status code")
Expect(sc).To(Equal(200), "status code")
// Non-API paths pass through to the React SPA (which handles login client-side)
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body")
Expect(string(body)).To(ContainSubstring(`<div id="root">`), "should serve React SPA")
})
It("Should support reverse-proxy when authenticated", func() {

121
core/http/auth/apikeys.go Normal file
View File

@@ -0,0 +1,121 @@
package auth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
const (
apiKeyPrefix = "lai-"
apiKeyRandBytes = 32 // 32 bytes = 64 hex chars
keyPrefixLen = 8 // display prefix length (from the random part)
)
// GenerateAPIKey generates a new API key. Returns the plaintext key,
// its HMAC-SHA256 hash, and a display prefix.
func GenerateAPIKey(hmacSecret string) (plaintext, hash, prefix string, err error) {
b := make([]byte, apiKeyRandBytes)
if _, err := rand.Read(b); err != nil {
return "", "", "", fmt.Errorf("failed to generate API key: %w", err)
}
randHex := hex.EncodeToString(b)
plaintext = apiKeyPrefix + randHex
hash = HashAPIKey(plaintext, hmacSecret)
prefix = plaintext[:len(apiKeyPrefix)+keyPrefixLen]
return plaintext, hash, prefix, nil
}
// HashAPIKey returns the HMAC-SHA256 hex digest of the given plaintext key.
// If hmacSecret is empty, falls back to plain SHA-256 for backward compatibility.
func HashAPIKey(plaintext, hmacSecret string) string {
if hmacSecret == "" {
h := sha256.Sum256([]byte(plaintext))
return hex.EncodeToString(h[:])
}
mac := hmac.New(sha256.New, []byte(hmacSecret))
mac.Write([]byte(plaintext))
return hex.EncodeToString(mac.Sum(nil))
}
// CreateAPIKey generates and stores a new API key for the given user.
// Returns the plaintext key (shown once) and the database record.
func CreateAPIKey(db *gorm.DB, userID, name, role, hmacSecret string, expiresAt *time.Time) (string, *UserAPIKey, error) {
plaintext, hash, prefix, err := GenerateAPIKey(hmacSecret)
if err != nil {
return "", nil, err
}
record := &UserAPIKey{
ID: uuid.New().String(),
UserID: userID,
Name: name,
KeyHash: hash,
KeyPrefix: prefix,
Role: role,
ExpiresAt: expiresAt,
}
if err := db.Create(record).Error; err != nil {
return "", nil, fmt.Errorf("failed to store API key: %w", err)
}
return plaintext, record, nil
}
// ValidateAPIKey looks up an API key by hashing the plaintext and searching
// the database. Returns the key record if found, or an error.
// Updates LastUsed on successful validation.
func ValidateAPIKey(db *gorm.DB, plaintext, hmacSecret string) (*UserAPIKey, error) {
hash := HashAPIKey(plaintext, hmacSecret)
var key UserAPIKey
if err := db.Preload("User").Where("key_hash = ?", hash).First(&key).Error; err != nil {
return nil, fmt.Errorf("invalid API key")
}
if key.ExpiresAt != nil && time.Now().After(*key.ExpiresAt) {
return nil, fmt.Errorf("API key expired")
}
if key.User.Status != StatusActive {
return nil, fmt.Errorf("user account is not active")
}
// Update LastUsed
now := time.Now()
db.Model(&key).Update("last_used", now)
return &key, nil
}
// ListAPIKeys returns all API keys for the given user (without plaintext).
func ListAPIKeys(db *gorm.DB, userID string) ([]UserAPIKey, error) {
var keys []UserAPIKey
if err := db.Where("user_id = ?", userID).Order("created_at DESC").Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// RevokeAPIKey deletes an API key. Only the owner can revoke their own key.
func RevokeAPIKey(db *gorm.DB, keyID, userID string) error {
result := db.Where("id = ? AND user_id = ?", keyID, userID).Delete(&UserAPIKey{})
if result.RowsAffected == 0 {
return fmt.Errorf("API key not found or not owned by user")
}
return result.Error
}
// CleanExpiredAPIKeys removes all API keys that have passed their expiry time.
func CleanExpiredAPIKeys(db *gorm.DB) error {
return db.Where("expires_at IS NOT NULL AND expires_at < ?", time.Now()).Delete(&UserAPIKey{}).Error
}

View File

@@ -0,0 +1,212 @@
//go:build auth
package auth_test
import (
"strings"
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)
var _ = Describe("API Keys", func() {
var (
db *gorm.DB
user *auth.User
)
// Use empty HMAC secret for tests (falls back to plain SHA-256)
hmacSecret := ""
BeforeEach(func() {
db = testDB()
user = createTestUser(db, "apikey@example.com", auth.RoleUser, auth.ProviderGitHub)
})
Describe("GenerateAPIKey", func() {
It("returns key with 'lai-' prefix", func() {
plaintext, _, _, err := auth.GenerateAPIKey(hmacSecret)
Expect(err).ToNot(HaveOccurred())
Expect(plaintext).To(HavePrefix("lai-"))
})
It("returns consistent hash for same plaintext", func() {
plaintext, hash, _, err := auth.GenerateAPIKey(hmacSecret)
Expect(err).ToNot(HaveOccurred())
Expect(auth.HashAPIKey(plaintext, hmacSecret)).To(Equal(hash))
})
It("returns prefix for display", func() {
_, _, prefix, err := auth.GenerateAPIKey(hmacSecret)
Expect(err).ToNot(HaveOccurred())
Expect(prefix).To(HavePrefix("lai-"))
Expect(len(prefix)).To(Equal(12)) // "lai-" + 8 chars
})
It("generates unique keys", func() {
key1, _, _, _ := auth.GenerateAPIKey(hmacSecret)
key2, _, _, _ := auth.GenerateAPIKey(hmacSecret)
Expect(key1).ToNot(Equal(key2))
})
})
Describe("CreateAPIKey", func() {
It("stores hashed key in DB", func() {
plaintext, record, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser, hmacSecret, nil)
Expect(err).ToNot(HaveOccurred())
Expect(plaintext).To(HavePrefix("lai-"))
Expect(record.KeyHash).To(Equal(auth.HashAPIKey(plaintext, hmacSecret)))
})
It("does not store plaintext in DB", func() {
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser, hmacSecret, nil)
Expect(err).ToNot(HaveOccurred())
var keys []auth.UserAPIKey
db.Find(&keys)
for _, k := range keys {
Expect(k.KeyHash).ToNot(Equal(plaintext))
Expect(strings.Contains(k.KeyHash, "lai-")).To(BeFalse())
}
})
It("inherits role from parameter", func() {
_, record, err := auth.CreateAPIKey(db, user.ID, "admin key", auth.RoleAdmin, hmacSecret, nil)
Expect(err).ToNot(HaveOccurred())
Expect(record.Role).To(Equal(auth.RoleAdmin))
})
})
Describe("ValidateAPIKey", func() {
It("returns UserAPIKey for valid key", func() {
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "valid key", auth.RoleUser, hmacSecret, nil)
Expect(err).ToNot(HaveOccurred())
found, err := auth.ValidateAPIKey(db, plaintext, hmacSecret)
Expect(err).ToNot(HaveOccurred())
Expect(found).ToNot(BeNil())
Expect(found.UserID).To(Equal(user.ID))
})
It("returns error for invalid key", func() {
_, err := auth.ValidateAPIKey(db, "lai-invalidkey12345678901234567890", hmacSecret)
Expect(err).To(HaveOccurred())
})
It("updates LastUsed timestamp", func() {
plaintext, record, err := auth.CreateAPIKey(db, user.ID, "used key", auth.RoleUser, hmacSecret, nil)
Expect(err).ToNot(HaveOccurred())
Expect(record.LastUsed).To(BeNil())
_, err = auth.ValidateAPIKey(db, plaintext, hmacSecret)
Expect(err).ToNot(HaveOccurred())
var updated auth.UserAPIKey
db.First(&updated, "id = ?", record.ID)
Expect(updated.LastUsed).ToNot(BeNil())
})
It("loads associated user", func() {
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "with user", auth.RoleUser, hmacSecret, nil)
Expect(err).ToNot(HaveOccurred())
found, err := auth.ValidateAPIKey(db, plaintext, hmacSecret)
Expect(err).ToNot(HaveOccurred())
Expect(found.User.ID).To(Equal(user.ID))
Expect(found.User.Email).To(Equal("apikey@example.com"))
})
})
Describe("ListAPIKeys", func() {
It("returns all keys for the user", func() {
auth.CreateAPIKey(db, user.ID, "key1", auth.RoleUser, hmacSecret, nil)
auth.CreateAPIKey(db, user.ID, "key2", auth.RoleUser, hmacSecret, nil)
keys, err := auth.ListAPIKeys(db, user.ID)
Expect(err).ToNot(HaveOccurred())
Expect(keys).To(HaveLen(2))
})
It("does not return other users' keys", func() {
other := createTestUser(db, "other@example.com", auth.RoleUser, auth.ProviderGitHub)
auth.CreateAPIKey(db, user.ID, "my key", auth.RoleUser, hmacSecret, nil)
auth.CreateAPIKey(db, other.ID, "other key", auth.RoleUser, hmacSecret, nil)
keys, err := auth.ListAPIKeys(db, user.ID)
Expect(err).ToNot(HaveOccurred())
Expect(keys).To(HaveLen(1))
Expect(keys[0].Name).To(Equal("my key"))
})
})
Context("with HMAC secret", func() {
hmacSecretVal := "test-hmac-secret-456"
It("generates different hash than empty secret", func() {
plaintext, _, _, err := auth.GenerateAPIKey("")
Expect(err).ToNot(HaveOccurred())
hashEmpty := auth.HashAPIKey(plaintext, "")
hashHMAC := auth.HashAPIKey(plaintext, hmacSecretVal)
Expect(hashEmpty).ToNot(Equal(hashHMAC))
})
It("round-trips CreateAPIKey and ValidateAPIKey with HMAC secret", func() {
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "hmac key", auth.RoleUser, hmacSecretVal, nil)
Expect(err).ToNot(HaveOccurred())
found, err := auth.ValidateAPIKey(db, plaintext, hmacSecretVal)
Expect(err).ToNot(HaveOccurred())
Expect(found).ToNot(BeNil())
Expect(found.UserID).To(Equal(user.ID))
})
It("does not validate with wrong HMAC secret", func() {
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "hmac key2", auth.RoleUser, hmacSecretVal, nil)
Expect(err).ToNot(HaveOccurred())
_, err = auth.ValidateAPIKey(db, plaintext, "wrong-secret")
Expect(err).To(HaveOccurred())
})
It("does not validate key created with empty secret using non-empty secret", func() {
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "empty-secret key", auth.RoleUser, "", nil)
Expect(err).ToNot(HaveOccurred())
_, err = auth.ValidateAPIKey(db, plaintext, hmacSecretVal)
Expect(err).To(HaveOccurred())
})
It("does not validate key created with non-empty secret using empty secret", func() {
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "nonempty-secret key", auth.RoleUser, hmacSecretVal, nil)
Expect(err).ToNot(HaveOccurred())
_, err = auth.ValidateAPIKey(db, plaintext, "")
Expect(err).To(HaveOccurred())
})
})
Describe("RevokeAPIKey", func() {
It("deletes the key record", func() {
plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser, hmacSecret, nil)
Expect(err).ToNot(HaveOccurred())
err = auth.RevokeAPIKey(db, record.ID, user.ID)
Expect(err).ToNot(HaveOccurred())
_, err = auth.ValidateAPIKey(db, plaintext, hmacSecret)
Expect(err).To(HaveOccurred())
})
It("only allows owner to revoke their own key", func() {
_, record, err := auth.CreateAPIKey(db, user.ID, "mine", auth.RoleUser, hmacSecret, nil)
Expect(err).ToNot(HaveOccurred())
other := createTestUser(db, "attacker@example.com", auth.RoleUser, auth.ProviderGitHub)
err = auth.RevokeAPIKey(db, record.ID, other.ID)
Expect(err).To(HaveOccurred())
})
})
})

View File

@@ -0,0 +1,15 @@
//go:build auth
package auth_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestAuth(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Auth Suite")
}

49
core/http/auth/db.go Normal file
View File

@@ -0,0 +1,49 @@
package auth
import (
"fmt"
"strings"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitDB initializes the auth database. If databaseURL starts with "postgres://"
// or "postgresql://", it connects to PostgreSQL; otherwise it treats the value
// as a SQLite file path (use ":memory:" for in-memory).
// SQLite support requires building with the "auth" build tag (CGO).
func InitDB(databaseURL string) (*gorm.DB, error) {
var dialector gorm.Dialector
if strings.HasPrefix(databaseURL, "postgres://") || strings.HasPrefix(databaseURL, "postgresql://") {
dialector = postgres.Open(databaseURL)
} else {
d, err := openSQLiteDialector(databaseURL)
if err != nil {
return nil, err
}
dialector = d
}
db, err := gorm.Open(dialector, &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, fmt.Errorf("failed to open auth database: %w", err)
}
if err := db.AutoMigrate(&User{}, &Session{}, &UserAPIKey{}, &UsageRecord{}, &UserPermission{}, &InviteCode{}); err != nil {
return nil, fmt.Errorf("failed to migrate auth tables: %w", err)
}
// Create composite index on users(provider, subject) for fast OAuth lookups
if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil {
// Ignore error on postgres if index already exists
if !strings.Contains(err.Error(), "already exists") {
return nil, fmt.Errorf("failed to create composite index: %w", err)
}
}
return db, nil
}

View File

@@ -0,0 +1,13 @@
//go:build !auth
package auth
import (
"fmt"
"gorm.io/gorm"
)
func openSQLiteDialector(path string) (gorm.Dialector, error) {
return nil, fmt.Errorf("SQLite auth database requires building with -tags auth (CGO); use DATABASE_URL with PostgreSQL instead")
}

View File

@@ -0,0 +1,12 @@
//go:build auth
package auth
import (
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func openSQLiteDialector(path string) (gorm.Dialector, error) {
return sqlite.Open(path), nil
}

53
core/http/auth/db_test.go Normal file
View File

@@ -0,0 +1,53 @@
//go:build auth
package auth_test
import (
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("InitDB", func() {
Context("SQLite", func() {
It("creates all tables with in-memory SQLite", func() {
db, err := auth.InitDB(":memory:")
Expect(err).ToNot(HaveOccurred())
Expect(db).ToNot(BeNil())
// Verify tables exist
Expect(db.Migrator().HasTable(&auth.User{})).To(BeTrue())
Expect(db.Migrator().HasTable(&auth.Session{})).To(BeTrue())
Expect(db.Migrator().HasTable(&auth.UserAPIKey{})).To(BeTrue())
})
It("is idempotent - running twice does not error", func() {
db, err := auth.InitDB(":memory:")
Expect(err).ToNot(HaveOccurred())
// Re-migrate on same DB should succeed
err = db.AutoMigrate(&auth.User{}, &auth.Session{}, &auth.UserAPIKey{})
Expect(err).ToNot(HaveOccurred())
})
It("creates composite index on users(provider, subject)", func() {
db, err := auth.InitDB(":memory:")
Expect(err).ToNot(HaveOccurred())
// Insert a user to verify the index doesn't prevent normal operations
user := &auth.User{
ID: "test-1",
Provider: auth.ProviderGitHub,
Subject: "12345",
Role: "admin",
Status: auth.StatusActive,
}
Expect(db.Create(user).Error).ToNot(HaveOccurred())
// Query using the indexed columns should work
var found auth.User
Expect(db.Where("provider = ? AND subject = ?", auth.ProviderGitHub, "12345").First(&found).Error).ToNot(HaveOccurred())
Expect(found.ID).To(Equal("test-1"))
})
})
})

144
core/http/auth/features.go Normal file
View File

@@ -0,0 +1,144 @@
package auth
// RouteFeature maps a route pattern + HTTP method to a required feature.
type RouteFeature struct {
Method string // "POST", "GET", "*" (any)
Pattern string // Echo route pattern, e.g. "/v1/chat/completions"
Feature string // Feature constant, e.g. FeatureChat
}
// RouteFeatureRegistry is the single source of truth for endpoint -> feature mappings.
// To gate a new endpoint, add an entry here -- no other file changes needed.
var RouteFeatureRegistry = []RouteFeature{
// Chat / Completions
{"POST", "/v1/chat/completions", FeatureChat},
{"POST", "/chat/completions", FeatureChat},
{"POST", "/v1/completions", FeatureChat},
{"POST", "/completions", FeatureChat},
{"POST", "/v1/engines/:model/completions", FeatureChat},
{"POST", "/v1/edits", FeatureChat},
{"POST", "/edits", FeatureChat},
// Anthropic
{"POST", "/v1/messages", FeatureChat},
{"POST", "/messages", FeatureChat},
// Open Responses
{"POST", "/v1/responses", FeatureChat},
{"POST", "/responses", FeatureChat},
{"GET", "/v1/responses", FeatureChat},
{"GET", "/responses", FeatureChat},
// Embeddings
{"POST", "/v1/embeddings", FeatureEmbeddings},
{"POST", "/embeddings", FeatureEmbeddings},
{"POST", "/v1/engines/:model/embeddings", FeatureEmbeddings},
// Images
{"POST", "/v1/images/generations", FeatureImages},
{"POST", "/images/generations", FeatureImages},
{"POST", "/v1/images/inpainting", FeatureImages},
{"POST", "/images/inpainting", FeatureImages},
// Audio transcription
{"POST", "/v1/audio/transcriptions", FeatureAudioTranscription},
{"POST", "/audio/transcriptions", FeatureAudioTranscription},
// Audio speech / TTS
{"POST", "/v1/audio/speech", FeatureAudioSpeech},
{"POST", "/audio/speech", FeatureAudioSpeech},
{"POST", "/tts", FeatureAudioSpeech},
{"POST", "/v1/text-to-speech/:voice-id", FeatureAudioSpeech},
// VAD
{"POST", "/vad", FeatureVAD},
{"POST", "/v1/vad", FeatureVAD},
// Detection
{"POST", "/v1/detection", FeatureDetection},
// Video
{"POST", "/video", FeatureVideo},
// Sound generation
{"POST", "/v1/sound-generation", FeatureSound},
// Realtime
{"GET", "/v1/realtime", FeatureRealtime},
{"POST", "/v1/realtime/sessions", FeatureRealtime},
{"POST", "/v1/realtime/transcription_session", FeatureRealtime},
{"POST", "/v1/realtime/calls", FeatureRealtime},
// MCP
{"POST", "/v1/mcp/chat/completions", FeatureMCP},
{"POST", "/mcp/v1/chat/completions", FeatureMCP},
{"POST", "/mcp/chat/completions", FeatureMCP},
// Tokenize
{"POST", "/v1/tokenize", FeatureTokenize},
// Rerank
{"POST", "/v1/rerank", FeatureRerank},
// Stores
{"POST", "/stores/set", FeatureStores},
{"POST", "/stores/delete", FeatureStores},
{"POST", "/stores/get", FeatureStores},
{"POST", "/stores/find", FeatureStores},
// Fine-tuning
{"POST", "/api/fine-tuning/jobs", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
{"POST", "/api/fine-tuning/jobs/:id/stop", FeatureFineTuning},
{"DELETE", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs/:id/progress", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs/:id/checkpoints", FeatureFineTuning},
{"POST", "/api/fine-tuning/jobs/:id/export", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning},
{"POST", "/api/fine-tuning/datasets", FeatureFineTuning},
}
// FeatureMeta describes a feature for the admin API/UI.
type FeatureMeta struct {
Key string `json:"key"`
Label string `json:"label"`
DefaultValue bool `json:"default"`
}
// AgentFeatureMetas returns metadata for agent features.
func AgentFeatureMetas() []FeatureMeta {
return []FeatureMeta{
{FeatureAgents, "Agents", false},
{FeatureSkills, "Skills", false},
{FeatureCollections, "Collections", false},
{FeatureMCPJobs, "MCP CI Jobs", false},
}
}
// GeneralFeatureMetas returns metadata for general features.
func GeneralFeatureMetas() []FeatureMeta {
return []FeatureMeta{
{FeatureFineTuning, "Fine-Tuning", false},
}
}
// APIFeatureMetas returns metadata for API endpoint features.
func APIFeatureMetas() []FeatureMeta {
return []FeatureMeta{
{FeatureChat, "Chat Completions", true},
{FeatureImages, "Image Generation", true},
{FeatureAudioSpeech, "Audio Speech / TTS", true},
{FeatureAudioTranscription, "Audio Transcription", true},
{FeatureVAD, "Voice Activity Detection", true},
{FeatureDetection, "Detection", true},
{FeatureVideo, "Video Generation", true},
{FeatureEmbeddings, "Embeddings", true},
{FeatureSound, "Sound Generation", true},
{FeatureRealtime, "Realtime", true},
{FeatureRerank, "Rerank", true},
{FeatureTokenize, "Tokenize", true},
{FeatureMCP, "MCP", true},
{FeatureStores, "Stores", true},
}
}

View File

@@ -0,0 +1,155 @@
//go:build auth
package auth_test
import (
"net/http"
"net/http/httptest"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)
// testDB creates an in-memory SQLite GORM instance with auto-migration.
func testDB() *gorm.DB {
db, err := auth.InitDB(":memory:")
Expect(err).ToNot(HaveOccurred())
return db
}
// createTestUser inserts a user directly into the DB for test setup.
func createTestUser(db *gorm.DB, email, role, provider string) *auth.User {
user := &auth.User{
ID: generateTestID(),
Email: email,
Name: "Test User",
Provider: provider,
Subject: generateTestID(),
Role: role,
Status: auth.StatusActive,
}
err := db.Create(user).Error
Expect(err).ToNot(HaveOccurred())
return user
}
// createTestSession creates a session for a user, returns plaintext session token.
func createTestSession(db *gorm.DB, userID string) string {
sessionID, err := auth.CreateSession(db, userID, "")
Expect(err).ToNot(HaveOccurred())
return sessionID
}
var testIDCounter int
func generateTestID() string {
testIDCounter++
return "test-id-" + string(rune('a'+testIDCounter))
}
// ok is a simple handler that returns 200 OK.
func ok(c echo.Context) error {
return c.String(http.StatusOK, "ok")
}
// newAuthTestApp creates a minimal Echo app with the new auth middleware.
func newAuthTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo {
e := echo.New()
e.Use(auth.Middleware(db, appConfig))
// API routes (require auth)
e.GET("/v1/models", ok)
e.POST("/v1/chat/completions", ok)
e.GET("/api/settings", ok)
e.POST("/api/settings", ok)
// Auth routes (exempt)
e.GET("/api/auth/status", ok)
e.GET("/api/auth/github/login", ok)
// Static routes
e.GET("/app", ok)
e.GET("/app/*", ok)
return e
}
// newAdminTestApp creates an Echo app with admin-protected routes.
func newAdminTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo {
e := echo.New()
e.Use(auth.Middleware(db, appConfig))
// Regular routes
e.GET("/v1/models", ok)
e.POST("/v1/chat/completions", ok)
// Admin-only routes
adminMw := auth.RequireAdmin()
e.POST("/api/settings", ok, adminMw)
e.POST("/models/apply", ok, adminMw)
e.POST("/backends/apply", ok, adminMw)
e.GET("/api/agents", ok, adminMw)
// Trace/log endpoints (admin only)
e.GET("/api/traces", ok, adminMw)
e.POST("/api/traces/clear", ok, adminMw)
e.GET("/api/backend-logs", ok, adminMw)
e.GET("/api/backend-logs/:modelId", ok, adminMw)
// Gallery/management reads (admin only)
e.GET("/api/operations", ok, adminMw)
e.GET("/api/models", ok, adminMw)
e.GET("/api/backends", ok, adminMw)
e.GET("/api/resources", ok, adminMw)
e.GET("/api/p2p/workers", ok, adminMw)
// Agent task/job routes (admin only)
e.POST("/api/agent/tasks", ok, adminMw)
e.GET("/api/agent/tasks", ok, adminMw)
e.GET("/api/agent/jobs", ok, adminMw)
// System info (admin only)
e.GET("/system", ok, adminMw)
e.GET("/backend/monitor", ok, adminMw)
return e
}
// doRequest performs an HTTP request against the given Echo app and returns the recorder.
func doRequest(e *echo.Echo, method, path string, opts ...func(*http.Request)) *httptest.ResponseRecorder {
req := httptest.NewRequest(method, path, nil)
req.Header.Set("Content-Type", "application/json")
for _, opt := range opts {
opt(req)
}
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
return rec
}
func withBearerToken(token string) func(*http.Request) {
return func(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+token)
}
}
func withXApiKey(key string) func(*http.Request) {
return func(req *http.Request) {
req.Header.Set("x-api-key", key)
}
}
func withSessionCookie(sessionID string) func(*http.Request) {
return func(req *http.Request) {
req.AddCookie(&http.Cookie{Name: "session", Value: sessionID})
}
}
func withTokenCookie(token string) func(*http.Request) {
return func(req *http.Request) {
req.AddCookie(&http.Cookie{Name: "token", Value: token})
}
}

View File

@@ -0,0 +1,522 @@
package auth
import (
"bytes"
"crypto/subtle"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"gorm.io/gorm"
)
const (
contextKeyUser = "auth_user"
contextKeyRole = "auth_role"
)
// Middleware returns an Echo middleware that handles authentication.
//
// Resolution order:
// 1. If auth not enabled AND no legacy API keys → pass through
// 2. Skip auth for exempt paths (PathWithoutAuth + /api/auth/)
// 3. If auth enabled (db != nil):
// a. Try "session" cookie → DB lookup
// b. Try Authorization: Bearer → session ID, then user API key
// c. Try x-api-key / xi-api-key → user API key
// d. Try "token" cookie → legacy API key check
// e. Check all extracted keys against legacy ApiKeys → synthetic admin
// 4. If auth not enabled → delegate to legacy API key validation
// 5. If no auth found for /api/ or /v1/ paths → 401
// 6. Otherwise pass through (static assets, UI pages, etc.)
func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
authEnabled := db != nil
hasLegacyKeys := len(appConfig.ApiKeys) > 0
// 1. No auth at all
if !authEnabled && !hasLegacyKeys {
return next(c)
}
path := c.Request().URL.Path
exempt := isExemptPath(path, appConfig)
authenticated := false
// 2. Try to authenticate (populates user in context if possible)
if authEnabled {
user := tryAuthenticate(c, db, appConfig)
if user != nil {
c.Set(contextKeyUser, user)
c.Set(contextKeyRole, user.Role)
authenticated = true
// Session rotation for cookie-based sessions
if session, ok := c.Get("_auth_session").(*Session); ok {
MaybeRotateSession(c, db, session, appConfig.Auth.APIKeyHMACSecret)
}
}
}
// 3. Legacy API key validation (works whether auth is enabled or not)
if !authenticated && hasLegacyKeys {
key := extractKey(c)
if key != "" && isValidLegacyKey(key, appConfig) {
syntheticUser := &User{
ID: "legacy-api-key",
Name: "API Key User",
Role: RoleAdmin,
}
c.Set(contextKeyUser, syntheticUser)
c.Set(contextKeyRole, RoleAdmin)
authenticated = true
}
}
// 4. If authenticated or exempt path, proceed
if authenticated || exempt {
return next(c)
}
// 5. Require auth for API paths
if isAPIPath(path) {
// Check GET exemptions for legacy keys
if hasLegacyKeys && appConfig.DisableApiKeyRequirementForHttpGet && c.Request().Method == http.MethodGet {
for _, rx := range appConfig.HttpGetExemptedEndpoints {
if rx.MatchString(c.Path()) {
return next(c)
}
}
}
return authError(c, appConfig)
}
// 6. Non-API paths (UI, static assets) pass through.
// The React UI handles login redirects client-side.
return next(c)
}
}
}
// RequireAdmin returns middleware that checks the user has admin role.
func RequireAdmin() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
user := GetUser(c)
if user == nil {
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
Error: &schema.APIError{
Message: "Authentication required",
Code: http.StatusUnauthorized,
Type: "authentication_error",
},
})
}
if user.Role != RoleAdmin {
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "Admin access required",
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
}
return next(c)
}
}
}
// NoopMiddleware returns a middleware that does nothing (pass-through).
// Used when auth is disabled to satisfy route registration that expects
// an admin middleware parameter.
func NoopMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return next
}
}
// RequireFeature returns middleware that checks the user has access to the given feature.
// If no auth DB is provided, it passes through (backward compat).
// Admins always pass. Regular users must have the feature enabled in their permissions.
func RequireFeature(db *gorm.DB, feature string) echo.MiddlewareFunc {
if db == nil {
return NoopMiddleware()
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
user := GetUser(c)
if user == nil {
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
Error: &schema.APIError{
Message: "Authentication required",
Code: http.StatusUnauthorized,
Type: "authentication_error",
},
})
}
if user.Role == RoleAdmin {
return next(c)
}
perm, err := GetCachedUserPermissions(c, db, user.ID)
if err != nil {
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "feature not enabled for your account",
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
}
val, exists := perm.Permissions[feature]
if !exists {
if !isDefaultOnFeature(feature) {
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "feature not enabled for your account",
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
}
} else if !val {
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "feature not enabled for your account",
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
}
return next(c)
}
}
}
// GetUser returns the authenticated user from the echo context, or nil.
func GetUser(c echo.Context) *User {
u, ok := c.Get(contextKeyUser).(*User)
if !ok {
return nil
}
return u
}
// GetUserRole returns the role of the authenticated user, or empty string.
func GetUserRole(c echo.Context) string {
role, _ := c.Get(contextKeyRole).(string)
return role
}
// RequireRouteFeature returns a global middleware that checks the user has access
// to the feature required by the matched route. It uses the RouteFeatureRegistry
// to look up the required feature for each route pattern + HTTP method.
// If no entry matches, the request passes through (no restriction).
func RequireRouteFeature(db *gorm.DB) echo.MiddlewareFunc {
if db == nil {
return NoopMiddleware()
}
// Pre-build lookup map: "METHOD:pattern" -> feature
lookup := map[string]string{}
for _, rf := range RouteFeatureRegistry {
lookup[rf.Method+":"+rf.Pattern] = rf.Feature
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
path := c.Path() // Echo route pattern (e.g. "/v1/engines/:model/completions")
method := c.Request().Method
feature := lookup[method+":"+path]
if feature == "" {
feature = lookup["*:"+path]
}
if feature == "" {
return next(c) // no restriction for this route
}
user := GetUser(c)
if user == nil {
return next(c) // auth middleware handles unauthenticated
}
if user.Role == RoleAdmin {
return next(c)
}
perm, err := GetCachedUserPermissions(c, db, user.ID)
if err != nil {
return c.JSON(http.StatusInternalServerError, schema.ErrorResponse{
Error: &schema.APIError{
Message: "failed to check permissions",
Code: http.StatusInternalServerError,
Type: "server_error",
},
})
}
val, exists := perm.Permissions[feature]
if !exists {
if !isDefaultOnFeature(feature) {
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "feature not enabled for your account: " + feature,
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
}
} else if !val {
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "feature not enabled for your account: " + feature,
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
}
return next(c)
}
}
}
// RequireModelAccess returns a global middleware that checks the user is allowed
// to use the resolved model. It extracts the model name directly from the request
// (path param, query param, JSON body, or form value) rather than relying on a
// context key set by downstream route-specific middleware.
func RequireModelAccess(db *gorm.DB) echo.MiddlewareFunc {
if db == nil {
return NoopMiddleware()
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
user := GetUser(c)
if user == nil {
return next(c)
}
if user.Role == RoleAdmin {
return next(c)
}
// Check if this user even has a model allowlist enabled before
// doing the expensive body read. Most users won't have restrictions.
// Uses request-scoped cache to avoid duplicate DB hit when
// RequireRouteFeature already fetched permissions.
perm, err := GetCachedUserPermissions(c, db, user.ID)
if err != nil {
return c.JSON(http.StatusInternalServerError, schema.ErrorResponse{
Error: &schema.APIError{
Message: "failed to check permissions",
Code: http.StatusInternalServerError,
Type: "server_error",
},
})
}
allowlist := perm.AllowedModels
if !allowlist.Enabled {
return next(c)
}
modelName := extractModelFromRequest(c)
if modelName == "" {
return next(c)
}
for _, m := range allowlist.Models {
if m == modelName {
return next(c)
}
}
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{
Message: "access denied to model: " + modelName,
Code: http.StatusForbidden,
Type: "authorization_error",
},
})
}
}
}
// extractModelFromRequest extracts the model name from various request sources.
// It checks URL path params, query params, JSON body, and form values.
// For JSON bodies, it peeks at the body and resets it so downstream handlers
// can still read it.
func extractModelFromRequest(c echo.Context) string {
// 1. URL path param (e.g. /v1/engines/:model/completions)
if model := c.Param("model"); model != "" {
return model
}
// 2. Query param
if model := c.QueryParam("model"); model != "" {
return model
}
// 3. Peek at JSON body
if strings.HasPrefix(c.Request().Header.Get("Content-Type"), "application/json") {
body, err := io.ReadAll(c.Request().Body)
c.Request().Body = io.NopCloser(bytes.NewReader(body)) // always reset
if err == nil && len(body) > 0 {
var m struct {
Model string `json:"model"`
}
if json.Unmarshal(body, &m) == nil && m.Model != "" {
return m.Model
}
}
}
// 4. Form value (multipart/form-data)
if model := c.FormValue("model"); model != "" {
return model
}
return ""
}
// tryAuthenticate attempts to authenticate the request using the database.
func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User {
hmacSecret := appConfig.Auth.APIKeyHMACSecret
// a. Session cookie
if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" {
if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil {
// Store session for rotation check in middleware
c.Set("_auth_session", session)
return user
}
}
// b. Authorization: Bearer token
authHeader := c.Request().Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token := strings.TrimPrefix(authHeader, "Bearer ")
// Try as session ID first
if user, _ := ValidateSession(db, token, hmacSecret); user != nil {
return user
}
// Try as user API key
if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil {
return &key.User
}
}
// c. x-api-key / xi-api-key headers
for _, header := range []string{"x-api-key", "xi-api-key"} {
if key := c.Request().Header.Get(header); key != "" {
if apiKey, err := ValidateAPIKey(db, key, hmacSecret); err == nil {
return &apiKey.User
}
}
}
// d. token cookie (legacy)
if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" {
// Try as user API key
if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil {
return &key.User
}
}
return nil
}
// extractKey extracts an API key from the request (all sources).
func extractKey(c echo.Context) string {
// Authorization header
auth := c.Request().Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
}
if auth != "" {
return auth
}
// x-api-key
if key := c.Request().Header.Get("x-api-key"); key != "" {
return key
}
// xi-api-key
if key := c.Request().Header.Get("xi-api-key"); key != "" {
return key
}
// token cookie
if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" {
return cookie.Value
}
return ""
}
// isValidLegacyKey checks if the key matches any configured API key
// using constant-time comparison to prevent timing attacks.
func isValidLegacyKey(key string, appConfig *config.ApplicationConfig) bool {
for _, validKey := range appConfig.ApiKeys {
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
return true
}
}
return false
}
// isExemptPath returns true if the path should skip authentication.
func isExemptPath(path string, appConfig *config.ApplicationConfig) bool {
// Auth endpoints are always public
if strings.HasPrefix(path, "/api/auth/") {
return true
}
// Check configured exempt paths
for _, p := range appConfig.PathWithoutAuth {
if strings.HasPrefix(path, p) {
return true
}
}
return false
}
// isAPIPath returns true for paths that always require authentication.
func isAPIPath(path string) bool {
return strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/models/") ||
strings.HasPrefix(path, "/backends/") ||
strings.HasPrefix(path, "/backend/") ||
strings.HasPrefix(path, "/tts") ||
strings.HasPrefix(path, "/vad") ||
strings.HasPrefix(path, "/video") ||
strings.HasPrefix(path, "/stores/") ||
strings.HasPrefix(path, "/system") ||
strings.HasPrefix(path, "/ws/") ||
strings.HasPrefix(path, "/generated-") ||
path == "/metrics"
}
// authError returns an appropriate error response.
func authError(c echo.Context, appConfig *config.ApplicationConfig) error {
c.Response().Header().Set("WWW-Authenticate", "Bearer")
if appConfig.OpaqueErrors {
return c.NoContent(http.StatusUnauthorized)
}
contentType := c.Request().Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
Error: &schema.APIError{
Message: "An authentication key is required",
Code: http.StatusUnauthorized,
Type: "invalid_request_error",
},
})
}
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
Error: &schema.APIError{
Message: "An authentication key is required",
Code: http.StatusUnauthorized,
Type: "invalid_request_error",
},
})
}

View File

@@ -0,0 +1,306 @@
//go:build auth
package auth_test
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)
var _ = Describe("Auth Middleware", func() {
Context("auth disabled, no API keys", func() {
var app *echo.Echo
BeforeEach(func() {
appConfig := config.NewApplicationConfig()
app = newAuthTestApp(nil, appConfig)
})
It("passes through all requests", func() {
rec := doRequest(app, http.MethodGet, "/v1/models")
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("passes through POST requests", func() {
rec := doRequest(app, http.MethodPost, "/v1/chat/completions")
Expect(rec.Code).To(Equal(http.StatusOK))
})
})
Context("auth disabled, API keys configured", func() {
var app *echo.Echo
const validKey = "sk-test-key-123"
BeforeEach(func() {
appConfig := config.NewApplicationConfig()
appConfig.ApiKeys = []string{validKey}
app = newAuthTestApp(nil, appConfig)
})
It("returns 401 for request without key", func() {
rec := doRequest(app, http.MethodGet, "/v1/models")
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
It("passes with valid Bearer token", func() {
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(validKey))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("passes with valid x-api-key header", func() {
rec := doRequest(app, http.MethodGet, "/v1/models", withXApiKey(validKey))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("passes with valid token cookie", func() {
rec := doRequest(app, http.MethodGet, "/v1/models", withTokenCookie(validKey))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("returns 401 for invalid key", func() {
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("wrong-key"))
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
})
Context("auth enabled with database", func() {
var (
db *gorm.DB
app *echo.Echo
appConfig *config.ApplicationConfig
user *auth.User
)
BeforeEach(func() {
db = testDB()
appConfig = config.NewApplicationConfig()
app = newAuthTestApp(db, appConfig)
user = createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
})
It("allows requests with valid session cookie", func() {
sessionID := createTestSession(db, user.ID)
rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("allows requests with valid session as Bearer token", func() {
sessionID := createTestSession(db, user.ID)
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("allows requests with valid user API key as Bearer token", func() {
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test", auth.RoleUser, "", nil)
Expect(err).ToNot(HaveOccurred())
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("allows requests with legacy API_KEY as admin bypass", func() {
appConfig.ApiKeys = []string{"legacy-key-123"}
app = newAuthTestApp(db, appConfig)
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("legacy-key-123"))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("returns 401 for expired session", func() {
sessionID := createTestSession(db, user.ID)
// Manually expire (session ID in DB is the hash)
hash := auth.HashAPIKey(sessionID, "")
db.Model(&auth.Session{}).Where("id = ?", hash).
Update("expires_at", "2020-01-01")
rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
It("returns 401 for invalid session ID", func() {
rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie("invalid-session-id"))
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
It("returns 401 for revoked API key", func() {
plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser, "", nil)
Expect(err).ToNot(HaveOccurred())
err = auth.RevokeAPIKey(db, record.ID, user.ID)
Expect(err).ToNot(HaveOccurred())
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext))
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
It("skips auth for /api/auth/* paths", func() {
rec := doRequest(app, http.MethodGet, "/api/auth/status")
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("skips auth for PathWithoutAuth paths", func() {
rec := doRequest(app, http.MethodGet, "/healthz")
// healthz is not registered in our test app, so it'll be 404/405 but NOT 401
Expect(rec.Code).ToNot(Equal(http.StatusUnauthorized))
})
It("returns 401 for unauthenticated API requests", func() {
rec := doRequest(app, http.MethodGet, "/v1/models")
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
It("allows unauthenticated access to non-API paths when no legacy keys", func() {
rec := doRequest(app, http.MethodGet, "/app")
Expect(rec.Code).To(Equal(http.StatusOK))
})
})
Describe("RequireAdmin", func() {
var (
db *gorm.DB
appConfig *config.ApplicationConfig
)
BeforeEach(func() {
db = testDB()
appConfig = config.NewApplicationConfig()
})
It("passes for admin user", func() {
admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub)
sessionID := createTestSession(db, admin.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("returns 403 for user role", func() {
user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
sessionID := createTestSession(db, user.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusForbidden))
})
It("returns 401 when no user in context", func() {
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodPost, "/api/settings")
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
It("allows admin to access model management", func() {
admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub)
sessionID := createTestSession(db, admin.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("blocks user from model management", func() {
user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
sessionID := createTestSession(db, user.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusForbidden))
})
It("allows user to access regular inference endpoints", func() {
user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
sessionID := createTestSession(db, user.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodPost, "/v1/chat/completions", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("allows legacy API key (admin bypass) on admin routes", func() {
appConfig.ApiKeys = []string{"admin-key"}
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodPost, "/api/settings", withBearerToken("admin-key"))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("allows admin to access trace endpoints", func() {
admin := createTestUser(db, "admin2@example.com", auth.RoleAdmin, auth.ProviderGitHub)
sessionID := createTestSession(db, admin.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("blocks non-admin from trace endpoints", func() {
user := createTestUser(db, "user2@example.com", auth.RoleUser, auth.ProviderGitHub)
sessionID := createTestSession(db, user.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusForbidden))
rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusForbidden))
})
It("allows admin to access agent job endpoints", func() {
admin := createTestUser(db, "admin3@example.com", auth.RoleAdmin, auth.ProviderGitHub)
sessionID := createTestSession(db, admin.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK))
})
It("blocks non-admin from agent job endpoints", func() {
user := createTestUser(db, "user3@example.com", auth.RoleUser, auth.ProviderGitHub)
sessionID := createTestSession(db, user.ID)
app := newAdminTestApp(db, appConfig)
rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusForbidden))
rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusForbidden))
})
It("blocks non-admin from system/management endpoints", func() {
user := createTestUser(db, "user4@example.com", auth.RoleUser, auth.ProviderGitHub)
sessionID := createTestSession(db, user.ID)
app := newAdminTestApp(db, appConfig)
for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} {
rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusForbidden), "expected 403 for path: "+path)
}
})
It("allows admin to access system/management endpoints", func() {
admin := createTestUser(db, "admin4@example.com", auth.RoleAdmin, auth.ProviderGitHub)
sessionID := createTestSession(db, admin.ID)
app := newAdminTestApp(db, appConfig)
for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} {
rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID))
Expect(rec.Code).To(Equal(http.StatusOK), "expected 200 for path: "+path)
}
})
})
})

148
core/http/auth/models.go Normal file
View File

@@ -0,0 +1,148 @@
package auth
import (
"database/sql/driver"
"encoding/json"
"fmt"
"time"
)
// Auth provider constants.
const (
ProviderLocal = "local"
ProviderGitHub = "github"
ProviderOIDC = "oidc"
)
// User represents an authenticated user.
type User struct {
ID string `gorm:"primaryKey;size:36"`
Email string `gorm:"size:255;index"`
Name string `gorm:"size:255"`
AvatarURL string `gorm:"size:512"`
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
Subject string `gorm:"size:255"` // provider-specific user ID
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
Role string `gorm:"size:20;default:user"`
Status string `gorm:"size:20;default:active"` // "active", "pending"
CreatedAt time.Time
UpdatedAt time.Time
}
// Session represents a user login session.
type Session struct {
ID string `gorm:"primaryKey;size:64"` // HMAC-SHA256 hash of session token
UserID string `gorm:"size:36;index"`
ExpiresAt time.Time
RotatedAt time.Time
CreatedAt time.Time
User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
}
// UserAPIKey represents a user-generated API key for programmatic access.
type UserAPIKey struct {
ID string `gorm:"primaryKey;size:36"`
UserID string `gorm:"size:36;index"`
Name string `gorm:"size:255"` // user-provided label
KeyHash string `gorm:"size:64;uniqueIndex"`
KeyPrefix string `gorm:"size:12"` // first 8 chars of key for display
Role string `gorm:"size:20"`
CreatedAt time.Time
ExpiresAt *time.Time `gorm:"index"`
LastUsed *time.Time
User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
}
// PermissionMap is a flexible map of feature -> enabled, stored as JSON text.
// Known features: "agents", "skills", "collections", "mcp_jobs".
// New features can be added without schema changes.
type PermissionMap map[string]bool
// Value implements driver.Valuer for GORM JSON serialization.
func (p PermissionMap) Value() (driver.Value, error) {
if p == nil {
return "{}", nil
}
b, err := json.Marshal(p)
if err != nil {
return nil, fmt.Errorf("failed to marshal PermissionMap: %w", err)
}
return string(b), nil
}
// Scan implements sql.Scanner for GORM JSON deserialization.
func (p *PermissionMap) Scan(value any) error {
if value == nil {
*p = PermissionMap{}
return nil
}
var bytes []byte
switch v := value.(type) {
case string:
bytes = []byte(v)
case []byte:
bytes = v
default:
return fmt.Errorf("cannot scan %T into PermissionMap", value)
}
return json.Unmarshal(bytes, p)
}
// InviteCode represents an admin-generated invitation for user registration.
type InviteCode struct {
ID string `gorm:"primaryKey;size:36"`
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
CreatedBy string `gorm:"size:36;not null"`
UsedBy *string `gorm:"size:36"`
UsedAt *time.Time
ExpiresAt time.Time `gorm:"not null;index"`
CreatedAt time.Time
Creator User `gorm:"foreignKey:CreatedBy"`
Consumer *User `gorm:"foreignKey:UsedBy"`
}
// ModelAllowlist controls which models a user can access.
// When Enabled is false (default), all models are allowed.
type ModelAllowlist struct {
Enabled bool `json:"enabled"`
Models []string `json:"models,omitempty"`
}
// Value implements driver.Valuer for GORM JSON serialization.
func (m ModelAllowlist) Value() (driver.Value, error) {
b, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("failed to marshal ModelAllowlist: %w", err)
}
return string(b), nil
}
// Scan implements sql.Scanner for GORM JSON deserialization.
func (m *ModelAllowlist) Scan(value any) error {
if value == nil {
*m = ModelAllowlist{}
return nil
}
var bytes []byte
switch v := value.(type) {
case string:
bytes = []byte(v)
case []byte:
bytes = v
default:
return fmt.Errorf("cannot scan %T into ModelAllowlist", value)
}
return json.Unmarshal(bytes, m)
}
// UserPermission stores per-user feature permissions.
type UserPermission struct {
ID string `gorm:"primaryKey;size:36"`
UserID string `gorm:"size:36;uniqueIndex"`
Permissions PermissionMap `gorm:"type:text"`
AllowedModels ModelAllowlist `gorm:"type:text"`
CreatedAt time.Time
UpdatedAt time.Time
User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
}

439
core/http/auth/oauth.go Normal file
View File

@@ -0,0 +1,439 @@
package auth
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/xlog"
"golang.org/x/oauth2"
githubOAuth "golang.org/x/oauth2/github"
"gorm.io/gorm"
)
// providerEntry holds the OAuth2/OIDC config for a single provider.
type providerEntry struct {
oauth2Config oauth2.Config
oidcVerifier *oidc.IDTokenVerifier // nil for GitHub (API-based user info)
name string
userInfoURL string // only used for GitHub
}
// oauthUserInfo is a provider-agnostic representation of an authenticated user.
type oauthUserInfo struct {
Subject string
Email string
Name string
AvatarURL string
}
// OAuthManager manages multiple OAuth/OIDC providers.
type OAuthManager struct {
providers map[string]*providerEntry
}
// OAuthParams groups the parameters needed to create an OAuthManager.
type OAuthParams struct {
GitHubClientID string
GitHubClientSecret string
OIDCIssuer string
OIDCClientID string
OIDCClientSecret string
}
// NewOAuthManager creates an OAuthManager from the given params.
func NewOAuthManager(baseURL string, params OAuthParams) (*OAuthManager, error) {
m := &OAuthManager{providers: make(map[string]*providerEntry)}
if params.GitHubClientID != "" {
m.providers[ProviderGitHub] = &providerEntry{
name: ProviderGitHub,
oauth2Config: oauth2.Config{
ClientID: params.GitHubClientID,
ClientSecret: params.GitHubClientSecret,
Endpoint: githubOAuth.Endpoint,
RedirectURL: baseURL + "/api/auth/github/callback",
Scopes: []string{"user:email", "read:user"},
},
userInfoURL: "https://api.github.com/user",
}
}
if params.OIDCClientID != "" && params.OIDCIssuer != "" {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
provider, err := oidc.NewProvider(ctx, params.OIDCIssuer)
if err != nil {
return nil, fmt.Errorf("OIDC discovery failed for %s: %w", params.OIDCIssuer, err)
}
verifier := provider.Verifier(&oidc.Config{ClientID: params.OIDCClientID})
m.providers[ProviderOIDC] = &providerEntry{
name: ProviderOIDC,
oauth2Config: oauth2.Config{
ClientID: params.OIDCClientID,
ClientSecret: params.OIDCClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: baseURL + "/api/auth/oidc/callback",
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
},
oidcVerifier: verifier,
}
}
return m, nil
}
// Providers returns the list of configured provider names.
func (m *OAuthManager) Providers() []string {
names := make([]string, 0, len(m.providers))
for name := range m.providers {
names = append(names, name)
}
return names
}
// LoginHandler redirects the user to the OAuth provider's login page.
func (m *OAuthManager) LoginHandler(providerName string) echo.HandlerFunc {
return func(c echo.Context) error {
provider, ok := m.providers[providerName]
if !ok {
return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"})
}
state, err := generateState()
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to generate state"})
}
secure := isSecure(c)
c.SetCookie(&http.Cookie{
Name: "oauth_state",
Value: state,
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
MaxAge: 600, // 10 minutes
})
// Store invite code in cookie if provided
if inviteCode := c.QueryParam("invite_code"); inviteCode != "" {
c.SetCookie(&http.Cookie{
Name: "invite_code",
Value: inviteCode,
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
MaxAge: 600,
})
}
url := provider.oauth2Config.AuthCodeURL(state)
return c.Redirect(http.StatusTemporaryRedirect, url)
}
}
// CallbackHandler handles the OAuth callback, creates/updates the user, and
// creates a session.
func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEmail, registrationMode, hmacSecret string) echo.HandlerFunc {
return func(c echo.Context) error {
provider, ok := m.providers[providerName]
if !ok {
return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"})
}
// Validate state
stateCookie, err := c.Cookie("oauth_state")
if err != nil || stateCookie.Value == "" || subtle.ConstantTimeCompare([]byte(stateCookie.Value), []byte(c.QueryParam("state"))) != 1 {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid OAuth state"})
}
// Clear state cookie
c.SetCookie(&http.Cookie{
Name: "oauth_state",
Value: "",
Path: "/",
HttpOnly: true,
Secure: isSecure(c),
MaxAge: -1,
})
// Exchange code for token
code := c.QueryParam("code")
if code == "" {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing authorization code"})
}
ctx, cancel := context.WithTimeout(c.Request().Context(), 30*time.Second)
defer cancel()
token, err := provider.oauth2Config.Exchange(ctx, code)
if err != nil {
xlog.Error("OAuth code exchange failed", "provider", providerName, "error", err)
return c.JSON(http.StatusBadRequest, map[string]string{"error": "OAuth authentication failed"})
}
// Fetch user info — branch based on provider type
var userInfo *oauthUserInfo
if provider.oidcVerifier != nil {
userInfo, err = extractOIDCUserInfo(ctx, provider.oidcVerifier, token)
} else {
userInfo, err = fetchGitHubUserInfoAsOAuth(ctx, token.AccessToken)
}
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to fetch user info"})
}
// Retrieve invite code from cookie if present
var inviteCode string
if ic, err := c.Cookie("invite_code"); err == nil && ic.Value != "" {
inviteCode = ic.Value
// Clear the invite code cookie
c.SetCookie(&http.Cookie{
Name: "invite_code",
Value: "",
Path: "/",
HttpOnly: true,
Secure: isSecure(c),
MaxAge: -1,
})
}
// Upsert user (with invite code support)
user, err := upsertOAuthUser(db, providerName, userInfo, adminEmail, registrationMode)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create user"})
}
// For new users that are pending, check if they have a valid invite
if user.Status != StatusActive && inviteCode != "" {
if invite, err := ValidateInvite(db, inviteCode, hmacSecret); err == nil {
user.Status = StatusActive
db.Model(user).Update("status", StatusActive)
ConsumeInvite(db, invite, user.ID)
}
}
if user.Status != StatusActive {
if registrationMode == "invite" {
return c.JSON(http.StatusForbidden, map[string]string{"error": "a valid invite code is required to register"})
}
return c.JSON(http.StatusForbidden, map[string]string{"error": "account pending approval"})
}
// Maybe promote on login
MaybePromote(db, user, adminEmail)
// Create session
sessionID, err := CreateSession(db, user.ID, hmacSecret)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create session"})
}
SetSessionCookie(c, sessionID)
return c.Redirect(http.StatusTemporaryRedirect, "/app")
}
}
// extractOIDCUserInfo extracts user info from the OIDC ID token.
func extractOIDCUserInfo(ctx context.Context, verifier *oidc.IDTokenVerifier, token *oauth2.Token) (*oauthUserInfo, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok || rawIDToken == "" {
return nil, fmt.Errorf("no id_token in token response")
}
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("failed to verify ID token: %w", err)
}
var claims struct {
Sub string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
Picture string `json:"picture"`
}
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse ID token claims: %w", err)
}
return &oauthUserInfo{
Subject: claims.Sub,
Email: claims.Email,
Name: claims.Name,
AvatarURL: claims.Picture,
}, nil
}
type githubUserInfo struct {
ID int `json:"id"`
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
}
type githubEmail struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
// fetchGitHubUserInfoAsOAuth fetches GitHub user info and returns it as oauthUserInfo.
func fetchGitHubUserInfoAsOAuth(ctx context.Context, accessToken string) (*oauthUserInfo, error) {
info, err := fetchGitHubUserInfo(ctx, accessToken)
if err != nil {
return nil, err
}
return &oauthUserInfo{
Subject: fmt.Sprintf("%d", info.ID),
Email: info.Email,
Name: info.Name,
AvatarURL: info.AvatarURL,
}, nil
}
func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserInfo, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var info githubUserInfo
if err := json.Unmarshal(body, &info); err != nil {
return nil, err
}
// If no public email, fetch from /user/emails
if info.Email == "" {
info.Email, _ = fetchGitHubPrimaryEmail(ctx, accessToken)
}
return &info, nil
}
func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
var emails []githubEmail
if err := json.Unmarshal(body, &emails); err != nil {
return "", err
}
for _, e := range emails {
if e.Primary && e.Verified {
return e.Email, nil
}
}
// Fall back to first verified email
for _, e := range emails {
if e.Verified {
return e.Email, nil
}
}
return "", fmt.Errorf("no verified email found")
}
func upsertOAuthUser(db *gorm.DB, provider string, info *oauthUserInfo, adminEmail, registrationMode string) (*User, error) {
// Normalize email from provider (#10)
if info.Email != "" {
info.Email = strings.ToLower(strings.TrimSpace(info.Email))
}
var user User
err := db.Where("provider = ? AND subject = ?", provider, info.Subject).First(&user).Error
if err == nil {
// Existing user — update profile fields
user.Name = info.Name
user.AvatarURL = info.AvatarURL
if info.Email != "" {
user.Email = info.Email
}
db.Save(&user)
return &user, nil
}
// New user — empty registration mode defaults to "approval"
effectiveMode := registrationMode
if effectiveMode == "" {
effectiveMode = "approval"
}
status := StatusActive
if effectiveMode == "approval" || effectiveMode == "invite" {
status = StatusPending
}
role := AssignRole(db, info.Email, adminEmail)
// First user is always active regardless of registration mode
if role == RoleAdmin {
status = StatusActive
}
user = User{
ID: uuid.New().String(),
Email: info.Email,
Name: info.Name,
AvatarURL: info.AvatarURL,
Provider: provider,
Subject: info.Subject,
Role: role,
Status: status,
}
if err := db.Create(&user).Error; err != nil {
return nil, err
}
return &user, nil
}
func generateState() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

View File

@@ -0,0 +1,14 @@
package auth
import "golang.org/x/crypto/bcrypt"
// HashPassword returns a bcrypt hash of the given password.
func HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
// CheckPassword compares a bcrypt hash with a plaintext password.
func CheckPassword(hash, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
}

View File

@@ -0,0 +1,217 @@
package auth
import (
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
)
const contextKeyPermissions = "auth_permissions"
// GetCachedUserPermissions returns the user's permission record, using a
// request-scoped cache stored in the echo context. This avoids duplicate
// DB lookups when multiple middlewares (RequireRouteFeature, RequireModelAccess)
// both need permissions in the same request.
func GetCachedUserPermissions(c echo.Context, db *gorm.DB, userID string) (*UserPermission, error) {
if perm, ok := c.Get(contextKeyPermissions).(*UserPermission); ok && perm != nil {
return perm, nil
}
perm, err := GetUserPermissions(db, userID)
if err != nil {
return nil, err
}
c.Set(contextKeyPermissions, perm)
return perm, nil
}
// Feature name constants — all code must use these, never bare strings.
const (
// Agent features (default OFF for new users)
FeatureAgents = "agents"
FeatureSkills = "skills"
FeatureCollections = "collections"
FeatureMCPJobs = "mcp_jobs"
// General features (default OFF for new users)
FeatureFineTuning = "fine_tuning"
// API features (default ON for new users)
FeatureChat = "chat"
FeatureImages = "images"
FeatureAudioSpeech = "audio_speech"
FeatureAudioTranscription = "audio_transcription"
FeatureVAD = "vad"
FeatureDetection = "detection"
FeatureVideo = "video"
FeatureEmbeddings = "embeddings"
FeatureSound = "sound"
FeatureRealtime = "realtime"
FeatureRerank = "rerank"
FeatureTokenize = "tokenize"
FeatureMCP = "mcp"
FeatureStores = "stores"
)
// AgentFeatures lists agent-related features (default OFF).
var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs}
// GeneralFeatures lists general features (default OFF).
var GeneralFeatures = []string{FeatureFineTuning}
// APIFeatures lists API endpoint features (default ON).
var APIFeatures = []string{
FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription,
FeatureVAD, FeatureDetection, FeatureVideo, FeatureEmbeddings, FeatureSound,
FeatureRealtime, FeatureRerank, FeatureTokenize, FeatureMCP, FeatureStores,
}
// AllFeatures lists all known features (used by UI and validation).
var AllFeatures = append(append(append([]string{}, AgentFeatures...), GeneralFeatures...), APIFeatures...)
// defaultOnFeatures is the set of features that default to ON when absent from a user's permission map.
var defaultOnFeatures = func() map[string]bool {
m := map[string]bool{}
for _, f := range APIFeatures {
m[f] = true
}
return m
}()
// isDefaultOnFeature returns true if the feature defaults to ON when not explicitly set.
func isDefaultOnFeature(feature string) bool {
return defaultOnFeatures[feature]
}
// GetUserPermissions returns the permission record for a user, creating a default
// (empty map = all disabled) if none exists.
func GetUserPermissions(db *gorm.DB, userID string) (*UserPermission, error) {
var perm UserPermission
err := db.Where("user_id = ?", userID).First(&perm).Error
if err == gorm.ErrRecordNotFound {
perm = UserPermission{
ID: uuid.New().String(),
UserID: userID,
Permissions: PermissionMap{},
}
if err := db.Create(&perm).Error; err != nil {
return nil, err
}
return &perm, nil
}
if err != nil {
return nil, err
}
return &perm, nil
}
// UpdateUserPermissions upserts the permission map for a user.
func UpdateUserPermissions(db *gorm.DB, userID string, perms PermissionMap) error {
var perm UserPermission
err := db.Where("user_id = ?", userID).First(&perm).Error
if err == gorm.ErrRecordNotFound {
perm = UserPermission{
ID: uuid.New().String(),
UserID: userID,
Permissions: perms,
}
return db.Create(&perm).Error
}
if err != nil {
return err
}
perm.Permissions = perms
return db.Save(&perm).Error
}
// HasFeatureAccess returns true if the user is an admin or has the given feature enabled.
// When a feature key is absent from the user's permission map, it checks whether the
// feature defaults to ON (API features) or OFF (agent features) for backward compatibility.
func HasFeatureAccess(db *gorm.DB, user *User, feature string) bool {
if user == nil {
return false
}
if user.Role == RoleAdmin {
return true
}
perm, err := GetUserPermissions(db, user.ID)
if err != nil {
return false
}
val, exists := perm.Permissions[feature]
if !exists {
return isDefaultOnFeature(feature)
}
return val
}
// GetPermissionMapForUser returns the effective permission map for a user.
// Admins get all features as true (virtual).
// For regular users, absent keys are filled with their defaults so the
// UI/API always returns a complete picture.
func GetPermissionMapForUser(db *gorm.DB, user *User) PermissionMap {
if user == nil {
return PermissionMap{}
}
if user.Role == RoleAdmin {
m := PermissionMap{}
for _, f := range AllFeatures {
m[f] = true
}
return m
}
perm, err := GetUserPermissions(db, user.ID)
if err != nil {
return PermissionMap{}
}
// Fill in defaults for absent keys
effective := PermissionMap{}
for _, f := range AllFeatures {
val, exists := perm.Permissions[f]
if exists {
effective[f] = val
} else {
effective[f] = isDefaultOnFeature(f)
}
}
return effective
}
// GetModelAllowlist returns the model allowlist for a user.
func GetModelAllowlist(db *gorm.DB, userID string) ModelAllowlist {
perm, err := GetUserPermissions(db, userID)
if err != nil {
return ModelAllowlist{}
}
return perm.AllowedModels
}
// UpdateModelAllowlist updates the model allowlist for a user.
func UpdateModelAllowlist(db *gorm.DB, userID string, allowlist ModelAllowlist) error {
perm, err := GetUserPermissions(db, userID)
if err != nil {
return err
}
perm.AllowedModels = allowlist
return db.Save(perm).Error
}
// IsModelAllowed returns true if the user is allowed to use the given model.
// Admins always have access. If the allowlist is not enabled, all models are allowed.
func IsModelAllowed(db *gorm.DB, user *User, modelName string) bool {
if user == nil {
return false
}
if user.Role == RoleAdmin {
return true
}
allowlist := GetModelAllowlist(db, user.ID)
if !allowlist.Enabled {
return true
}
for _, m := range allowlist.Models {
if m == modelName {
return true
}
}
return false
}

103
core/http/auth/roles.go Normal file
View File

@@ -0,0 +1,103 @@
package auth
import (
"fmt"
"strings"
"time"
"gorm.io/gorm"
)
const (
RoleAdmin = "admin"
RoleUser = "user"
StatusActive = "active"
StatusPending = "pending"
StatusDisabled = "disabled"
)
// AssignRole determines the role for a new user.
// First user in the database becomes admin. If adminEmail is set and matches,
// the user becomes admin. Otherwise, the user gets the "user" role.
// Must be called within a transaction that also creates the user to prevent
// race conditions on the first-user admin assignment.
func AssignRole(tx *gorm.DB, email, adminEmail string) string {
var count int64
tx.Model(&User{}).Count(&count)
if count == 0 {
return RoleAdmin
}
if adminEmail != "" && strings.EqualFold(email, adminEmail) {
return RoleAdmin
}
return RoleUser
}
// MaybePromote promotes a user to admin on login if their email matches
// adminEmail. It does not demote existing admins. Returns true if the user
// was promoted.
func MaybePromote(db *gorm.DB, user *User, adminEmail string) bool {
if user.Role == RoleAdmin {
return false
}
if adminEmail != "" && strings.EqualFold(user.Email, adminEmail) {
user.Role = RoleAdmin
db.Model(user).Update("role", RoleAdmin)
return true
}
return false
}
// ValidateInvite checks that an invite code exists, is unused, and has not expired.
// The code is hashed with HMAC-SHA256 before lookup.
func ValidateInvite(db *gorm.DB, code, hmacSecret string) (*InviteCode, error) {
hash := HashAPIKey(code, hmacSecret)
var invite InviteCode
if err := db.Where("code = ?", hash).First(&invite).Error; err != nil {
return nil, fmt.Errorf("invite code not found")
}
if invite.UsedBy != nil {
return nil, fmt.Errorf("invite code already used")
}
if time.Now().After(invite.ExpiresAt) {
return nil, fmt.Errorf("invite code expired")
}
return &invite, nil
}
// ConsumeInvite marks an invite code as used by the given user.
func ConsumeInvite(db *gorm.DB, invite *InviteCode, userID string) {
now := time.Now()
invite.UsedBy = &userID
invite.UsedAt = &now
db.Save(invite)
}
// NeedsInviteOrApproval returns true if registration gating applies for the given mode.
// Admins (first user or matching adminEmail) are never gated.
// Must be called within a transaction that also creates the user.
func NeedsInviteOrApproval(tx *gorm.DB, email, adminEmail, registrationMode string) bool {
// Empty registration mode defaults to "approval"
if registrationMode == "" {
registrationMode = "approval"
}
if registrationMode != "approval" && registrationMode != "invite" {
return false
}
// Admin email is never gated
if adminEmail != "" && strings.EqualFold(email, adminEmail) {
return false
}
// First user is never gated
var count int64
tx.Model(&User{}).Count(&count)
if count == 0 {
return false
}
return true
}

View File

@@ -0,0 +1,84 @@
//go:build auth
package auth_test
import (
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)
var _ = Describe("Roles", func() {
var db *gorm.DB
BeforeEach(func() {
db = testDB()
})
Describe("AssignRole", func() {
It("returns admin for the first user (empty DB)", func() {
role := auth.AssignRole(db, "first@example.com", "")
Expect(role).To(Equal(auth.RoleAdmin))
})
It("returns user for the second user", func() {
createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub)
role := auth.AssignRole(db, "second@example.com", "")
Expect(role).To(Equal(auth.RoleUser))
})
It("returns admin when email matches adminEmail", func() {
createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub)
role := auth.AssignRole(db, "admin@example.com", "admin@example.com")
Expect(role).To(Equal(auth.RoleAdmin))
})
It("is case-insensitive for admin email match", func() {
createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub)
role := auth.AssignRole(db, "Admin@Example.COM", "admin@example.com")
Expect(role).To(Equal(auth.RoleAdmin))
})
It("returns user when email does not match adminEmail", func() {
createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub)
role := auth.AssignRole(db, "other@example.com", "admin@example.com")
Expect(role).To(Equal(auth.RoleUser))
})
})
Describe("MaybePromote", func() {
It("promotes user to admin when email matches", func() {
user := createTestUser(db, "promoted@example.com", auth.RoleUser, auth.ProviderGitHub)
promoted := auth.MaybePromote(db, user, "promoted@example.com")
Expect(promoted).To(BeTrue())
Expect(user.Role).To(Equal(auth.RoleAdmin))
// Verify in DB
var dbUser auth.User
db.First(&dbUser, "id = ?", user.ID)
Expect(dbUser.Role).To(Equal(auth.RoleAdmin))
})
It("does not promote when email does not match", func() {
user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
promoted := auth.MaybePromote(db, user, "admin@example.com")
Expect(promoted).To(BeFalse())
Expect(user.Role).To(Equal(auth.RoleUser))
})
It("does not demote an existing admin", func() {
user := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub)
promoted := auth.MaybePromote(db, user, "other@example.com")
Expect(promoted).To(BeFalse())
Expect(user.Role).To(Equal(auth.RoleAdmin))
})
})
})

182
core/http/auth/session.go Normal file
View File

@@ -0,0 +1,182 @@
package auth
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
)
const (
sessionDuration = 30 * 24 * time.Hour // 30 days
sessionIDBytes = 32 // 32 bytes = 64 hex chars
sessionCookie = "session"
sessionRotationInterval = 1 * time.Hour
)
// CreateSession creates a new session for the given user, returning the
// plaintext token (64-char hex string). The stored session ID is the
// HMAC-SHA256 hash of the token.
func CreateSession(db *gorm.DB, userID, hmacSecret string) (string, error) {
b := make([]byte, sessionIDBytes)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
plaintext := hex.EncodeToString(b)
hash := HashAPIKey(plaintext, hmacSecret)
now := time.Now()
session := Session{
ID: hash,
UserID: userID,
ExpiresAt: now.Add(sessionDuration),
RotatedAt: now,
}
if err := db.Create(&session).Error; err != nil {
return "", fmt.Errorf("failed to create session: %w", err)
}
return plaintext, nil
}
// ValidateSession hashes the plaintext token and looks up the session.
// Returns the associated user and session, or (nil, nil) if not found/expired.
func ValidateSession(db *gorm.DB, token, hmacSecret string) (*User, *Session) {
hash := HashAPIKey(token, hmacSecret)
var session Session
if err := db.Preload("User").Where("id = ? AND expires_at > ?", hash, time.Now()).First(&session).Error; err != nil {
return nil, nil
}
if session.User.Status != StatusActive {
return nil, nil
}
return &session.User, &session
}
// DeleteSession removes a session by hashing the plaintext token.
func DeleteSession(db *gorm.DB, token, hmacSecret string) error {
hash := HashAPIKey(token, hmacSecret)
return db.Where("id = ?", hash).Delete(&Session{}).Error
}
// CleanExpiredSessions removes all sessions that have passed their expiry time.
func CleanExpiredSessions(db *gorm.DB) error {
return db.Where("expires_at < ?", time.Now()).Delete(&Session{}).Error
}
// DeleteUserSessions removes all sessions for the given user.
func DeleteUserSessions(db *gorm.DB, userID string) error {
return db.Where("user_id = ?", userID).Delete(&Session{}).Error
}
// RotateSession creates a new session for the same user, deletes the old one,
// and returns the new plaintext token.
func RotateSession(db *gorm.DB, oldSession *Session, hmacSecret string) (string, error) {
b := make([]byte, sessionIDBytes)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
plaintext := hex.EncodeToString(b)
hash := HashAPIKey(plaintext, hmacSecret)
now := time.Now()
newSession := Session{
ID: hash,
UserID: oldSession.UserID,
ExpiresAt: oldSession.ExpiresAt,
RotatedAt: now,
}
err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&newSession).Error; err != nil {
return err
}
return tx.Where("id = ?", oldSession.ID).Delete(&Session{}).Error
})
if err != nil {
return "", fmt.Errorf("failed to rotate session: %w", err)
}
return plaintext, nil
}
// MaybeRotateSession checks if the session should be rotated and does so if needed.
// Called from the auth middleware after successful cookie-based authentication.
func MaybeRotateSession(c echo.Context, db *gorm.DB, session *Session, hmacSecret string) {
if session == nil {
return
}
rotatedAt := session.RotatedAt
if rotatedAt.IsZero() {
rotatedAt = session.CreatedAt
}
if time.Since(rotatedAt) < sessionRotationInterval {
return
}
newToken, err := RotateSession(db, session, hmacSecret)
if err != nil {
// Rotation failure is non-fatal; the old session remains valid
return
}
SetSessionCookie(c, newToken)
}
// isSecure returns true when the request arrived over HTTPS, either directly
// or via a reverse proxy that sets X-Forwarded-Proto.
func isSecure(c echo.Context) bool {
return c.Scheme() == "https"
}
// SetSessionCookie sets the session cookie on the response.
func SetSessionCookie(c echo.Context, sessionID string) {
cookie := &http.Cookie{
Name: sessionCookie,
Value: sessionID,
Path: "/",
HttpOnly: true,
Secure: isSecure(c),
SameSite: http.SameSiteLaxMode,
MaxAge: int(sessionDuration.Seconds()),
}
c.SetCookie(cookie)
}
// SetTokenCookie sets an httpOnly "token" cookie for legacy API key auth.
func SetTokenCookie(c echo.Context, token string) {
cookie := &http.Cookie{
Name: "token",
Value: token,
Path: "/",
HttpOnly: true,
Secure: isSecure(c),
SameSite: http.SameSiteLaxMode,
MaxAge: int(sessionDuration.Seconds()),
}
c.SetCookie(cookie)
}
// ClearSessionCookie clears the session cookie.
func ClearSessionCookie(c echo.Context) {
cookie := &http.Cookie{
Name: sessionCookie,
Value: "",
Path: "/",
HttpOnly: true,
Secure: isSecure(c),
SameSite: http.SameSiteLaxMode,
MaxAge: -1,
}
c.SetCookie(cookie)
}

View File

@@ -0,0 +1,272 @@
//go:build auth
package auth_test
import (
"time"
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)
var _ = Describe("Sessions", func() {
var (
db *gorm.DB
user *auth.User
)
// Use empty HMAC secret for basic tests
hmacSecret := ""
BeforeEach(func() {
db = testDB()
user = createTestUser(db, "session@example.com", auth.RoleUser, auth.ProviderGitHub)
})
Describe("CreateSession", func() {
It("creates a session and returns 64-char hex plaintext token", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
Expect(token).To(HaveLen(64))
})
It("stores the hash (not plaintext) in the DB", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
hash := auth.HashAPIKey(token, hmacSecret)
var session auth.Session
err = db.First(&session, "id = ?", hash).Error
Expect(err).ToNot(HaveOccurred())
Expect(session.UserID).To(Equal(user.ID))
// The plaintext token should NOT be stored as the ID
Expect(session.ID).ToNot(Equal(token))
Expect(session.ID).To(Equal(hash))
})
It("sets expiry to approximately 30 days from now", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
hash := auth.HashAPIKey(token, hmacSecret)
var session auth.Session
db.First(&session, "id = ?", hash)
expectedExpiry := time.Now().Add(30 * 24 * time.Hour)
Expect(session.ExpiresAt).To(BeTemporally("~", expectedExpiry, time.Minute))
})
It("sets RotatedAt on creation", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
hash := auth.HashAPIKey(token, hmacSecret)
var session auth.Session
db.First(&session, "id = ?", hash)
Expect(session.RotatedAt).To(BeTemporally("~", time.Now(), time.Minute))
})
It("associates session with correct user", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
hash := auth.HashAPIKey(token, hmacSecret)
var session auth.Session
db.First(&session, "id = ?", hash)
Expect(session.UserID).To(Equal(user.ID))
})
})
Describe("ValidateSession", func() {
It("returns user for valid session", func() {
token := createTestSession(db, user.ID)
found, session := auth.ValidateSession(db, token, hmacSecret)
Expect(found).ToNot(BeNil())
Expect(found.ID).To(Equal(user.ID))
Expect(session).ToNot(BeNil())
})
It("returns nil for non-existent session", func() {
found, session := auth.ValidateSession(db, "nonexistent-session-id", hmacSecret)
Expect(found).To(BeNil())
Expect(session).To(BeNil())
})
It("returns nil for expired session", func() {
token := createTestSession(db, user.ID)
hash := auth.HashAPIKey(token, hmacSecret)
// Manually expire the session
db.Model(&auth.Session{}).Where("id = ?", hash).
Update("expires_at", time.Now().Add(-1*time.Hour))
found, _ := auth.ValidateSession(db, token, hmacSecret)
Expect(found).To(BeNil())
})
})
Describe("DeleteSession", func() {
It("removes the session from DB", func() {
token := createTestSession(db, user.ID)
err := auth.DeleteSession(db, token, hmacSecret)
Expect(err).ToNot(HaveOccurred())
found, _ := auth.ValidateSession(db, token, hmacSecret)
Expect(found).To(BeNil())
})
It("does not error on non-existent session", func() {
err := auth.DeleteSession(db, "nonexistent", hmacSecret)
Expect(err).ToNot(HaveOccurred())
})
})
Describe("CleanExpiredSessions", func() {
It("removes expired sessions", func() {
token := createTestSession(db, user.ID)
hash := auth.HashAPIKey(token, hmacSecret)
// Manually expire the session
db.Model(&auth.Session{}).Where("id = ?", hash).
Update("expires_at", time.Now().Add(-1*time.Hour))
err := auth.CleanExpiredSessions(db)
Expect(err).ToNot(HaveOccurred())
var count int64
db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count)
Expect(count).To(Equal(int64(0)))
})
It("keeps active sessions", func() {
token := createTestSession(db, user.ID)
hash := auth.HashAPIKey(token, hmacSecret)
err := auth.CleanExpiredSessions(db)
Expect(err).ToNot(HaveOccurred())
var count int64
db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count)
Expect(count).To(Equal(int64(1)))
})
})
Describe("RotateSession", func() {
It("creates a new session and deletes the old one", func() {
token := createTestSession(db, user.ID)
hash := auth.HashAPIKey(token, hmacSecret)
// Get the old session
var oldSession auth.Session
db.First(&oldSession, "id = ?", hash)
newToken, err := auth.RotateSession(db, &oldSession, hmacSecret)
Expect(err).ToNot(HaveOccurred())
Expect(newToken).To(HaveLen(64))
Expect(newToken).ToNot(Equal(token))
// Old session should be gone
var count int64
db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count)
Expect(count).To(Equal(int64(0)))
// New session should exist and validate
found, _ := auth.ValidateSession(db, newToken, hmacSecret)
Expect(found).ToNot(BeNil())
Expect(found.ID).To(Equal(user.ID))
})
It("preserves user ID and expiry", func() {
token := createTestSession(db, user.ID)
hash := auth.HashAPIKey(token, hmacSecret)
var oldSession auth.Session
db.First(&oldSession, "id = ?", hash)
newToken, err := auth.RotateSession(db, &oldSession, hmacSecret)
Expect(err).ToNot(HaveOccurred())
newHash := auth.HashAPIKey(newToken, hmacSecret)
var newSession auth.Session
db.First(&newSession, "id = ?", newHash)
Expect(newSession.UserID).To(Equal(oldSession.UserID))
Expect(newSession.ExpiresAt).To(BeTemporally("~", oldSession.ExpiresAt, time.Second))
})
})
Context("with HMAC secret", func() {
hmacSecret := "test-hmac-secret-123"
It("creates and validates sessions with HMAC secret", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
found, session := auth.ValidateSession(db, token, hmacSecret)
Expect(found).ToNot(BeNil())
Expect(found.ID).To(Equal(user.ID))
Expect(session).ToNot(BeNil())
})
It("does not validate with wrong HMAC secret", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
found, _ := auth.ValidateSession(db, token, "wrong-secret")
Expect(found).To(BeNil())
})
It("does not validate with empty HMAC secret", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
found, _ := auth.ValidateSession(db, token, "")
Expect(found).To(BeNil())
})
It("session created with empty secret does not validate with non-empty secret", func() {
token, err := auth.CreateSession(db, user.ID, "")
Expect(err).ToNot(HaveOccurred())
found, _ := auth.ValidateSession(db, token, hmacSecret)
Expect(found).To(BeNil())
})
It("deletes session with correct HMAC secret", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
err = auth.DeleteSession(db, token, hmacSecret)
Expect(err).ToNot(HaveOccurred())
found, _ := auth.ValidateSession(db, token, hmacSecret)
Expect(found).To(BeNil())
})
It("rotates session with HMAC secret", func() {
token, err := auth.CreateSession(db, user.ID, hmacSecret)
Expect(err).ToNot(HaveOccurred())
hash := auth.HashAPIKey(token, hmacSecret)
var oldSession auth.Session
db.First(&oldSession, "id = ?", hash)
newToken, err := auth.RotateSession(db, &oldSession, hmacSecret)
Expect(err).ToNot(HaveOccurred())
// Old token should not validate
found, _ := auth.ValidateSession(db, token, hmacSecret)
Expect(found).To(BeNil())
// New token should validate
found, _ = auth.ValidateSession(db, newToken, hmacSecret)
Expect(found).ToNot(BeNil())
Expect(found.ID).To(Equal(user.ID))
})
})
})

151
core/http/auth/usage.go Normal file
View File

@@ -0,0 +1,151 @@
package auth
import (
"fmt"
"strings"
"time"
"gorm.io/gorm"
)
// UsageRecord represents a single API request's token usage.
type UsageRecord struct {
ID uint `gorm:"primaryKey;autoIncrement"`
UserID string `gorm:"size:36;index:idx_usage_user_time"`
UserName string `gorm:"size:255"`
Model string `gorm:"size:255;index"`
Endpoint string `gorm:"size:255"`
PromptTokens int64
CompletionTokens int64
TotalTokens int64
Duration int64 // milliseconds
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
}
// RecordUsage inserts a usage record.
func RecordUsage(db *gorm.DB, record *UsageRecord) error {
return db.Create(record).Error
}
// UsageBucket is an aggregated time bucket for the dashboard.
type UsageBucket struct {
Bucket string `json:"bucket"`
Model string `json:"model"`
UserID string `json:"user_id,omitempty"`
UserName string `json:"user_name,omitempty"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
RequestCount int64 `json:"request_count"`
}
// UsageTotals is a summary of all usage.
type UsageTotals struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
RequestCount int64 `json:"request_count"`
}
// periodToWindow returns the time window and SQL date format for a period.
func periodToWindow(period string, isSQLite bool) (time.Time, string) {
now := time.Now()
var since time.Time
var dateFmt string
switch period {
case "day":
since = now.Add(-24 * time.Hour)
if isSQLite {
dateFmt = "strftime('%Y-%m-%d %H:00', created_at)"
} else {
dateFmt = "to_char(date_trunc('hour', created_at), 'YYYY-MM-DD HH24:00')"
}
case "week":
since = now.Add(-7 * 24 * time.Hour)
if isSQLite {
dateFmt = "strftime('%Y-%m-%d', created_at)"
} else {
dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')"
}
case "all":
since = time.Time{} // zero time = no filter
if isSQLite {
dateFmt = "strftime('%Y-%m', created_at)"
} else {
dateFmt = "to_char(date_trunc('month', created_at), 'YYYY-MM')"
}
default: // "month"
since = now.Add(-30 * 24 * time.Hour)
if isSQLite {
dateFmt = "strftime('%Y-%m-%d', created_at)"
} else {
dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')"
}
}
return since, dateFmt
}
func isSQLiteDB(db *gorm.DB) bool {
return strings.Contains(db.Dialector.Name(), "sqlite")
}
// GetUserUsage returns aggregated usage for a single user.
func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) {
sqlite := isSQLiteDB(db)
since, dateFmt := periodToWindow(period, sqlite)
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}).
Select(bucketExpr+", model, "+
"SUM(prompt_tokens) as prompt_tokens, "+
"SUM(completion_tokens) as completion_tokens, "+
"SUM(total_tokens) as total_tokens, "+
"COUNT(*) as request_count").
Where("user_id = ?", userID).
Group("bucket, model").
Order("bucket ASC")
if !since.IsZero() {
query = query.Where("created_at >= ?", since)
}
var buckets []UsageBucket
if err := query.Find(&buckets).Error; err != nil {
return nil, err
}
return buckets, nil
}
// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter.
func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
sqlite := isSQLiteDB(db)
since, dateFmt := periodToWindow(period, sqlite)
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}).
Select(bucketExpr+", model, user_id, user_name, "+
"SUM(prompt_tokens) as prompt_tokens, "+
"SUM(completion_tokens) as completion_tokens, "+
"SUM(total_tokens) as total_tokens, "+
"COUNT(*) as request_count").
Group("bucket, model, user_id, user_name").
Order("bucket ASC")
if !since.IsZero() {
query = query.Where("created_at >= ?", since)
}
if userID != "" {
query = query.Where("user_id = ?", userID)
}
var buckets []UsageBucket
if err := query.Find(&buckets).Error; err != nil {
return nil, err
}
return buckets, nil
}

View File

@@ -0,0 +1,161 @@
//go:build auth
package auth_test
import (
"time"
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Usage", func() {
Describe("RecordUsage", func() {
It("inserts a usage record", func() {
db := testDB()
record := &auth.UsageRecord{
UserID: "user-1",
UserName: "Test User",
Model: "gpt-4",
Endpoint: "/v1/chat/completions",
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
Duration: 1200,
CreatedAt: time.Now(),
}
err := auth.RecordUsage(db, record)
Expect(err).ToNot(HaveOccurred())
Expect(record.ID).ToNot(BeZero())
})
})
Describe("GetUserUsage", func() {
It("returns aggregated usage for a specific user", func() {
db := testDB()
// Insert records for two users
for i := 0; i < 3; i++ {
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-a",
UserName: "Alice",
Model: "gpt-4",
Endpoint: "/v1/chat/completions",
PromptTokens: 100,
TotalTokens: 150,
CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
}
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-b",
UserName: "Bob",
Model: "gpt-4",
PromptTokens: 200,
TotalTokens: 300,
CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
buckets, err := auth.GetUserUsage(db, "user-a", "month")
Expect(err).ToNot(HaveOccurred())
Expect(buckets).ToNot(BeEmpty())
// All returned buckets should be for user-a's model
totalPrompt := int64(0)
for _, b := range buckets {
totalPrompt += b.PromptTokens
}
Expect(totalPrompt).To(Equal(int64(300)))
})
It("filters by period", func() {
db := testDB()
// Record in the past (beyond day window)
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-c",
UserName: "Carol",
Model: "gpt-4",
PromptTokens: 100,
TotalTokens: 100,
CreatedAt: time.Now().Add(-48 * time.Hour),
})
Expect(err).ToNot(HaveOccurred())
// Record now
err = auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-c",
UserName: "Carol",
Model: "gpt-4",
PromptTokens: 200,
TotalTokens: 200,
CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
// Day period should only include recent record
buckets, err := auth.GetUserUsage(db, "user-c", "day")
Expect(err).ToNot(HaveOccurred())
totalPrompt := int64(0)
for _, b := range buckets {
totalPrompt += b.PromptTokens
}
Expect(totalPrompt).To(Equal(int64(200)))
// Month period should include both
buckets, err = auth.GetUserUsage(db, "user-c", "month")
Expect(err).ToNot(HaveOccurred())
totalPrompt = 0
for _, b := range buckets {
totalPrompt += b.PromptTokens
}
Expect(totalPrompt).To(Equal(int64(300)))
})
})
Describe("GetAllUsage", func() {
It("returns usage for all users", func() {
db := testDB()
for _, uid := range []string{"user-x", "user-y"} {
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: uid,
UserName: uid,
Model: "gpt-4",
PromptTokens: 100,
TotalTokens: 150,
CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
}
buckets, err := auth.GetAllUsage(db, "month", "")
Expect(err).ToNot(HaveOccurred())
Expect(len(buckets)).To(BeNumerically(">=", 2))
})
It("filters by user ID when specified", func() {
db := testDB()
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-p", UserName: "Pat", Model: "gpt-4",
PromptTokens: 100, TotalTokens: 100, CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
err = auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-q", UserName: "Quinn", Model: "gpt-4",
PromptTokens: 200, TotalTokens: 200, CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
buckets, err := auth.GetAllUsage(db, "month", "user-p")
Expect(err).ToNot(HaveOccurred())
for _, b := range buckets {
Expect(b.UserID).To(Equal("user-p"))
}
})
})
})

View File

@@ -10,6 +10,7 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
@@ -197,54 +198,24 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
xlog.Debug("Anthropic MCP re-templating", "iteration", mcpIteration, "prompt_len", len(predInput))
}
images := []string{}
for _, m := range openAIReq.Messages {
images = append(images, m.StringImages...)
}
// Populate openAIReq fields for ComputeChoices
openAIReq.Tools = convertFuncsToOpenAITools(funcs)
openAIReq.ToolsChoice = input.ToolChoice
openAIReq.Metadata = input.Metadata
toolsJSON := ""
if len(funcs) > 0 {
openAITools := make([]functions.Tool, len(funcs))
for i, f := range funcs {
openAITools[i] = functions.Tool{Type: "function", Function: f}
}
if toolsBytes, err := json.Marshal(openAITools); err == nil {
toolsJSON = string(toolsBytes)
}
var result string
cb := func(s string, c *[]schema.Choice) {
result = s
}
toolChoiceJSON := ""
if input.ToolChoice != nil {
if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
}
predFunc, err := backend.ModelInference(
input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, nil, nil, nil, input.Metadata)
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
if err != nil {
xlog.Error("Anthropic model inference failed", "error", err)
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
}
const maxEmptyRetries = 5
var prediction backend.LLMResponse
var result string
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
prediction, err = predFunc()
if err != nil {
xlog.Error("Anthropic prediction failed", "error", err)
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
}
result = backend.Finetune(*cfg, predInput, prediction.Response)
if result != "" || !shouldUseFn {
break
}
xlog.Warn("Anthropic: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
}
// Try pre-parsed tool calls from C++ autoparser first, fall back to text parsing
var toolCalls []functions.FuncCallResults
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] Anthropic: using pre-parsed tool calls", "count", len(deltaToolCalls))
toolCalls = deltaToolCalls
} else {
@@ -350,8 +321,8 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
StopReason: &stopReason,
Content: contentBlocks,
Usage: schema.AnthropicUsage{
InputTokens: prediction.Usage.Prompt,
OutputTokens: prediction.Usage.Completion,
InputTokens: tokenUsage.Prompt,
OutputTokens: tokenUsage.Completion,
},
}
@@ -397,12 +368,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
xlog.Debug("Anthropic MCP stream re-templating", "iteration", mcpIteration)
}
openAIMessages := openAIReq.Messages
images := []string{}
for _, m := range openAIMessages {
images = append(images, m.StringImages...)
}
// Track accumulated content for tool call detection
accumulatedContent := ""
currentBlockIndex := 0
@@ -481,38 +446,19 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
return true
}
toolsJSON := ""
if len(funcs) > 0 {
openAITools := make([]functions.Tool, len(funcs))
for i, f := range funcs {
openAITools[i] = functions.Tool{Type: "function", Function: f}
}
if toolsBytes, err := json.Marshal(openAITools); err == nil {
toolsJSON = string(toolsBytes)
}
}
toolChoiceJSON := ""
if input.ToolChoice != nil {
if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
}
// Populate openAIReq fields for ComputeChoices
openAIReq.Tools = convertFuncsToOpenAITools(funcs)
openAIReq.ToolsChoice = input.ToolChoice
openAIReq.Metadata = input.Metadata
predFunc, err := backend.ModelInference(
input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, nil, nil, nil, input.Metadata)
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
if err != nil {
xlog.Error("Anthropic stream model inference failed", "error", err)
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
}
prediction, err := predFunc()
if err != nil {
xlog.Error("Anthropic stream prediction failed", "error", err)
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
}
// Also check chat deltas for tool calls
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
collectedToolCalls = deltaToolCalls
}
@@ -595,7 +541,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
StopReason: &stopReason,
},
Usage: &schema.AnthropicUsage{
OutputTokens: prediction.Usage.Completion,
OutputTokens: tokenUsage.Completion,
},
})
@@ -613,6 +559,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
return nil
}
func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
tools := make([]functions.Tool, len(funcs))
for i, f := range funcs {
tools[i] = functions.Tool{Type: "function", Function: f}
}
return tools
}
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
data, err := json.Marshal(event)
if err != nil {

View File

@@ -12,27 +12,54 @@ import (
func ListCollectionsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
collections, err := svc.ListCollections()
userID := getUserID(c)
cols, err := svc.ListCollectionsForUser(userID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]any{
"collections": collections,
"count": len(collections),
})
resp := map[string]any{
"collections": cols,
"count": len(cols),
}
// Admin cross-user aggregation
if wantsAllUsers(c) {
usm := svc.UserServicesManager()
if usm != nil {
userIDs, _ := usm.ListAllUserIDs()
userGroups := map[string]any{}
for _, uid := range userIDs {
if uid == userID {
continue
}
userCols, err := svc.ListCollectionsForUser(uid)
if err != nil || len(userCols) == 0 {
continue
}
userGroups[uid] = map[string]any{"collections": userCols}
}
if len(userGroups) > 0 {
resp["user_groups"] = userGroups
}
}
}
return c.JSON(http.StatusOK, resp)
}
}
func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
var payload struct {
Name string `json:"name"`
}
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
if err := svc.CreateCollection(payload.Name); err != nil {
if err := svc.CreateCollectionForUser(userID, payload.Name); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusCreated, map[string]string{"status": "ok", "name": payload.Name})
@@ -42,33 +69,33 @@ func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc {
func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
file, err := c.FormFile("file")
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
}
if svc.CollectionEntryExists(name, file.Filename) {
return c.JSON(http.StatusConflict, map[string]string{"error": "entry already exists"})
}
src, err := file.Open()
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
defer src.Close()
if err := svc.UploadToCollection(name, file.Filename, src); err != nil {
key, err := svc.UploadToCollectionForUser(userID, name, file.Filename, src)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename})
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename, "key": key})
}
}
func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
entries, err := svc.ListCollectionEntries(c.Param("name"))
userID := effectiveUserID(c)
entries, err := svc.ListCollectionEntriesForUser(userID, c.Param("name"))
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -85,12 +112,13 @@ func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFun
func GetCollectionEntryContentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
entryParam := c.Param("*")
entry, err := url.PathUnescape(entryParam)
if err != nil {
entry = entryParam
}
content, chunkCount, err := svc.GetCollectionEntryContent(c.Param("name"), entry)
content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, c.Param("name"), entry)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -107,6 +135,7 @@ func GetCollectionEntryContentEndpoint(app *application.Application) echo.Handle
func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
var payload struct {
Query string `json:"query"`
MaxResults int `json:"max_results"`
@@ -114,7 +143,7 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
results, err := svc.SearchCollection(c.Param("name"), payload.Query, payload.MaxResults)
results, err := svc.SearchCollectionForUser(userID, c.Param("name"), payload.Query, payload.MaxResults)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -131,7 +160,8 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc {
func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
if err := svc.ResetCollection(c.Param("name")); err != nil {
userID := effectiveUserID(c)
if err := svc.ResetCollectionForUser(userID, c.Param("name")); err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -144,13 +174,14 @@ func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc {
func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
var payload struct {
Entry string `json:"entry"`
}
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
remaining, err := svc.DeleteCollectionEntry(c.Param("name"), payload.Entry)
remaining, err := svc.DeleteCollectionEntryForUser(userID, c.Param("name"), payload.Entry)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -167,6 +198,7 @@ func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFun
func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
var payload struct {
URL string `json:"url"`
UpdateInterval int `json:"update_interval"`
@@ -177,7 +209,7 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc
if payload.UpdateInterval < 1 {
payload.UpdateInterval = 60
}
if err := svc.AddCollectionSource(c.Param("name"), payload.URL, payload.UpdateInterval); err != nil {
if err := svc.AddCollectionSourceForUser(userID, c.Param("name"), payload.URL, payload.UpdateInterval); err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -190,23 +222,46 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc
func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
var payload struct {
URL string `json:"url"`
}
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
if err := svc.RemoveCollectionSource(c.Param("name"), payload.URL); err != nil {
if err := svc.RemoveCollectionSourceForUser(userID, c.Param("name"), payload.URL); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
}
}
// GetCollectionEntryRawFileEndpoint serves the original uploaded binary file.
func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
entryParam := c.Param("*")
entry, err := url.PathUnescape(entryParam)
if err != nil {
entry = entryParam
}
fpath, err := svc.GetCollectionEntryFilePathForUser(userID, c.Param("name"), entry)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.File(fpath)
}
}
func ListCollectionSourcesEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
sources, err := svc.ListCollectionSources(c.Param("name"))
userID := effectiveUserID(c)
sources, err := svc.ListCollectionSourcesForUser(userID, c.Param("name"))
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})

View File

@@ -8,19 +8,27 @@ import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// CreateTaskEndpoint creates a new agent task
// @Summary Create a new agent task
// @Description Create a new reusable agent task with prompt template and configuration
// @Tags agent-jobs
// @Accept json
// @Produce json
// @Param task body schema.Task true "Task definition"
// @Success 201 {object} map[string]string "Task created"
// @Failure 400 {object} map[string]string "Invalid request"
// @Failure 500 {object} map[string]string "Internal server error"
// @Router /api/agent/tasks [post]
// getJobService returns the job service for the current user.
// Falls back to the global service when no user is authenticated.
func getJobService(app *application.Application, c echo.Context) *services.AgentJobService {
userID := getUserID(c)
if userID == "" {
return app.AgentJobService()
}
svc := app.AgentPoolService()
if svc == nil {
return app.AgentJobService()
}
jobSvc, err := svc.JobServiceForUser(userID)
if err != nil {
return app.AgentJobService()
}
return jobSvc
}
func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
var task schema.Task
@@ -28,7 +36,7 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()})
}
id, err := app.AgentJobService().CreateTask(task)
id, err := getJobService(app, c).CreateTask(task)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
@@ -37,18 +45,6 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
}
}
// UpdateTaskEndpoint updates an existing task
// @Summary Update an agent task
// @Description Update an existing agent task
// @Tags agent-jobs
// @Accept json
// @Produce json
// @Param id path string true "Task ID"
// @Param task body schema.Task true "Updated task definition"
// @Success 200 {object} map[string]string "Task updated"
// @Failure 400 {object} map[string]string "Invalid request"
// @Failure 404 {object} map[string]string "Task not found"
// @Router /api/agent/tasks/{id} [put]
func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
@@ -57,7 +53,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()})
}
if err := app.AgentJobService().UpdateTask(id, task); err != nil {
if err := getJobService(app, c).UpdateTask(id, task); err != nil {
if err.Error() == "task not found: "+id {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -68,19 +64,10 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
}
}
// DeleteTaskEndpoint deletes a task
// @Summary Delete an agent task
// @Description Delete an agent task by ID
// @Tags agent-jobs
// @Produce json
// @Param id path string true "Task ID"
// @Success 200 {object} map[string]string "Task deleted"
// @Failure 404 {object} map[string]string "Task not found"
// @Router /api/agent/tasks/{id} [delete]
func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
if err := app.AgentJobService().DeleteTask(id); err != nil {
if err := getJobService(app, c).DeleteTask(id); err != nil {
if err.Error() == "task not found: "+id {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -91,33 +78,52 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
}
}
// ListTasksEndpoint lists all tasks
// @Summary List all agent tasks
// @Description Get a list of all agent tasks
// @Tags agent-jobs
// @Produce json
// @Success 200 {array} schema.Task "List of tasks"
// @Router /api/agent/tasks [get]
func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
tasks := app.AgentJobService().ListTasks()
jobSvc := getJobService(app, c)
tasks := jobSvc.ListTasks()
// Admin cross-user aggregation
if wantsAllUsers(c) {
svc := app.AgentPoolService()
if svc != nil {
usm := svc.UserServicesManager()
if usm != nil {
userID := getUserID(c)
userIDs, _ := usm.ListAllUserIDs()
userGroups := map[string]any{}
for _, uid := range userIDs {
if uid == userID {
continue
}
userJobSvc, err := svc.JobServiceForUser(uid)
if err != nil {
continue
}
userTasks := userJobSvc.ListTasks()
if len(userTasks) == 0 {
continue
}
userGroups[uid] = map[string]any{"tasks": userTasks}
}
if len(userGroups) > 0 {
return c.JSON(http.StatusOK, map[string]any{
"tasks": tasks,
"user_groups": userGroups,
})
}
}
}
}
return c.JSON(http.StatusOK, tasks)
}
}
// GetTaskEndpoint gets a task by ID
// @Summary Get an agent task
// @Description Get an agent task by ID
// @Tags agent-jobs
// @Produce json
// @Param id path string true "Task ID"
// @Success 200 {object} schema.Task "Task details"
// @Failure 404 {object} map[string]string "Task not found"
// @Router /api/agent/tasks/{id} [get]
func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
task, err := app.AgentJobService().GetTask(id)
task, err := getJobService(app, c).GetTask(id)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -126,16 +132,6 @@ func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
}
}
// ExecuteJobEndpoint executes a job
// @Summary Execute an agent job
// @Description Create and execute a new agent job
// @Tags agent-jobs
// @Accept json
// @Produce json
// @Param request body schema.JobExecutionRequest true "Job execution request"
// @Success 201 {object} schema.JobExecutionResponse "Job created"
// @Failure 400 {object} map[string]string "Invalid request"
// @Router /api/agent/jobs/execute [post]
func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
var req schema.JobExecutionRequest
@@ -147,7 +143,6 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
req.Parameters = make(map[string]string)
}
// Build multimedia struct from request
var multimedia *schema.MultimediaAttachment
if len(req.Images) > 0 || len(req.Videos) > 0 || len(req.Audios) > 0 || len(req.Files) > 0 {
multimedia = &schema.MultimediaAttachment{
@@ -158,7 +153,7 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
}
}
jobID, err := app.AgentJobService().ExecuteJob(req.TaskID, req.Parameters, "api", multimedia)
jobID, err := getJobService(app, c).ExecuteJob(req.TaskID, req.Parameters, "api", multimedia)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
@@ -172,19 +167,10 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
}
}
// GetJobEndpoint gets a job by ID
// @Summary Get an agent job
// @Description Get an agent job by ID
// @Tags agent-jobs
// @Produce json
// @Param id path string true "Job ID"
// @Success 200 {object} schema.Job "Job details"
// @Failure 404 {object} map[string]string "Job not found"
// @Router /api/agent/jobs/{id} [get]
func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
job, err := app.AgentJobService().GetJob(id)
job, err := getJobService(app, c).GetJob(id)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -193,16 +179,6 @@ func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
}
}
// ListJobsEndpoint lists jobs with optional filtering
// @Summary List agent jobs
// @Description Get a list of agent jobs, optionally filtered by task_id and status
// @Tags agent-jobs
// @Produce json
// @Param task_id query string false "Filter by task ID"
// @Param status query string false "Filter by status (pending, running, completed, failed, cancelled)"
// @Param limit query int false "Limit number of results"
// @Success 200 {array} schema.Job "List of jobs"
// @Router /api/agent/jobs [get]
func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
var taskID *string
@@ -224,25 +200,50 @@ func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
}
}
jobs := app.AgentJobService().ListJobs(taskID, status, limit)
jobSvc := getJobService(app, c)
jobs := jobSvc.ListJobs(taskID, status, limit)
// Admin cross-user aggregation
if wantsAllUsers(c) {
svc := app.AgentPoolService()
if svc != nil {
usm := svc.UserServicesManager()
if usm != nil {
userID := getUserID(c)
userIDs, _ := usm.ListAllUserIDs()
userGroups := map[string]any{}
for _, uid := range userIDs {
if uid == userID {
continue
}
userJobSvc, err := svc.JobServiceForUser(uid)
if err != nil {
continue
}
userJobs := userJobSvc.ListJobs(taskID, status, limit)
if len(userJobs) == 0 {
continue
}
userGroups[uid] = map[string]any{"jobs": userJobs}
}
if len(userGroups) > 0 {
return c.JSON(http.StatusOK, map[string]any{
"jobs": jobs,
"user_groups": userGroups,
})
}
}
}
}
return c.JSON(http.StatusOK, jobs)
}
}
// CancelJobEndpoint cancels a running job
// @Summary Cancel an agent job
// @Description Cancel a running or pending agent job
// @Tags agent-jobs
// @Produce json
// @Param id path string true "Job ID"
// @Success 200 {object} map[string]string "Job cancelled"
// @Failure 400 {object} map[string]string "Job cannot be cancelled"
// @Failure 404 {object} map[string]string "Job not found"
// @Router /api/agent/jobs/{id}/cancel [post]
func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
if err := app.AgentJobService().CancelJob(id); err != nil {
if err := getJobService(app, c).CancelJob(id); err != nil {
if err.Error() == "job not found: "+id {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -253,19 +254,10 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
}
}
// DeleteJobEndpoint deletes a job
// @Summary Delete an agent job
// @Description Delete an agent job by ID
// @Tags agent-jobs
// @Produce json
// @Param id path string true "Job ID"
// @Success 200 {object} map[string]string "Job deleted"
// @Failure 404 {object} map[string]string "Job not found"
// @Router /api/agent/jobs/{id} [delete]
func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
id := c.Param("id")
if err := app.AgentJobService().DeleteJob(id); err != nil {
if err := getJobService(app, c).DeleteJob(id); err != nil {
if err.Error() == "job not found: "+id {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -276,52 +268,33 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
}
}
// ExecuteTaskByNameEndpoint executes a task by name
// @Summary Execute a task by name
// @Description Execute an agent task by its name (convenience endpoint). Parameters can be provided in the request body as a JSON object with string values.
// @Tags agent-jobs
// @Accept json
// @Produce json
// @Param name path string true "Task name"
// @Param request body map[string]string false "Template parameters (JSON object with string values)"
// @Success 201 {object} schema.JobExecutionResponse "Job created"
// @Failure 400 {object} map[string]string "Invalid request"
// @Failure 404 {object} map[string]string "Task not found"
// @Router /api/agent/tasks/{name}/execute [post]
func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
name := c.Param("name")
var params map[string]string
// Try to bind parameters from request body
// If body is empty or invalid, use empty params
if c.Request().ContentLength > 0 {
if err := c.Bind(&params); err != nil {
// If binding fails, try to read as raw JSON
body := make(map[string]interface{})
if err := c.Bind(&body); err == nil {
// Convert interface{} values to strings
params = make(map[string]string)
for k, v := range body {
if str, ok := v.(string); ok {
params[k] = str
} else {
// Convert non-string values to string
params[k] = fmt.Sprintf("%v", v)
}
}
} else {
// If all binding fails, use empty params
params = make(map[string]string)
}
}
} else {
// No body provided, use empty params
params = make(map[string]string)
}
// Find task by name
tasks := app.AgentJobService().ListTasks()
jobSvc := getJobService(app, c)
tasks := jobSvc.ListTasks()
var task *schema.Task
for _, t := range tasks {
if t.Name == name {
@@ -334,7 +307,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Task not found: " + name})
}
jobID, err := app.AgentJobService().ExecuteJob(task.ID, params, "api", nil)
jobID, err := jobSvc.ExecuteJob(task.ID, params, "api", nil)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}

View File

@@ -28,7 +28,7 @@ func skillToResponse(s skilldomain.Skill) skillResponse {
out.License = s.Metadata.License
out.Compatibility = s.Metadata.Compatibility
out.Metadata = s.Metadata.Metadata
out.AllowedTools = s.Metadata.AllowedTools
out.AllowedTools = s.Metadata.AllowedTools.String()
}
return out
}
@@ -44,10 +44,38 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse {
func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
skills, err := svc.ListSkills()
userID := getUserID(c)
skills, err := svc.ListSkillsForUser(userID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
// Admin cross-user aggregation
if wantsAllUsers(c) {
usm := svc.UserServicesManager()
if usm != nil {
userIDs, _ := usm.ListAllUserIDs()
userGroups := map[string]any{}
for _, uid := range userIDs {
if uid == userID {
continue
}
userSkills, err := svc.ListSkillsForUser(uid)
if err != nil || len(userSkills) == 0 {
continue
}
userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)}
}
resp := map[string]any{
"skills": skillsToResponses(skills),
}
if len(userGroups) > 0 {
resp["user_groups"] = userGroups
}
return c.JSON(http.StatusOK, resp)
}
}
return c.JSON(http.StatusOK, skillsToResponses(skills))
}
}
@@ -55,7 +83,8 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
cfg := svc.GetSkillsConfig()
userID := getUserID(c)
cfg := svc.GetSkillsConfigForUser(userID)
return c.JSON(http.StatusOK, cfg)
}
}
@@ -63,8 +92,9 @@ func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
query := c.QueryParam("q")
skills, err := svc.SearchSkills(query)
skills, err := svc.SearchSkillsForUser(userID, query)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
@@ -75,6 +105,7 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
var payload struct {
Name string `json:"name"`
Description string `json:"description"`
@@ -87,7 +118,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
skill, err := svc.CreateSkill(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
if err != nil {
if strings.Contains(err.Error(), "already exists") {
return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()})
@@ -101,7 +132,8 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
skill, err := svc.GetSkill(c.Param("name"))
userID := effectiveUserID(c)
skill, err := svc.GetSkillForUser(userID, c.Param("name"))
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -112,6 +144,7 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
var payload struct {
Description string `json:"description"`
Content string `json:"content"`
@@ -123,7 +156,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
skill, err := svc.UpdateSkill(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -137,7 +170,8 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
if err := svc.DeleteSkill(c.Param("name")); err != nil {
userID := effectiveUserID(c)
if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -147,9 +181,9 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
// The wildcard param captures the path after /export/
userID := effectiveUserID(c)
name := c.Param("*")
data, err := svc.ExportSkill(name)
data, err := svc.ExportSkillForUser(userID, name)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -162,6 +196,7 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
file, err := c.FormFile("file")
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
@@ -175,7 +210,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
skill, err := svc.ImportSkill(data)
skill, err := svc.ImportSkillForUser(userID, data)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
@@ -188,7 +223,8 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
resources, skill, err := svc.ListSkillResources(c.Param("name"))
userID := effectiveUserID(c)
resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name"))
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -225,7 +261,8 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
content, info, err := svc.GetSkillResource(c.Param("name"), c.Param("*"))
userID := effectiveUserID(c)
content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*"))
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -245,6 +282,7 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
file, err := c.FormFile("file")
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"})
@@ -262,7 +300,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if err := svc.CreateSkillResource(c.Param("name"), path, data); err != nil {
if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusCreated, map[string]string{"path": path})
@@ -272,13 +310,14 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
var payload struct {
Content string `json:"content"`
}
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
if err := svc.UpdateSkillResource(c.Param("name"), c.Param("*"), payload.Content); err != nil {
if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -288,7 +327,8 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
if err := svc.DeleteSkillResource(c.Param("name"), c.Param("*")); err != nil {
userID := getUserID(c)
if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -300,7 +340,8 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
repos, err := svc.ListGitRepos()
userID := getUserID(c)
repos, err := svc.ListGitReposForUser(userID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
@@ -311,13 +352,14 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
var payload struct {
URL string `json:"url"`
}
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
repo, err := svc.AddGitRepo(payload.URL)
repo, err := svc.AddGitRepoForUser(userID, payload.URL)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
@@ -328,6 +370,7 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
var payload struct {
URL string `json:"url"`
Enabled *bool `json:"enabled"`
@@ -335,7 +378,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
repo, err := svc.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled)
repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -349,7 +392,8 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
if err := svc.DeleteGitRepo(c.Param("id")); err != nil {
userID := getUserID(c)
if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -362,7 +406,8 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
if err := svc.SyncGitRepo(c.Param("id")); err != nil {
userID := getUserID(c)
if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"})
@@ -372,7 +417,8 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
repo, err := svc.ToggleGitRepo(c.Param("id"))
userID := getUserID(c)
repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id"))
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/LocalAGI/core/state"
@@ -19,10 +20,42 @@ import (
agiServices "github.com/mudler/LocalAGI/services"
)
// getUserID extracts the scoped user ID from the request context.
// Returns empty string when auth is not active (backward compat).
func getUserID(c echo.Context) string {
user := auth.GetUser(c)
if user == nil {
return ""
}
return user.ID
}
// isAdminUser returns true if the authenticated user has admin role.
func isAdminUser(c echo.Context) bool {
user := auth.GetUser(c)
return user != nil && user.Role == auth.RoleAdmin
}
// wantsAllUsers returns true if the request has ?all_users=true and the user is admin.
func wantsAllUsers(c echo.Context) bool {
return c.QueryParam("all_users") == "true" && isAdminUser(c)
}
// effectiveUserID returns the user ID to scope operations to.
// SECURITY: Only admins may supply ?user_id=<id> to operate on another user's
// resources. Non-admin callers always get their own ID regardless of query params.
func effectiveUserID(c echo.Context) string {
if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) {
return targetUID
}
return getUserID(c)
}
func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
statuses := svc.ListAgents()
userID := getUserID(c)
statuses := svc.ListAgentsForUser(userID)
agents := make([]string, 0, len(statuses))
for name := range statuses {
agents = append(agents, name)
@@ -38,6 +71,22 @@ func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
if hubURL := svc.AgentHubURL(); hubURL != "" {
resp["agent_hub_url"] = hubURL
}
// Admin cross-user aggregation
if wantsAllUsers(c) {
grouped := svc.ListAllAgentsGrouped()
userGroups := map[string]any{}
for uid, agentList := range grouped {
if uid == userID || uid == "" {
continue
}
userGroups[uid] = map[string]any{"agents": agentList}
}
if len(userGroups) > 0 {
resp["user_groups"] = userGroups
}
}
return c.JSON(http.StatusOK, resp)
}
}
@@ -45,11 +94,12 @@ func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
var cfg state.AgentConfig
if err := c.Bind(&cfg); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
if err := svc.CreateAgent(&cfg); err != nil {
if err := svc.CreateAgentForUser(userID, &cfg); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusCreated, map[string]string{"status": "ok"})
@@ -59,8 +109,9 @@ func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc {
func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
ag := svc.GetAgent(name)
ag := svc.GetAgentForUser(userID, name)
if ag == nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
}
@@ -73,12 +124,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
var cfg state.AgentConfig
if err := c.Bind(&cfg); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
if err := svc.UpdateAgent(name, &cfg); err != nil {
if err := svc.UpdateAgentForUser(userID, name, &cfg); err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -91,8 +143,9 @@ func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc {
func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
if err := svc.DeleteAgent(name); err != nil {
if err := svc.DeleteAgentForUser(userID, name); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -102,8 +155,9 @@ func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc {
func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
cfg := svc.GetAgentConfig(name)
cfg := svc.GetAgentConfigForUser(userID, name)
if cfg == nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
}
@@ -114,7 +168,8 @@ func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc {
func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
if err := svc.PauseAgent(c.Param("name")); err != nil {
userID := effectiveUserID(c)
if err := svc.PauseAgentForUser(userID, c.Param("name")); err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -124,7 +179,8 @@ func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc {
func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
if err := svc.ResumeAgent(c.Param("name")); err != nil {
userID := effectiveUserID(c)
if err := svc.ResumeAgentForUser(userID, c.Param("name")); err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
@@ -134,8 +190,9 @@ func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc {
func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
history := svc.GetAgentStatus(name)
history := svc.GetAgentStatusForUser(userID, name)
if history == nil {
history = &state.Status{ActionResults: []coreTypes.ActionState{}}
}
@@ -162,8 +219,9 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
history, err := svc.GetAgentObservables(name)
history, err := svc.GetAgentObservablesForUser(userID, name)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -177,8 +235,9 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc
func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
if err := svc.ClearAgentObservables(name); err != nil {
if err := svc.ClearAgentObservablesForUser(userID, name); err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]any{"Name": name, "cleared": true})
@@ -188,6 +247,7 @@ func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFun
func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
var payload struct {
Message string `json:"message"`
@@ -199,7 +259,7 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc {
if message == "" {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Message cannot be empty"})
}
messageID, err := svc.Chat(name, message)
messageID, err := svc.ChatForUser(userID, name, message)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
@@ -216,8 +276,9 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc {
func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
manager := svc.GetSSEManager(name)
manager := svc.GetSSEManagerForUser(userID, name)
if manager == nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
}
@@ -243,8 +304,9 @@ func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc {
func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
data, err := svc.ExportAgent(name)
data, err := svc.ExportAgentForUser(userID, name)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
@@ -256,6 +318,7 @@ func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc {
func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := getUserID(c)
// Try multipart form file first
file, err := c.FormFile("file")
@@ -269,7 +332,7 @@ func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc {
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to read file"})
}
if err := svc.ImportAgent(data); err != nil {
if err := svc.ImportAgentForUser(userID, data); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusCreated, map[string]string{"status": "ok"})
@@ -284,7 +347,7 @@ func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc {
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
if err := svc.ImportAgent(data); err != nil {
if err := svc.ImportAgentForUser(userID, data); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusCreated, map[string]string{"status": "ok"})
@@ -358,10 +421,16 @@ func AgentFileEndpoint(app *application.Application) echo.HandlerFunc {
return c.JSON(http.StatusNotFound, map[string]string{"error": "file not found"})
}
// Only serve files from the outputs subdirectory
outputsDir, _ := filepath.EvalSymlinks(filepath.Clean(svc.OutputsDir()))
// Determine the allowed outputs directory — scoped to the user when auth is active
allowedDir := svc.OutputsDir()
user := auth.GetUser(c)
if user != nil {
allowedDir = filepath.Join(allowedDir, user.ID)
}
if utils.InTrustedRoot(resolved, outputsDir) != nil {
allowedDirResolved, _ := filepath.EvalSymlinks(filepath.Clean(allowedDir))
if utils.InTrustedRoot(resolved, allowedDirResolved) != nil {
return c.JSON(http.StatusForbidden, map[string]string{"error": "access denied"})
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"github.com/labstack/echo/v4"
@@ -20,6 +21,9 @@ import (
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
if decoded, err := url.PathUnescape(modelName); err == nil {
modelName = decoded
}
if modelName == "" {
response := ModelResponse{
Success: false,
@@ -82,6 +86,9 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
func EditModelEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
if decoded, err := url.PathUnescape(modelName); err == nil {
modelName = decoded
}
if modelName == "" {
response := ModelResponse{
Success: false,

View File

@@ -0,0 +1,362 @@
package localai
import (
"archive/tar"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// StartFineTuneJobEndpoint starts a new fine-tuning job.
func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
var req schema.FineTuneJobRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Invalid request: " + err.Error(),
})
}
if req.Model == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "model is required",
})
}
if req.DatasetSource == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "dataset_source is required",
})
}
resp, err := ftService.StartJob(c.Request().Context(), userID, req)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusCreated, resp)
}
}
// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user.
func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobs := ftService.ListJobs(userID)
if jobs == nil {
jobs = []*schema.FineTuneJob{}
}
return c.JSON(http.StatusOK, jobs)
}
}
// GetFineTuneJobEndpoint gets a specific fine-tuning job.
func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
job, err := ftService.GetJob(userID, jobID)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, job)
}
}
// StopFineTuneJobEndpoint stops a running fine-tuning job.
func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
// Check for save_checkpoint query param
saveCheckpoint := c.QueryParam("save_checkpoint") == "true"
err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]string{
"status": "stopped",
"message": "Fine-tuning job stopped",
})
}
}
// DeleteFineTuneJobEndpoint deletes a fine-tuning job and its data.
func DeleteFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
err := ftService.DeleteJob(userID, jobID)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
} else if strings.Contains(err.Error(), "cannot delete") {
status = http.StatusConflict
}
return c.JSON(status, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]string{
"status": "deleted",
"message": "Fine-tuning job deleted",
})
}
}
// FineTuneProgressEndpoint streams progress updates via SSE.
func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
// Set SSE headers
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().WriteHeader(http.StatusOK)
err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) {
data, err := json.Marshal(event)
if err != nil {
return
}
fmt.Fprintf(c.Response(), "data: %s\n\n", data)
c.Response().Flush()
})
if err != nil {
// If headers already sent, we can't send a JSON error
fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error())
c.Response().Flush()
}
return nil
}
}
// ListCheckpointsEndpoint lists checkpoints for a job.
func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]any{
"checkpoints": checkpoints,
})
}
}
// ExportModelEndpoint exports a model from a checkpoint.
func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
var req schema.ExportRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Invalid request: " + err.Error(),
})
}
modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusAccepted, map[string]string{
"status": "exporting",
"message": "Export started for model '" + modelName + "'",
"model_name": modelName,
})
}
}
// DownloadExportedModelEndpoint streams the exported model directory as a tar.gz archive.
func DownloadExportedModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
modelDir, modelName, err := ftService.GetExportedModelPath(userID, jobID)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
c.Response().Header().Set("Content-Type", "application/gzip")
c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.tar.gz"`, modelName))
c.Response().WriteHeader(http.StatusOK)
gw := gzip.NewWriter(c.Response())
defer gw.Close()
tw := tar.NewWriter(gw)
defer tw.Close()
err = filepath.Walk(modelDir, func(path string, info os.FileInfo, walkErr error) error {
if walkErr != nil {
return walkErr
}
relPath, err := filepath.Rel(modelDir, path)
if err != nil {
return err
}
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
header.Name = filepath.Join(modelName, relPath)
if err := tw.WriteHeader(header); err != nil {
return err
}
if info.IsDir() {
return nil
}
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(tw, f)
return err
})
if err != nil {
// Headers already sent, can't return JSON error
return err
}
return nil
}
}
// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning".
func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to list backends: " + err.Error(),
})
}
type backendInfo struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Tags []string `json:"tags,omitempty"`
}
var result []backendInfo
for _, b := range backends {
if !b.Installed {
continue
}
hasTag := false
for _, t := range b.Tags {
if strings.EqualFold(t, "fine-tuning") {
hasTag = true
break
}
}
if !hasTag {
continue
}
name := b.Name
if b.Alias != "" {
name = b.Alias
}
result = append(result, backendInfo{
Name: name,
Description: b.Description,
Tags: b.Tags,
})
}
if result == nil {
result = []backendInfo{}
}
return c.JSON(http.StatusOK, result)
}
}
// UploadDatasetEndpoint handles dataset file upload.
func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
file, err := c.FormFile("file")
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "file is required",
})
}
src, err := file.Open()
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to open file",
})
}
defer src.Close()
data, err := io.ReadAll(src)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to read file",
})
}
path, err := ftService.UploadDataset(file.Filename, data)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]string{
"path": path,
})
}
}

View File

@@ -20,6 +20,7 @@ type ModelGalleryEndpointService struct {
backendGalleries []config.Gallery
modelPath string
galleryApplier *services.GalleryService
configLoader *config.ModelConfigLoader
}
type GalleryModel struct {
@@ -27,12 +28,13 @@ type GalleryModel struct {
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService, configLoader *config.ModelConfigLoader) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
backendGalleries: backendGalleries,
modelPath: systemState.Model.ModelsPath,
galleryApplier: galleryApplier,
configLoader: configLoader,
}
}
@@ -103,6 +105,8 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle
GalleryElementName: modelName,
}
mgs.configLoader.RemoveModelConfig(modelName)
uuid, err := uuid.NewUUID()
if err != nil {
return err

View File

@@ -136,6 +136,12 @@ func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc {
appConfig.ApiKeys = append(envKeys, runtimeKeys...)
}
// Update backend logging dynamically
if settings.EnableBackendLogging != nil {
app.ModelLoader().SetBackendLoggingEnabled(*settings.EnableBackendLogging)
xlog.Info("Updated backend logging setting", "enableBackendLogging", *settings.EnableBackendLogging)
}
// Update watchdog dynamically for settings that don't require restart
if settings.ForceEvictionWhenBusy != nil {
currentWD := app.ModelLoader().GetWatchDog()

View File

@@ -82,51 +82,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
template = s
}
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
// Track accumulated content for reasoning extraction
accumulatedContent := ""
lastEmittedReasoning := ""
lastEmittedCleanedContent := ""
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
accumulatedContent += s
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig)
// Calculate new reasoning delta (what we haven't emitted yet)
var reasoningDelta *string
if currentReasoning != lastEmittedReasoning {
// Extract only the new part
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
newReasoning := currentReasoning[len(lastEmittedReasoning):]
reasoningDelta = &newReasoning
lastEmittedReasoning = currentReasoning
} else if currentReasoning != "" {
// If reasoning changed in a non-append way, emit the full current reasoning
reasoningDelta = &currentReasoning
lastEmittedReasoning = currentReasoning
}
}
// Calculate content delta from cleaned content
var deltaContent string
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
lastEmittedCleanedContent = cleanedContent
} else if cleanedContent != lastEmittedCleanedContent {
// If cleaned content changed but not in a simple append, extract delta from cleaned content
// This handles cases where thinking tags are removed mid-stream
if lastEmittedCleanedContent == "" {
deltaContent = cleanedContent
lastEmittedCleanedContent = cleanedContent
} else {
// Content changed in non-append way, use the new cleaned content
deltaContent = cleanedContent
lastEmittedCleanedContent = cleanedContent
}
}
// Only emit content if there's actual content (not just thinking tags)
// If deltaContent is empty, we still emit the response but with empty content
reasoningDelta, contentDelta := extractor.ProcessToken(s)
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
@@ -139,12 +98,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
delta := &schema.Message{}
// Only include content if there's actual content (not just thinking tags)
if deltaContent != "" {
delta.Content = &deltaContent
if contentDelta != "" {
delta.Content = &contentDelta
}
if reasoningDelta != nil && *reasoningDelta != "" {
delta.Reasoning = reasoningDelta
if reasoningDelta != "" {
delta.Reasoning = &reasoningDelta
}
resp := schema.OpenAIResponse{
@@ -171,11 +129,53 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
template = prompt
}
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
result := ""
lastEmittedCount := 0
sentInitialRole := false
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
reasoningDelta, contentDelta := extractor.ProcessToken(s)
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
// (OpenAI spec: reasoning and tool_calls never share a delta)
if reasoningDelta != "" {
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model,
Choices: []schema.Choice{{
Delta: &schema.Message{Reasoning: &reasoningDelta},
Index: 0,
}},
Object: "chat.completion.chunk",
}
}
// Stream content deltas (cleaned of reasoning tags) while no tool calls
// have been detected. Once the incremental parser finds tool calls,
// content stops — per OpenAI spec, content and tool_calls don't mix.
if lastEmittedCount == 0 && contentDelta != "" {
if !sentInitialRole {
responses <- schema.OpenAIResponse{
ID: id, Created: created, Model: req.Model,
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
Object: "chat.completion.chunk",
}
sentInitialRole = true
}
responses <- schema.OpenAIResponse{
ID: id, Created: created, Model: req.Model,
Choices: []schema.Choice{{
Delta: &schema.Message{Content: &contentDelta},
Index: 0,
}},
Object: "chat.completion.chunk",
}
}
// Try incremental XML parsing for streaming support using iterative parser
// This allows emitting partial tool calls as they're being generated
cleanedResult := functions.CleanupLLMResult(result, config.FunctionsConfig)
@@ -279,7 +279,25 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
}
return true
})
},
func(attempt int) bool {
// After streaming completes: check if we got actionable content
cleaned := extractor.CleanedContent()
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
// but we need to know here whether to retry)
hasToolCalls := lastEmittedCount > 0
if cleaned == "" && !hasToolCalls {
xlog.Warn("Streaming: backend produced only reasoning, retrying",
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
extractor.ResetAndSuppressReasoning()
result = ""
lastEmittedCount = 0
sentInitialRole = false
return true
}
return false
},
)
if err != nil {
return err
}
@@ -296,30 +314,17 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
} else {
// Fallback: parse tool calls from raw text (no chat deltas from backend)
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
reasoning, result = reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig)
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
functionResults = functions.ParseFunctionCall(result, config.FunctionsConfig)
reasoning = extractor.Reasoning()
cleanedResult := extractor.CleanedContent()
textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig)
cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig)
functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig)
}
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", textContentToReturn)
noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0
switch {
case noActionToRun:
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
result, err := handleQuestion(config, functionResults, result, prompt)
if err != nil {
xlog.Error("error handling question", "error", err)
return err
}
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
@@ -330,25 +335,43 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}
var deltaReasoning *string
if reasoning != "" {
deltaReasoning = &reasoning
}
delta := &schema.Message{Content: &result}
if deltaReasoning != nil {
delta.Reasoning = deltaReasoning
}
if sentInitialRole {
// Content was already streamed during the callback — just emit usage.
delta := &schema.Message{}
if reasoning != "" && extractor.Reasoning() == "" {
delta.Reasoning = &reasoning
}
responses <- schema.OpenAIResponse{
ID: id, Created: created, Model: req.Model,
Choices: []schema.Choice{{Delta: delta, Index: 0}},
Object: "chat.completion.chunk",
Usage: usage,
}
} else {
// Content was NOT streamed — send everything at once (fallback).
responses <- schema.OpenAIResponse{
ID: id, Created: created, Model: req.Model,
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
Object: "chat.completion.chunk",
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
Object: "chat.completion.chunk",
Usage: usage,
}
result, err := handleQuestion(config, functionResults, extractor.CleanedContent(), prompt)
if err != nil {
xlog.Error("error handling question", "error", err)
return err
}
responses <- resp
delta := &schema.Message{Content: &result}
if reasoning != "" {
delta.Reasoning = &reasoning
}
responses <- schema.OpenAIResponse{
ID: id, Created: created, Model: req.Model,
Choices: []schema.Choice{{Delta: delta, Index: 0}},
Object: "chat.completion.chunk",
Usage: usage,
}
}
default:
for i, ss := range functionResults {
@@ -907,7 +930,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
// is deferred to after ComputeChoices so we can check chat deltas first
// and avoid redundant Go-side parsing.
var cbRawResult, cbReasoning string
var emptyRetryNeeded bool
tokenCallback := func(s string, c *[]schema.Choice) {
reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig)
@@ -927,146 +949,133 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
cbReasoning = reasoning
}
const maxEmptyRetries = 5
var result []schema.Choice
var tokenUsage backend.TokenUsage
var err error
var chatDeltas []*pb.ChatDelta
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
emptyRetryNeeded = false
result, tokenUsage, chatDeltas, err = ComputeChoices(
input,
predInput,
config,
cl,
startupOptions,
ml,
tokenCallback,
nil,
)
if err != nil {
break
}
// Tool parsing is deferred here (only when shouldUseFn)
if shouldUseFn {
var funcResults []functions.FuncCallResults
// Try pre-parsed tool calls from C++ autoparser first
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls))
funcResults = deltaToolCalls
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
} else {
// Fallback: parse tool calls from raw text
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig)
cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig)
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
result, tokenUsage, chatDeltas, err = ComputeChoices(
input,
predInput,
config,
cl,
startupOptions,
ml,
tokenCallback,
nil,
func(attempt int) bool {
if !shouldUseFn {
return false
}
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
switch {
case noActionsToRun:
if cbRawResult == "" && textContentToReturn == "" {
xlog.Warn("Backend returned empty content in tool-calling context, will retry")
emptyRetryNeeded = true
continue
}
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
if qErr != nil {
xlog.Error("error handling question", "error", qErr)
emptyRetryNeeded = true
continue
}
stopReason := FinishReasonStop
message := &schema.Message{Role: "assistant", Content: &qResult}
if cbReasoning != "" {
message.Reasoning = &cbReasoning
}
result = append(result, schema.Choice{
FinishReason: &stopReason,
Message: message,
})
default:
toolCallsReason := FinishReasonToolCalls
toolChoice := schema.Choice{
FinishReason: &toolCallsReason,
Message: &schema.Message{
Role: "assistant",
},
}
if cbReasoning != "" {
toolChoice.Message.Reasoning = &cbReasoning
}
for _, ss := range funcResults {
name, args := ss.Name, ss.Arguments
toolCallID := ss.ID
if toolCallID == "" {
toolCallID = id
}
if len(input.Tools) > 0 {
toolChoice.Message.Content = textContentToReturn
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: toolCallID,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
)
} else {
// Deprecated function_call format
functionCallReason := FinishReasonFunctionCall
message := &schema.Message{
Role: "assistant",
Content: &textContentToReturn,
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
}
if cbReasoning != "" {
message.Reasoning = &cbReasoning
}
result = append(result, schema.Choice{
FinishReason: &functionCallReason,
Message: message,
})
}
}
if len(input.Tools) > 0 {
result = append(result, toolChoice)
}
// Retry when backend produced only reasoning and no content/tool calls.
// Full tool parsing is deferred until after ComputeChoices returns
// (when chat deltas are available), but we can detect the empty case here.
if cbRawResult == "" && textContentToReturn == "" {
xlog.Warn("Backend produced reasoning without actionable content, retrying",
"reasoning_len", len(cbReasoning), "attempt", attempt+1)
cbRawResult = ""
cbReasoning = ""
textContentToReturn = ""
return true
}
}
if !emptyRetryNeeded {
break
}
xlog.Warn("Retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
}
return false
},
)
if err != nil {
return err
}
if emptyRetryNeeded {
xlog.Warn("All retries exhausted, backend still returning empty content")
stopReason := FinishReasonStop
empty := ""
result = append(result, schema.Choice{
FinishReason: &stopReason,
Index: 0,
Message: &schema.Message{Role: "assistant", Content: &empty},
})
// Tool parsing is deferred here (only when shouldUseFn) so chat deltas are available
if shouldUseFn {
var funcResults []functions.FuncCallResults
// Try pre-parsed tool calls from C++ autoparser first
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls))
funcResults = deltaToolCalls
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
} else {
// Fallback: parse tool calls from raw text
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig)
cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig)
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
}
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
switch {
case noActionsToRun:
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
if qErr != nil {
xlog.Error("error handling question", "error", qErr)
}
stopReason := FinishReasonStop
message := &schema.Message{Role: "assistant", Content: &qResult}
if cbReasoning != "" {
message.Reasoning = &cbReasoning
}
result = append(result, schema.Choice{
FinishReason: &stopReason,
Message: message,
})
default:
toolCallsReason := FinishReasonToolCalls
toolChoice := schema.Choice{
FinishReason: &toolCallsReason,
Message: &schema.Message{
Role: "assistant",
},
}
if cbReasoning != "" {
toolChoice.Message.Reasoning = &cbReasoning
}
for _, ss := range funcResults {
name, args := ss.Name, ss.Arguments
toolCallID := ss.ID
if toolCallID == "" {
toolCallID = id
}
if len(input.Tools) > 0 {
toolChoice.Message.Content = textContentToReturn
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: toolCallID,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
)
} else {
// Deprecated function_call format
functionCallReason := FinishReasonFunctionCall
message := &schema.Message{
Role: "assistant",
Content: &textContentToReturn,
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
}
if cbReasoning != "" {
message.Reasoning = &cbReasoning
}
result = append(result, schema.Choice{
FinishReason: &functionCallReason,
Message: message,
})
}
}
if len(input.Tools) > 0 {
result = append(result, toolChoice)
}
}
}
// MCP server-side tool execution loop:
@@ -1203,5 +1212,5 @@ func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCall
xlog.Debug("No action received from LLM, without a message, computing a reply")
return "", fmt.Errorf("no action received from LLM, without a message, computing a reply")
return "", nil
}

View File

@@ -0,0 +1,157 @@
package openai
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/functions"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/schema"
)
var _ = Describe("handleQuestion", func() {
var cfg *config.ModelConfig
BeforeEach(func() {
cfg = &config.ModelConfig{}
})
Context("with no function results but non-empty result", func() {
It("should return the result directly", func() {
result, err := handleQuestion(cfg, nil, "Hello world", "prompt")
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal("Hello world"))
})
})
Context("with no function results and empty result", func() {
It("should return empty string", func() {
result, err := handleQuestion(cfg, nil, "", "prompt")
Expect(err).ToNot(HaveOccurred())
Expect(result).To(BeEmpty())
})
})
Context("with function result containing a message argument", func() {
It("should extract the message from function arguments", func() {
funcResults := []functions.FuncCallResults{
{
Name: "answer",
Arguments: `{"message": "This is the answer"}`,
},
}
result, err := handleQuestion(cfg, funcResults, "", "prompt")
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal("This is the answer"))
})
})
Context("with function result containing empty message", func() {
It("should return empty string when message is empty", func() {
funcResults := []functions.FuncCallResults{
{
Name: "answer",
Arguments: `{"message": ""}`,
},
}
result, err := handleQuestion(cfg, funcResults, "", "prompt")
Expect(err).ToNot(HaveOccurred())
Expect(result).To(BeEmpty())
})
})
Context("with function result containing invalid JSON arguments", func() {
It("should return empty string gracefully", func() {
funcResults := []functions.FuncCallResults{
{
Name: "answer",
Arguments: "not json",
},
}
result, err := handleQuestion(cfg, funcResults, "", "prompt")
Expect(err).ToNot(HaveOccurred())
Expect(result).To(BeEmpty())
})
})
Context("with cleaned content (no think tags)", func() {
It("should return content without think tags", func() {
// This tests the bug fix: handleQuestion should receive cleaned content,
// not raw text with <think> tags
result, err := handleQuestion(cfg, nil, "Just the answer", "prompt")
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal("Just the answer"))
Expect(result).ToNot(ContainSubstring("<think>"))
})
})
Context("with raw think tags passed as result", func() {
It("would return content with think tags", func() {
result, err := handleQuestion(cfg, nil, "<think>reasoning</think>answer", "prompt")
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal("<think>reasoning</think>answer"))
})
})
})
var _ = Describe("mergeToolCallDeltas", func() {
Context("with new tool calls", func() {
It("should append new tool calls", func() {
existing := []schema.ToolCall{}
deltas := []schema.ToolCall{
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}},
}
result := mergeToolCallDeltas(existing, deltas)
Expect(result).To(HaveLen(1))
Expect(result[0].ID).To(Equal("tc1"))
Expect(result[0].FunctionCall.Name).To(Equal("search"))
})
})
Context("with argument appending", func() {
It("should append arguments to existing tool call", func() {
existing := []schema.ToolCall{
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search", Arguments: `{"q":`}},
}
deltas := []schema.ToolCall{
{Index: 0, FunctionCall: schema.FunctionCall{Arguments: `"hello"}`}},
}
result := mergeToolCallDeltas(existing, deltas)
Expect(result).To(HaveLen(1))
Expect(result[0].FunctionCall.Arguments).To(Equal(`{"q":"hello"}`))
})
})
Context("with multiple tool calls", func() {
It("should track multiple tool calls by index", func() {
existing := []schema.ToolCall{}
deltas1 := []schema.ToolCall{
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}},
}
result := mergeToolCallDeltas(existing, deltas1)
deltas2 := []schema.ToolCall{
{Index: 1, ID: "tc2", Type: "function", FunctionCall: schema.FunctionCall{Name: "browse"}},
}
result = mergeToolCallDeltas(result, deltas2)
Expect(result).To(HaveLen(2))
Expect(result[0].FunctionCall.Name).To(Equal("search"))
Expect(result[1].FunctionCall.Name).To(Equal("browse"))
})
})
Context("with ID update on existing tool call", func() {
It("should update ID when provided in delta", func() {
existing := []schema.ToolCall{
{Index: 0, FunctionCall: schema.FunctionCall{Name: "search"}},
}
deltas := []schema.ToolCall{
{Index: 0, ID: "new-id"},
}
result := mergeToolCallDeltas(existing, deltas)
Expect(result).To(HaveLen(1))
Expect(result[0].ID).To(Equal("new-id"))
Expect(result[0].FunctionCall.Name).To(Equal("search"))
})
})
})

View File

@@ -2,6 +2,7 @@ package openai
import (
"encoding/json"
"strings"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
@@ -9,6 +10,7 @@ import (
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
func ComputeChoices(
@@ -19,7 +21,9 @@ func ComputeChoices(
o *config.ApplicationConfig,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
tokenCallback func(string, backend.TokenUsage) bool,
shouldRetry ...func(int) bool,
) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
n := req.N // number of completions to return
result := []schema.Choice{}
@@ -27,6 +31,12 @@ func ComputeChoices(
n = 1
}
// Extract the optional shouldRetry callback
var shouldRetryFn func(int) bool
if len(shouldRetry) > 0 {
shouldRetryFn = shouldRetry[0]
}
images := []string{}
for _, m := range req.Messages {
images = append(images, m.StringImages...)
@@ -82,7 +92,7 @@ func ComputeChoices(
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(
predFunc, err := backend.ModelInferenceFunc(
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, req.Metadata)
if err != nil {
return result, backend.TokenUsage{}, nil, err
@@ -91,32 +101,49 @@ func ComputeChoices(
tokenUsage := backend.TokenUsage{}
var allChatDeltas []*pb.ChatDelta
const maxRetries = 5
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, backend.TokenUsage{}, nil, err
var prediction backend.LLMResponse
for attempt := 0; attempt <= maxRetries; attempt++ {
p, err := predFunc()
if err != nil {
return result, backend.TokenUsage{}, nil, err
}
prediction = p
// Built-in: retry on truly empty response (no tokens at all)
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries {
xlog.Warn("Backend returned empty response, retrying",
"attempt", attempt+1, "maxRetries", maxRetries)
continue
}
tokenUsage.Prompt = prediction.Usage.Prompt
tokenUsage.Completion = prediction.Usage.Completion
tokenUsage.TimingPromptProcessing = prediction.Usage.TimingPromptProcessing
tokenUsage.TimingTokenGeneration = prediction.Usage.TimingTokenGeneration
allChatDeltas = prediction.ChatDeltas
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
// Caller-driven retry (tool parsing, reasoning-only, etc.)
if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries {
// Caller has already reset its state inside shouldRetry
result = result[:0]
allChatDeltas = nil
continue
}
break
}
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing
tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration
// Collect chat deltas from C++ autoparser
if len(prediction.ChatDeltas) > 0 {
allChatDeltas = append(allChatDeltas, prediction.ChatDeltas...)
}
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
// Add logprobs to the last choice if present
if prediction.Logprobs != nil && len(result) > 0 {
result[len(result)-1].Logprobs = prediction.Logprobs
}
//result = append(result, Choice{Text: prediction})
}
return result, tokenUsage, allChatDeltas, err
}

View File

@@ -0,0 +1,402 @@
package openai
import (
"context"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type modelInferenceFunc = func(
ctx context.Context, s string, messages schema.Messages,
images, videos, audios []string,
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
o *config.ApplicationConfig,
tokenCallback func(string, backend.TokenUsage) bool,
tools, toolChoice string,
logprobs, topLogprobs *int,
logitBias map[string]float64,
metadata map[string]string,
) (func() (backend.LLMResponse, error), error)
var _ = Describe("ComputeChoices", func() {
var (
origInference modelInferenceFunc
cfg *config.ModelConfig
appCfg *config.ApplicationConfig
)
// mockInference installs a stub that yields the given responses sequentially.
// After all responses are consumed, the last one is repeated.
mockInference := func(responses []backend.LLMResponse) {
idx := 0
backend.ModelInferenceFunc = func(
ctx context.Context, s string, messages schema.Messages,
images, videos, audios []string,
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
o *config.ApplicationConfig,
tokenCallback func(string, backend.TokenUsage) bool,
tools, toolChoice string,
logprobs, topLogprobs *int,
logitBias map[string]float64,
metadata map[string]string,
) (func() (backend.LLMResponse, error), error) {
predFunc := func() (backend.LLMResponse, error) {
resp := responses[idx]
if idx < len(responses)-1 {
idx++
}
return resp, nil
}
return predFunc, nil
}
}
BeforeEach(func() {
origInference = backend.ModelInferenceFunc
cfg = &config.ModelConfig{}
appCfg = config.NewApplicationConfig()
})
AfterEach(func() {
backend.ModelInferenceFunc = origInference
})
makeReq := func() *schema.OpenAIRequest {
ctx, cancel := context.WithCancel(context.Background())
_ = cancel
return &schema.OpenAIRequest{
Context: ctx,
Cancel: cancel,
}
}
Context("normal response (no retry needed)", func() {
It("should return choices on first attempt", func() {
mockInference([]backend.LLMResponse{
{Response: "Hello world", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}},
})
var captured string
choices, usage, _, err := ComputeChoices(
makeReq(), "test prompt", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
captured = s
*c = append(*c, schema.Choice{Text: s})
},
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(1))
Expect(captured).To(Equal("Hello world"))
Expect(usage.Prompt).To(Equal(10))
Expect(usage.Completion).To(Equal(5))
})
})
Context("empty response triggers built-in retry", func() {
It("should retry and eventually return non-empty response", func() {
mockInference([]backend.LLMResponse{
{Response: ""}, // attempt 0: empty
{Response: " "}, // attempt 1: whitespace-only
{Response: "Got it", Usage: backend.TokenUsage{Prompt: 8, Completion: 3}}, // attempt 2: success
})
choices, usage, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(1))
Expect(choices[0].Text).To(Equal("Got it"))
Expect(usage.Prompt).To(Equal(8))
Expect(usage.Completion).To(Equal(3))
})
})
Context("all retries exhausted on empty response", func() {
It("should return the empty response after max retries", func() {
mockInference([]backend.LLMResponse{
{Response: ""}, // always empty
})
choices, _, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
)
Expect(err).ToNot(HaveOccurred())
// After maxRetries, it proceeds with the empty response
Expect(choices).To(HaveLen(1))
Expect(choices[0].Text).To(BeEmpty())
})
})
Context("shouldRetry callback", func() {
It("should call shouldRetry and retry when it returns true", func() {
callCount := 0
mockInference([]backend.LLMResponse{
{Response: "reasoning-only", Usage: backend.TokenUsage{Prompt: 5, Completion: 2}},
{Response: "actual-answer", Usage: backend.TokenUsage{Prompt: 5, Completion: 4}},
})
retryAttempts := []int{}
choices, usage, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
callCount++
*c = append(*c, schema.Choice{Text: s})
},
nil,
func(attempt int) bool {
retryAttempts = append(retryAttempts, attempt)
// Retry on first attempt only
return attempt == 0
},
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(1))
Expect(choices[0].Text).To(Equal("actual-answer"))
// shouldRetry was called twice: once returning true (retry), once returning false (proceed)
Expect(retryAttempts).To(Equal([]int{0, 1}))
// cb was called twice (once per attempt)
Expect(callCount).To(Equal(2))
// Token usage should be from the LATEST attempt
Expect(usage.Prompt).To(Equal(5))
Expect(usage.Completion).To(Equal(4))
})
It("should not retry when shouldRetry returns false", func() {
mockInference([]backend.LLMResponse{
{Response: "first-response"},
})
shouldRetryCalled := false
choices, _, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
func(attempt int) bool {
shouldRetryCalled = true
return false
},
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(1))
Expect(choices[0].Text).To(Equal("first-response"))
Expect(shouldRetryCalled).To(BeTrue())
})
})
Context("shouldRetry not provided (variadic omitted)", func() {
It("should work without shouldRetry parameter", func() {
mockInference([]backend.LLMResponse{
{Response: "works"},
})
choices, _, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(1))
Expect(choices[0].Text).To(Equal("works"))
})
})
Context("token usage from latest attempt", func() {
It("should use token usage from the last attempt, not accumulated", func() {
mockInference([]backend.LLMResponse{
{Response: "retry-me", Usage: backend.TokenUsage{Prompt: 100, Completion: 50}},
{Response: "final", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}},
})
_, usage, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
func(attempt int) bool { return attempt == 0 },
)
Expect(err).ToNot(HaveOccurred())
// Should be the LATEST attempt's usage, not accumulated
Expect(usage.Prompt).To(Equal(10))
Expect(usage.Completion).To(Equal(5))
})
})
Context("chat deltas from latest attempt", func() {
It("should return chat deltas from the last attempt only", func() {
mockInference([]backend.LLMResponse{
{
Response: "retry-me",
ChatDeltas: []*pb.ChatDelta{{Content: "old"}},
},
{
Response: "final",
ChatDeltas: []*pb.ChatDelta{{Content: "new"}},
},
})
_, _, deltas, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
func(attempt int) bool { return attempt == 0 },
)
Expect(err).ToNot(HaveOccurred())
Expect(deltas).To(HaveLen(1))
Expect(deltas[0].Content).To(Equal("new"))
})
})
Context("result choices cleared on retry", func() {
It("should only contain choices from the final attempt", func() {
mockInference([]backend.LLMResponse{
{Response: "bad-choice"},
{Response: "good-choice"},
})
choices, _, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
func(attempt int) bool { return attempt == 0 },
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(1))
Expect(choices[0].Text).To(Equal("good-choice"))
})
})
Context("shouldRetry with max retries cap", func() {
It("should stop retrying after maxRetries even if shouldRetry returns true", func() {
attempts := 0
mockInference([]backend.LLMResponse{
{Response: "always-retry"},
})
choices, _, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
func(attempt int) bool {
attempts++
return true // always want to retry
},
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(1))
// maxRetries is 5, so shouldRetry is called for attempts 0..4,
// but attempt 5 is the final one where shouldRetry can't trigger continue
Expect(attempts).To(BeNumerically("<=", 6))
})
})
Context("N > 1 completions", func() {
It("should produce N separate completions", func() {
callIdx := 0
responses := []string{"first", "second", "third"}
backend.ModelInferenceFunc = func(
ctx context.Context, s string, messages schema.Messages,
images, videos, audios []string,
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
o *config.ApplicationConfig,
tokenCallback func(string, backend.TokenUsage) bool,
tools, toolChoice string,
logprobs, topLogprobs *int,
logitBias map[string]float64,
metadata map[string]string,
) (func() (backend.LLMResponse, error), error) {
predFunc := func() (backend.LLMResponse, error) {
resp := backend.LLMResponse{Response: responses[callIdx]}
if callIdx < len(responses)-1 {
callIdx++
}
return resp, nil
}
return predFunc, nil
}
req := makeReq()
req.N = 3
choices, _, _, err := ComputeChoices(
req, "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(3))
Expect(choices[0].Text).To(Equal("first"))
Expect(choices[1].Text).To(Equal("second"))
Expect(choices[2].Text).To(Equal("third"))
})
})
Context("with streaming token callback", func() {
It("should call tokenCallback for streaming responses", func() {
var streamedTokens []string
backend.ModelInferenceFunc = func(
ctx context.Context, s string, messages schema.Messages,
images, videos, audios []string,
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
o *config.ApplicationConfig,
tokenCallback func(string, backend.TokenUsage) bool,
tools, toolChoice string,
logprobs, topLogprobs *int,
logitBias map[string]float64,
metadata map[string]string,
) (func() (backend.LLMResponse, error), error) {
predFunc := func() (backend.LLMResponse, error) {
if tokenCallback != nil {
tokenCallback("Hello", backend.TokenUsage{Prompt: 5})
tokenCallback(" world", backend.TokenUsage{Prompt: 5, Completion: 2})
}
return backend.LLMResponse{
Response: "Hello world",
Usage: backend.TokenUsage{Prompt: 5, Completion: 2},
}, nil
}
return predFunc, nil
}
choices, _, _, err := ComputeChoices(
makeReq(), "test", cfg, nil, appCfg, nil,
func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
},
func(s string, usage backend.TokenUsage) bool {
streamedTokens = append(streamedTokens, s)
return true
},
)
Expect(err).ToNot(HaveOccurred())
Expect(choices).To(HaveLen(1))
Expect(streamedTokens).To(Equal([]string{"Hello", " world"}))
})
})
})

View File

@@ -3,16 +3,22 @@ package openai
import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
model "github.com/mudler/LocalAI/pkg/model"
"gorm.io/gorm"
)
// ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models
// @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get]
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, db ...*gorm.DB) echo.HandlerFunc {
var authDB *gorm.DB
if len(db) > 0 {
authDB = db[0]
}
return func(c echo.Context) error {
// If blank, no filter is applied.
filter := c.QueryParam("filter")
@@ -36,6 +42,26 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
return err
}
// Filter models by user's allowlist if auth is enabled
if authDB != nil {
if user := auth.GetUser(c); user != nil && user.Role != auth.RoleAdmin {
perm, err := auth.GetCachedUserPermissions(c, authDB, user.ID)
if err == nil && perm.AllowedModels.Enabled {
allowed := map[string]bool{}
for _, m := range perm.AllowedModels.Models {
allowed[m] = true
}
filtered := make([]string, 0, len(modelNames))
for _, m := range modelNames {
if allowed[m] {
filtered = append(filtered, m)
}
}
modelNames = filtered
}
}
}
// Map from a slice of names to a slice of OpenAIModel response objects
dataModels := []schema.OpenAIModel{}
for _, m := range modelNames {

View File

@@ -12,6 +12,7 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
@@ -879,34 +880,14 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
xlog.Debug("Background MCP re-templating", "iteration", mcpIteration)
}
images := []string{}
videos := []string{}
audios := []string{}
for _, m := range openAIReq.Messages {
images = append(images, m.StringImages...)
videos = append(videos, m.StringVideos...)
audios = append(audios, m.StringAudios...)
}
toolsJSON := serializeToolsForBackend(input.Tools)
toolChoiceJSON := ""
if input.ToolChoice != nil {
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
if err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
}
var logprobs *int
// Populate openAIReq fields for ComputeChoices
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
openAIReq.ToolsChoice = input.ToolChoice
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
logprobs = input.TopLogprobs
}
predFunc, err := backend.ModelInference(
ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias, nil)
if err != nil {
return nil, fmt.Errorf("model inference failed: %w", err)
openAIReq.TopLogprobs = input.TopLogprobs
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
}
openAIReq.LogitBias = input.LogitBias
select {
case <-ctx.Done():
@@ -914,24 +895,19 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
default:
}
const maxEmptyRetries = 5
var prediction backend.LLMResponse
var result string
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
prediction, err = predFunc()
if err != nil {
return nil, fmt.Errorf("prediction failed: %w", err)
}
result = backend.Finetune(*cfg, predInput, prediction.Response)
if result != "" || !shouldUseFn {
break
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
xlog.Warn("Open Responses background: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
cb := func(s string, c *[]schema.Choice) {
result = s
}
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
if err != nil {
return nil, fmt.Errorf("model inference failed: %w", err)
}
// Extract logprobs from choices if available
var resultLogprobs *schema.Logprobs
if len(choices) > 0 {
resultLogprobs = choices[0].Logprobs
}
// Parse tool calls
@@ -939,9 +915,9 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
var textContent string
if shouldUseFn {
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
funcCallResults = deltaToolCalls
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
textContent = functions.ContentFromChatDeltas(chatDeltas)
} else {
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
@@ -1021,7 +997,7 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
allOutputItems = append(allOutputItems, schema.ORItemField{
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
Status: "completed", Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)},
})
}
for _, tc := range toolCalls {
@@ -1034,22 +1010,22 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
allOutputItems = append(allOutputItems, schema.ORItemField{
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
Status: "completed", Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)},
})
}
} else {
allOutputItems = append(allOutputItems, schema.ORItemField{
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
Status: "completed", Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)},
})
}
now := time.Now().Unix()
return buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, allOutputItems, &schema.ORUsage{
InputTokens: prediction.Usage.Prompt,
OutputTokens: prediction.Usage.Completion,
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
InputTokens: tokenUsage.Prompt,
OutputTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}, true), nil
} // end MCP iteration loop
@@ -1058,23 +1034,14 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
// handleBackgroundStream handles background streaming responses with event buffering
func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) {
images := []string{}
videos := []string{}
audios := []string{}
for _, m := range openAIReq.Messages {
images = append(images, m.StringImages...)
videos = append(videos, m.StringVideos...)
audios = append(audios, m.StringAudios...)
}
toolsJSON := serializeToolsForBackend(input.Tools)
toolChoiceJSON := ""
if input.ToolChoice != nil {
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
if err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
// Populate openAIReq fields for ComputeChoices
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
openAIReq.ToolsChoice = input.ToolChoice
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
openAIReq.TopLogprobs = input.TopLogprobs
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
}
openAIReq.LogitBias = input.LogitBias
sequenceNumber := 0
@@ -1105,20 +1072,13 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
}
hasMCPTools := len(mcpToolInfos) > 0
var prediction backend.LLMResponse
var lastTokenUsage backend.TokenUsage
var lastLogprobs *schema.Logprobs
for mcpIter := 0; mcpIter <= mcpBgStreamMaxIterations; mcpIter++ {
if mcpIter > 0 {
predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
xlog.Debug("Background stream MCP re-templating", "iteration", mcpIter)
images = images[:0]
videos = videos[:0]
audios = audios[:0]
for _, m := range openAIReq.Messages {
images = append(images, m.StringImages...)
videos = append(videos, m.StringVideos...)
audios = append(audios, m.StringAudios...)
}
}
accumulatedText = ""
@@ -1177,28 +1137,23 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
return true
}
var streamLogprobs *int
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
streamLogprobs = input.TopLogprobs
var result string
cb := func(s string, c *[]schema.Choice) {
result = s
}
predFunc, err := backend.ModelInference(
ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias, nil)
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, tokenCallback)
if err != nil {
return nil, fmt.Errorf("model inference failed: %w", err)
}
prediction, err = predFunc()
if err != nil {
return nil, fmt.Errorf("prediction failed: %w", err)
lastTokenUsage = tokenUsage
if len(choices) > 0 {
lastLogprobs = choices[0].Logprobs
}
result := backend.Finetune(*cfg, predInput, prediction.Response)
// Check for MCP tool calls in the streamed result
if shouldUseFn && hasMCPTools {
var funcCallResults []functions.FuncCallResults
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
funcCallResults = deltaToolCalls
} else {
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
@@ -1315,7 +1270,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
}
// No MCP tools — close the message and break
streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
streamEventLogprobs := convertLogprobsForStreaming(lastLogprobs)
bufferEvent(store, responseID, &schema.ORStreamEvent{
Type: "response.output_text.done",
SequenceNumber: sequenceNumber,
@@ -1327,7 +1282,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
})
sequenceNumber++
textPart := makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs)
textPart := makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs)
bufferEvent(store, responseID, &schema.ORStreamEvent{
Type: "response.content_part.done",
SequenceNumber: sequenceNumber,
@@ -1343,7 +1298,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
ID: currentMessageID,
Status: "completed",
Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs)},
}
bufferEvent(store, responseID, &schema.ORStreamEvent{
Type: "response.output_item.done",
@@ -1360,9 +1315,9 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
// Build final response
now := time.Now().Unix()
response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, collectedOutputItems, &schema.ORUsage{
InputTokens: prediction.Usage.Prompt,
OutputTokens: prediction.Usage.Completion,
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
InputTokens: lastTokenUsage.Prompt,
OutputTokens: lastTokenUsage.Completion,
TotalTokens: lastTokenUsage.Prompt + lastTokenUsage.Completion,
}, true)
// Emit response.completed
@@ -1377,6 +1332,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
// bufferEvent stores an SSE event in the response store for streaming resume
func bufferEvent(store *ResponseStore, responseID string, event *schema.ORStreamEvent) {
normalizeORStreamEvent(event)
if err := store.AppendEvent(responseID, event); err != nil {
xlog.Error("Failed to buffer event", "response_id", responseID, "error", err)
}
@@ -1391,52 +1347,27 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
if mcpIteration > mcpMaxIterations {
return sendOpenResponsesError(c, 500, "server_error", "MCP iteration limit reached", "")
}
images := []string{}
videos := []string{}
audios := []string{}
for _, m := range openAIReq.Messages {
images = append(images, m.StringImages...)
videos = append(videos, m.StringVideos...)
audios = append(audios, m.StringAudios...)
}
// Convert and serialize tools to OpenAI format for the backend
toolsJSON := serializeToolsForBackend(input.Tools)
toolChoiceJSON := ""
if input.ToolChoice != nil {
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
if err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
}
// Pass logprobs and logit_bias parameters if requested
var logprobs *int
// Populate openAIReq fields for ComputeChoices
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
openAIReq.ToolsChoice = input.ToolChoice
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
logprobs = input.TopLogprobs
openAIReq.TopLogprobs = input.TopLogprobs
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
}
openAIReq.LogitBias = input.LogitBias
predFunc, err := backend.ModelInference(
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias, nil)
var result string
cb := func(s string, c *[]schema.Choice) {
result = s
}
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
if err != nil {
xlog.Error("Open Responses model inference failed", "error", err)
return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("model inference failed: %v", err), "")
}
const maxEmptyRetries = 5
var prediction backend.LLMResponse
var result string
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
prediction, err = predFunc()
if err != nil {
xlog.Error("Open Responses prediction failed", "error", err)
return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("prediction failed: %v", err), "")
}
result = backend.Finetune(*cfg, predInput, prediction.Response)
if result != "" || !shouldUseFn {
break
}
xlog.Warn("Open Responses: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
var resultLogprobs *schema.Logprobs
if len(choices) > 0 {
resultLogprobs = choices[0].Logprobs
}
xlog.Debug("Open Responses - Raw model result", "result", result, "shouldUseFn", shouldUseFn)
@@ -1473,10 +1404,10 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
var textContent string
// Try pre-parsed tool calls from C++ autoparser first
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls))
funcCallResults = deltaToolCalls
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
textContent = functions.ContentFromChatDeltas(chatDeltas)
} else {
xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing")
// Clean up the result (already extracted reasoning above)
@@ -1574,7 +1505,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
Status: "completed",
Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)},
})
}
@@ -1605,7 +1536,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
Status: "completed",
Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)},
})
}
} else {
@@ -1615,7 +1546,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
Status: "completed",
Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)},
}
outputItems = append(outputItems, messageItem)
}
@@ -1633,9 +1564,9 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
// Build response with all required fields
now := time.Now().Unix()
response := buildORResponse(responseID, createdAt, &now, "completed", input, outputItems, &schema.ORUsage{
InputTokens: prediction.Usage.Prompt,
OutputTokens: prediction.Usage.Completion,
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
InputTokens: tokenUsage.Prompt,
OutputTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
OutputTokensDetails: &schema.OROutputTokensDetails{
ReasoningTokens: reasoningTokens,
},
@@ -1675,24 +1606,14 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
})
sequenceNumber++
images := []string{}
videos := []string{}
audios := []string{}
for _, m := range openAIReq.Messages {
images = append(images, m.StringImages...)
videos = append(videos, m.StringVideos...)
audios = append(audios, m.StringAudios...)
}
// Convert and serialize tools to OpenAI format for the backend
toolsJSON := serializeToolsForBackend(input.Tools)
toolChoiceJSON := ""
if input.ToolChoice != nil {
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
if err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
// Populate openAIReq fields for ComputeChoices
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
openAIReq.ToolsChoice = input.ToolChoice
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
openAIReq.TopLogprobs = input.TopLogprobs
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
}
openAIReq.LogitBias = input.LogitBias
// Detect if thinking token is already in prompt or template
var template string
@@ -1714,10 +1635,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
// Track reasoning state for streaming
var currentReasoningID string
var currentReasoningContentIndex int
var accumulatedContent string
var lastEmittedReasoning string
var lastEmittedCleanedContent string
var reasoningTokens int
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
// Collect all output items for storage
var collectedOutputItems []schema.ORItemField
@@ -1729,24 +1648,25 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
}
hasMCPToolsStream := len(mcpToolInfos) > 0
var prediction backend.LLMResponse
var result, finalReasoning, finalCleanedResult string
var textContent string
var parsedToolCalls []functions.FuncCallResults
var toolCalls []functions.FuncCallResults
var lastStreamTokenUsage backend.TokenUsage
var lastStreamLogprobs *schema.Logprobs
for mcpStreamIter := 0; mcpStreamIter <= mcpStreamMaxIterations; mcpStreamIter++ {
if mcpStreamIter > 0 {
// Reset reasoning and tool-call state for re-inference so reasoning
// extraction runs again on subsequent iterations
inToolCallMode = false
extractor.Reset()
currentMessageID = ""
lastEmittedToolCallCount = 0
currentReasoningID = ""
predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
xlog.Debug("Open Responses stream MCP re-templating", "iteration", mcpStreamIter)
images = images[:0]
videos = videos[:0]
audios = audios[:0]
for _, m := range openAIReq.Messages {
images = append(images, m.StringImages...)
videos = append(videos, m.StringVideos...)
audios = append(audios, m.StringAudios...)
}
}
// For tool calls, we need to track accumulated result and parse incrementally
@@ -1901,11 +1821,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
// If no tool calls detected yet, handle reasoning and text
if !inToolCallMode {
accumulatedContent += token
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, cfg.ReasoningConfig)
reasoningDelta, contentDelta := extractor.ProcessToken(token)
// Handle reasoning item
if currentReasoning != "" {
if extractor.Reasoning() != "" {
// Check if we need to create reasoning item
if currentReasoningID == "" {
outputIndex++
@@ -1937,16 +1856,6 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
sequenceNumber++
}
// Calculate reasoning delta
var reasoningDelta string
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
reasoningDelta = currentReasoning[len(lastEmittedReasoning):]
lastEmittedReasoning = currentReasoning
} else if currentReasoning != lastEmittedReasoning {
reasoningDelta = currentReasoning
lastEmittedReasoning = currentReasoning
}
// Emit reasoning delta if there's new content
if reasoningDelta != "" {
sendSSEEvent(c, &schema.ORStreamEvent{
@@ -1963,23 +1872,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
}
}
// Handle message content (cleaned content without reasoning tags)
var deltaContent string
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
lastEmittedCleanedContent = cleanedContent
} else if cleanedContent != lastEmittedCleanedContent {
if lastEmittedCleanedContent == "" {
deltaContent = cleanedContent
lastEmittedCleanedContent = cleanedContent
} else {
deltaContent = cleanedContent
lastEmittedCleanedContent = cleanedContent
}
}
// Only emit message content if there's actual content (not just reasoning)
if deltaContent != "" {
if contentDelta != "" {
if currentMessageID == "" {
// Emit output_item.added for message
outputIndex++
@@ -2020,7 +1914,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
ItemID: currentMessageID,
OutputIndex: &outputIndex,
ContentIndex: &currentContentIndex,
Delta: strPtr(deltaContent),
Delta: strPtr(contentDelta),
Logprobs: emptyLogprobs(),
})
sequenceNumber++
@@ -2030,14 +1924,11 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
return true
}
// Pass logprobs and logit_bias parameters if requested
var streamLogprobs *int
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
streamLogprobs = input.TopLogprobs
var ccResult string
ccCb := func(s string, c *[]schema.Choice) {
ccResult = s
}
predFunc, err := backend.ModelInference(
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias, nil)
choices, ccTokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, ccCb, tokenCallback)
if err != nil {
xlog.Error("Open Responses stream model inference failed", "error", err)
sendSSEEvent(c, &schema.ORStreamEvent{
@@ -2061,36 +1952,27 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
c.Response().Flush()
return nil
}
prediction, err = predFunc()
if err != nil {
xlog.Error("Open Responses stream prediction failed", "error", err)
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "error",
SequenceNumber: sequenceNumber,
Error: &schema.ORErrorPayload{
Type: "model_error",
Message: fmt.Sprintf("prediction failed: %v", err),
},
})
sequenceNumber++
responseFailed := responseCreated
responseFailed.Status = "failed"
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "response.failed",
SequenceNumber: sequenceNumber,
Response: responseFailed,
})
// Send [DONE] even on error
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
result = ccResult
lastStreamTokenUsage = ccTokenUsage
if len(choices) > 0 {
lastStreamLogprobs = choices[0].Logprobs
}
result = backend.Finetune(*cfg, predInput, prediction.Response)
// Extract reasoning from final result
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
// Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's
// streaming state, (3) final extraction from the finetuned result.
if chatDeltaReasoning := functions.ReasoningFromChatDeltas(chatDeltas); chatDeltaReasoning != "" {
finalReasoning = chatDeltaReasoning
finalCleanedResult = functions.ContentFromChatDeltas(chatDeltas)
if finalCleanedResult == "" {
finalCleanedResult = extractor.CleanedContent()
}
} else {
finalReasoning = extractor.Reasoning()
finalCleanedResult = extractor.CleanedContent()
}
if finalReasoning == "" && finalCleanedResult == "" {
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
}
// Close reasoning item if it exists and wasn't closed yet
if currentReasoningID != "" && finalReasoning != "" {
@@ -2147,10 +2029,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
textContent = ""
// Try pre-parsed tool calls from C++ autoparser first
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls))
parsedToolCalls = deltaToolCalls
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
textContent = functions.ContentFromChatDeltas(chatDeltas)
} else {
xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing")
cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig)
@@ -2269,8 +2151,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
}
// Convert prediction logprobs for streaming events
streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
// Convert logprobs for streaming events
streamEventLogprobs := convertLogprobsForStreaming(lastStreamLogprobs)
// If we have no output but the model did produce something, use the cleaned result (without reasoning tags)
if textContent == "" && len(toolCalls) == 0 && finalCleanedResult != "" {
@@ -2293,7 +2175,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
sequenceNumber++
// Emit content_part.done (with actual logprobs)
textPart := makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)
textPart := makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "response.content_part.done",
SequenceNumber: sequenceNumber,
@@ -2310,7 +2192,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
ID: currentMessageID,
Status: "completed",
Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)},
}
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "response.output_item.done",
@@ -2379,7 +2261,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
ID: currentMessageID,
Status: "completed",
Role: "assistant",
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)},
})
}
// Add tool call items
@@ -2398,9 +2280,9 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
// Emit response.completed
now := time.Now().Unix()
responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, allOutputItems, &schema.ORUsage{
InputTokens: prediction.Usage.Prompt,
OutputTokens: prediction.Usage.Completion,
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
InputTokens: lastStreamTokenUsage.Prompt,
OutputTokens: lastStreamTokenUsage.Completion,
TotalTokens: lastStreamTokenUsage.Prompt + lastStreamTokenUsage.Completion,
OutputTokensDetails: &schema.OROutputTokensDetails{
ReasoningTokens: reasoningTokens,
},
@@ -2459,12 +2341,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
// Stream text deltas with reasoning extraction
tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool {
accumulatedText += token
accumulatedContent += token
// Prepend thinking token if needed, then extract reasoning
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, cfg.ReasoningConfig)
reasoningDelta, contentDelta := extractor.ProcessToken(token)
// Handle reasoning item
if currentReasoning != "" {
if extractor.Reasoning() != "" {
// Check if we need to create reasoning item
if currentReasoningID == "" {
outputIndex++
@@ -2496,16 +2376,6 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
sequenceNumber++
}
// Calculate reasoning delta
var reasoningDelta string
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
reasoningDelta = currentReasoning[len(lastEmittedReasoning):]
lastEmittedReasoning = currentReasoning
} else if currentReasoning != lastEmittedReasoning {
reasoningDelta = currentReasoning
lastEmittedReasoning = currentReasoning
}
// Emit reasoning delta if there's new content
if reasoningDelta != "" {
sendSSEEvent(c, &schema.ORStreamEvent{
@@ -2522,23 +2392,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
}
}
// Handle message content (cleaned content without reasoning tags)
var deltaContent string
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
lastEmittedCleanedContent = cleanedContent
} else if cleanedContent != lastEmittedCleanedContent {
if lastEmittedCleanedContent == "" {
deltaContent = cleanedContent
lastEmittedCleanedContent = cleanedContent
} else {
deltaContent = cleanedContent
lastEmittedCleanedContent = cleanedContent
}
}
// Only emit message content if there's actual content (not just reasoning)
if deltaContent != "" {
if contentDelta != "" {
// Emit text delta
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "response.output_text.delta",
@@ -2546,7 +2401,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
ItemID: currentMessageID,
OutputIndex: &outputIndex,
ContentIndex: &currentContentIndex,
Delta: strPtr(deltaContent),
Delta: strPtr(contentDelta),
Logprobs: emptyLogprobs(),
})
sequenceNumber++
@@ -2555,14 +2410,11 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
return true
}
// Pass logprobs and logit_bias parameters if requested
var mcpLogprobs *int
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
mcpLogprobs = input.TopLogprobs
var noToolResult string
noToolCb := func(s string, c *[]schema.Choice) {
noToolResult = s
}
predFunc, err := backend.ModelInference(
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, mcpLogprobs, input.TopLogprobs, input.LogitBias, nil)
noToolChoices, noToolTokenUsage, noToolChatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, noToolCb, tokenCallback)
if err != nil {
xlog.Error("Open Responses stream model inference failed", "error", err)
sendSSEEvent(c, &schema.ORStreamEvent{
@@ -2586,36 +2438,28 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
c.Response().Flush()
return nil
}
prediction, err := predFunc()
if err != nil {
xlog.Error("Open Responses stream prediction failed", "error", err)
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "error",
SequenceNumber: sequenceNumber,
Error: &schema.ORErrorPayload{
Type: "model_error",
Message: fmt.Sprintf("prediction failed: %v", err),
},
})
sequenceNumber++
responseFailed := responseCreated
responseFailed.Status = "failed"
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "response.failed",
SequenceNumber: sequenceNumber,
Response: responseFailed,
})
// Send [DONE] even on error
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
result := noToolResult
var noToolLogprobs *schema.Logprobs
if len(noToolChoices) > 0 {
noToolLogprobs = noToolChoices[0].Logprobs
}
result := backend.Finetune(*cfg, predInput, prediction.Response)
// Extract reasoning from final result for non-tool-call path
finalReasoning, finalCleanedResult := reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
// Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's
// streaming state, (3) final extraction from the finetuned result.
var finalReasoning, finalCleanedResult string
if chatDeltaReasoning := functions.ReasoningFromChatDeltas(noToolChatDeltas); chatDeltaReasoning != "" {
finalReasoning = chatDeltaReasoning
finalCleanedResult = functions.ContentFromChatDeltas(noToolChatDeltas)
if finalCleanedResult == "" {
finalCleanedResult = extractor.CleanedContent()
}
} else {
finalReasoning = extractor.Reasoning()
finalCleanedResult = extractor.CleanedContent()
}
if finalReasoning == "" && finalCleanedResult == "" {
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
}
// Close reasoning item if it exists and wasn't closed yet
if currentReasoningID != "" && finalReasoning != "" {
@@ -2670,8 +2514,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
result = finalCleanedResult
// Convert prediction logprobs for streaming events
mcpStreamLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
// Convert logprobs for streaming events
mcpStreamLogprobs := convertLogprobsForStreaming(noToolLogprobs)
// Emit output_text.done
sendSSEEvent(c, &schema.ORStreamEvent{
@@ -2686,7 +2530,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
sequenceNumber++
// Emit content_part.done (with actual logprobs)
resultPart := makeOutputTextPartWithLogprobs(result, prediction.Logprobs)
resultPart := makeOutputTextPartWithLogprobs(result, noToolLogprobs)
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "response.content_part.done",
SequenceNumber: sequenceNumber,
@@ -2699,7 +2543,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
// Emit output_item.done (with actual logprobs)
messageItem.Status = "completed"
messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}
messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, noToolLogprobs)}
sendSSEEvent(c, &schema.ORStreamEvent{
Type: "response.output_item.done",
SequenceNumber: sequenceNumber,
@@ -2734,9 +2578,9 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
finalOutputItems = append(finalOutputItems, *messageItem)
}
responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, finalOutputItems, &schema.ORUsage{
InputTokens: prediction.Usage.Prompt,
OutputTokens: prediction.Usage.Completion,
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
InputTokens: noToolTokenUsage.Prompt,
OutputTokens: noToolTokenUsage.Completion,
TotalTokens: noToolTokenUsage.Prompt + noToolTokenUsage.Completion,
OutputTokensDetails: &schema.OROutputTokensDetails{
ReasoningTokens: reasoningTokens,
},
@@ -2762,6 +2606,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
// sendSSEEvent sends a Server-Sent Event
func sendSSEEvent(c echo.Context, event *schema.ORStreamEvent) {
normalizeORStreamEvent(event)
data, err := json.Marshal(event)
if err != nil {
xlog.Error("Failed to marshal SSE event", "error", err)
@@ -2770,6 +2615,13 @@ func sendSSEEvent(c echo.Context, event *schema.ORStreamEvent) {
fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.Type, string(data))
}
// normalizeORStreamEvent ensures required fields like Summary are never null.
func normalizeORStreamEvent(event *schema.ORStreamEvent) {
if event.Item != nil && event.Item.Summary == nil {
event.Item.Summary = []schema.ORContentPart{}
}
}
// getTopLogprobs returns the top_logprobs value, defaulting to 0 if nil
func getTopLogprobs(topLogprobs *int) int {
if topLogprobs != nil {
@@ -2850,6 +2702,13 @@ func buildORResponse(responseID string, createdAt int64, completedAt *int64, sta
outputItems = []schema.ORItemField{}
}
// Ensure Summary is never null on any output item
for i := range outputItems {
if outputItems[i].Summary == nil {
outputItems[i].Summary = []schema.ORContentPart{}
}
}
// Ensure tools is never null - always an array
tools := input.Tools
if tools == nil {
@@ -3025,19 +2884,6 @@ func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.T
return result
}
// serializeToolsForBackend converts and serializes Open Responses tools to JSON for the backend
func serializeToolsForBackend(orTools []schema.ORFunctionTool) string {
if len(orTools) == 0 {
return ""
}
openAITools := convertORToolsToOpenAIFormat(orTools)
toolsBytes, err := json.Marshal(openAITools)
if err != nil {
return ""
}
return string(toolsBytes)
}
// GetResponseEndpoint returns a handler for GET /responses/:id
// This endpoint is used for polling background responses or resuming streaming
// @Summary Get a response by ID

View File

@@ -2,15 +2,16 @@ package middleware
import (
"bytes"
"github.com/emirpasic/gods/v2/queues/circularbuffer"
"io"
"net/http"
"sort"
"sync"
"time"
"github.com/emirpasic/gods/v2/queues/circularbuffer"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/xlog"
)
@@ -29,8 +30,12 @@ type APIExchangeResponse struct {
type APIExchange struct {
Timestamp time.Time `json:"timestamp"`
Duration time.Duration `json:"duration"`
Request APIExchangeRequest `json:"request"`
Response APIExchangeResponse `json:"response"`
Error string `json:"error,omitempty"`
UserID string `json:"user_id,omitempty"`
UserName string `json:"user_name,omitempty"`
}
var traceBuffer *circularbuffer.Queue[APIExchange]
@@ -108,13 +113,18 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
}
c.Response().Writer = mw
err = next(c)
if err != nil {
c.Response().Writer = mw.ResponseWriter // Restore original writer if error
return err
handlerErr := next(c)
// Restore original writer unconditionally
c.Response().Writer = mw.ResponseWriter
// Determine response status (use 500 if handler errored and no status was set)
status := c.Response().Status
if status == 0 && handlerErr != nil {
status = http.StatusInternalServerError
}
// Create exchange log
// Create exchange log (always, even on error)
requestHeaders := c.Request().Header.Clone()
requestBody := make([]byte, len(body))
copy(requestBody, body)
@@ -123,6 +133,7 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
copy(responseBody, resBody.Bytes())
exchange := APIExchange{
Timestamp: startTime,
Duration: time.Since(startTime),
Request: APIExchangeRequest{
Method: c.Request().Method,
Path: c.Path(),
@@ -130,11 +141,19 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
Body: &requestBody,
},
Response: APIExchangeResponse{
Status: c.Response().Status,
Status: status,
Headers: &responseHeaders,
Body: &responseBody,
},
}
if handlerErr != nil {
exchange.Error = handlerErr.Error()
}
if user := auth.GetUser(c); user != nil {
exchange.UserID = user.ID
exchange.UserName = user.Name
}
select {
case logChan <- exchange:
@@ -142,7 +161,7 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
xlog.Warn("Trace channel full, dropping trace")
}
return nil
return handlerErr
}
}
}

View File

@@ -0,0 +1,185 @@
package middleware
import (
"bytes"
"encoding/json"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/xlog"
"gorm.io/gorm"
)
const (
usageFlushInterval = 5 * time.Second
usageMaxPending = 5000
)
// usageBatcher accumulates usage records and flushes them to the DB periodically.
type usageBatcher struct {
mu sync.Mutex
pending []*auth.UsageRecord
db *gorm.DB
}
func (b *usageBatcher) add(r *auth.UsageRecord) {
b.mu.Lock()
b.pending = append(b.pending, r)
b.mu.Unlock()
}
func (b *usageBatcher) flush() {
b.mu.Lock()
batch := b.pending
b.pending = nil
b.mu.Unlock()
if len(batch) == 0 {
return
}
if err := b.db.Create(&batch).Error; err != nil {
xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err)
// Re-queue failed records with a cap to avoid unbounded growth
b.mu.Lock()
if len(b.pending) < usageMaxPending {
b.pending = append(batch, b.pending...)
}
b.mu.Unlock()
}
}
var batcher *usageBatcher
// InitUsageRecorder starts a background goroutine that periodically flushes
// accumulated usage records to the database.
func InitUsageRecorder(db *gorm.DB) {
if db == nil {
return
}
batcher = &usageBatcher{db: db}
go func() {
ticker := time.NewTicker(usageFlushInterval)
defer ticker.Stop()
for range ticker.C {
batcher.flush()
}
}()
}
// usageResponseBody is the minimal structure we need from the response JSON.
type usageResponseBody struct {
Model string `json:"model"`
Usage *struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
} `json:"usage"`
}
// UsageMiddleware extracts token usage from OpenAI-compatible response JSON
// and records it per-user.
func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if db == nil || batcher == nil {
return next(c)
}
startTime := time.Now()
// Wrap response writer to capture body
resBody := new(bytes.Buffer)
origWriter := c.Response().Writer
mw := &bodyWriter{
ResponseWriter: origWriter,
body: resBody,
}
c.Response().Writer = mw
handlerErr := next(c)
// Restore original writer
c.Response().Writer = origWriter
// Only record on successful responses
if c.Response().Status < 200 || c.Response().Status >= 300 {
return handlerErr
}
// Get authenticated user
user := auth.GetUser(c)
if user == nil {
return handlerErr
}
// Try to parse usage from response
responseBytes := resBody.Bytes()
if len(responseBytes) == 0 {
return handlerErr
}
// Check content type
ct := c.Response().Header().Get("Content-Type")
isJSON := ct == "" || ct == "application/json" || bytes.HasPrefix([]byte(ct), []byte("application/json"))
isSSE := bytes.HasPrefix([]byte(ct), []byte("text/event-stream"))
if !isJSON && !isSSE {
return handlerErr
}
var resp usageResponseBody
if isSSE {
last, ok := lastSSEData(responseBytes)
if !ok {
return handlerErr
}
if err := json.Unmarshal(last, &resp); err != nil {
return handlerErr
}
} else {
if err := json.Unmarshal(responseBytes, &resp); err != nil {
return handlerErr
}
}
if resp.Usage == nil {
return handlerErr
}
record := &auth.UsageRecord{
UserID: user.ID,
UserName: user.Name,
Model: resp.Model,
Endpoint: c.Request().URL.Path,
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
Duration: time.Since(startTime).Milliseconds(),
CreatedAt: startTime,
}
batcher.add(record)
return handlerErr
}
}
}
// lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]".
func lastSSEData(b []byte) ([]byte, bool) {
prefix := []byte("data: ")
var last []byte
for _, line := range bytes.Split(b, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
if bytes.HasPrefix(line, prefix) {
payload := line[len(prefix):]
if !bytes.Equal(payload, []byte("[DONE]")) {
last = payload
}
}
}
return last, last != nil
}

View File

@@ -0,0 +1,64 @@
import { test, expect } from '@playwright/test'
test.describe('Backend Logs', () => {
test('model detail page shows title', async ({ page }) => {
await page.goto('/app/backend-logs/mock-model')
await expect(page.locator('.page-title')).toContainText('mock-model')
})
test('no back arrow link on detail page', async ({ page }) => {
await page.goto('/app/backend-logs/mock-model')
await expect(page.locator('a[href="/app/backend-logs"]')).not.toBeVisible()
})
test('filter buttons are visible', async ({ page }) => {
await page.goto('/app/backend-logs/mock-model')
await expect(page.locator('button', { hasText: 'All' })).toBeVisible()
await expect(page.locator('button', { hasText: 'stdout' })).toBeVisible()
await expect(page.locator('button', { hasText: 'stderr' })).toBeVisible()
})
test('filter buttons toggle active state', async ({ page }) => {
await page.goto('/app/backend-logs/mock-model')
const allBtn = page.locator('button', { hasText: 'All' })
const stdoutBtn = page.locator('button', { hasText: 'stdout' })
// All is active by default
await expect(allBtn).toHaveClass(/btn-primary/)
// Click stdout
await stdoutBtn.click()
await expect(stdoutBtn).toHaveClass(/btn-primary/)
await expect(allBtn).not.toHaveClass(/btn-primary/)
})
test('export button is present', async ({ page }) => {
await page.goto('/app/backend-logs/mock-model')
await expect(page.locator('button', { hasText: 'Export' })).toBeVisible()
})
test('auto-scroll checkbox is present', async ({ page }) => {
await page.goto('/app/backend-logs/mock-model')
await expect(page.locator('text=Auto-scroll')).toBeVisible()
})
test('clear button is present', async ({ page }) => {
await page.goto('/app/backend-logs/mock-model')
await expect(page.locator('button', { hasText: 'Clear' })).toBeVisible()
})
test('details toggle button is present and toggles', async ({ page }) => {
await page.goto('/app/backend-logs/mock-model')
// "Text only" button visible by default (details are shown)
const toggleBtn = page.locator('button', { hasText: 'Text only' })
await expect(toggleBtn).toBeVisible()
// Click to hide details
await toggleBtn.click()
// Button label changes to "Show details"
await expect(page.locator('button', { hasText: 'Show details' })).toBeVisible()
})
})

View File

@@ -0,0 +1,29 @@
import { test, expect } from '@playwright/test'
test.describe('Manage Page - Backend Logs Link', () => {
test('models table shows terminal icon for logs', async ({ page }) => {
await page.goto('/app/manage')
// Wait for models to load
await expect(page.locator('.table')).toBeVisible({ timeout: 10_000 })
// Check for terminal icon (backend logs link)
const terminalIcon = page.locator('a[title="Backend logs"] i.fa-terminal')
await expect(terminalIcon.first()).toBeVisible()
})
test('terminal icon links to backend-logs page', async ({ page }) => {
await page.goto('/app/manage')
await expect(page.locator('.table')).toBeVisible({ timeout: 10_000 })
const logsLink = page.locator('a[title="Backend logs"]').first()
await expect(logsLink).toBeVisible()
// Link uses href="#" with onClick for navigation
const href = await logsLink.getAttribute('href')
expect(href).toBe('#')
// Click and verify navigation
await logsLink.click()
await expect(page).toHaveURL(/\/app\/backend-logs\//)
})
})

View File

@@ -0,0 +1,80 @@
import { test, expect } from '@playwright/test'
const MOCK_MODELS_RESPONSE = {
models: [
{ name: 'llama-model', description: 'A llama model', backend: 'llama-cpp', installed: false, tags: ['llm'] },
{ name: 'whisper-model', description: 'A whisper model', backend: 'whisper', installed: true, tags: ['stt'] },
{ name: 'stablediffusion-model', description: 'An image model', backend: 'stablediffusion', installed: false, tags: ['sd'] },
{ name: 'unknown-model', description: 'No backend', backend: '', installed: false, tags: [] },
],
allBackends: ['llama-cpp', 'stablediffusion', 'whisper'],
allTags: ['llm', 'sd', 'stt'],
availableModels: 4,
installedModels: 1,
totalPages: 1,
currentPage: 1,
}
test.describe('Models Gallery - Backend Features', () => {
test.beforeEach(async ({ page }) => {
await page.route('**/api/models*', (route) => {
route.fulfill({
contentType: 'application/json',
body: JSON.stringify(MOCK_MODELS_RESPONSE),
})
})
await page.goto('/app/models')
// Wait for the table to render
await expect(page.locator('th', { hasText: 'Backend' })).toBeVisible({ timeout: 10_000 })
})
test('backend column header is visible', async ({ page }) => {
await expect(page.locator('th', { hasText: 'Backend' })).toBeVisible()
})
test('backend badges shown in table rows', async ({ page }) => {
const table = page.locator('table')
await expect(table.locator('.badge', { hasText: 'llama-cpp' })).toBeVisible()
await expect(table.locator('.badge', { hasText: /^whisper$/ })).toBeVisible()
})
test('backend dropdown is visible', async ({ page }) => {
await expect(page.locator('button', { hasText: 'All Backends' })).toBeVisible()
})
test('clicking backend dropdown opens searchable panel', async ({ page }) => {
await page.locator('button', { hasText: 'All Backends' }).click()
await expect(page.locator('input[placeholder="Search backends..."]')).toBeVisible()
})
test('typing in search filters dropdown options', async ({ page }) => {
await page.locator('button', { hasText: 'All Backends' }).click()
const searchInput = page.locator('input[placeholder="Search backends..."]')
await searchInput.fill('llama')
// llama-cpp option should be visible, whisper should not
const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..') .locator('..')
await expect(dropdown.locator('text=llama-cpp')).toBeVisible()
await expect(dropdown.locator('text=whisper')).not.toBeVisible()
})
test('selecting a backend updates the dropdown label', async ({ page }) => {
await page.locator('button', { hasText: 'All Backends' }).click()
// Click the llama-cpp option within the dropdown (not the table badge)
const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..').locator('..')
await dropdown.locator('text=llama-cpp').click()
// The dropdown button should now show the selected backend instead of "All Backends"
await expect(page.locator('button span', { hasText: 'llama-cpp' })).toBeVisible()
})
test('expanded row shows backend in detail', async ({ page }) => {
// Click the first model row to expand it
await page.locator('tr', { hasText: 'llama-model' }).click()
// The detail view should show Backend label and value
const detail = page.locator('td[colspan="8"]')
await expect(detail.locator('text=Backend')).toBeVisible()
await expect(detail.locator('text=llama-cpp')).toBeVisible()
})
})

View File

@@ -0,0 +1,23 @@
import { test, expect } from '@playwright/test'
test.describe('Navigation', () => {
test('/ redirects to /app', async ({ page }) => {
await page.goto('/')
await expect(page).toHaveURL(/\/app/)
})
test('/app shows home page with LocalAI title', async ({ page }) => {
await page.goto('/app')
await expect(page.locator('.sidebar')).toBeVisible()
await expect(page.locator('.home-page')).toBeVisible()
})
test('sidebar traces link navigates to /app/traces', async ({ page }) => {
await page.goto('/app')
const tracesLink = page.locator('a.nav-item[href="/app/traces"]')
await expect(tracesLink).toBeVisible()
await tracesLink.click()
await expect(page).toHaveURL(/\/app\/traces/)
await expect(page.getByRole('heading', { name: 'Traces', exact: true })).toBeVisible()
})
})

Some files were not shown because too many files have changed in this diff Show More