mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-20 14:46:38 -04:00
Compare commits
48 Commits
v4.0.0
...
feat/fine-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8997ff6042 | ||
|
|
f1223b45b2 | ||
|
|
fa8b1a8673 | ||
|
|
3451dbdccd | ||
|
|
7b8afc9609 | ||
|
|
ae4b758a5a | ||
|
|
9cdbd89c1f | ||
|
|
7d81bf0aa3 | ||
|
|
a6d0e29eba | ||
|
|
6054d2a91b | ||
|
|
aea21951a2 | ||
|
|
bbe9067227 | ||
|
|
9a9da062e1 | ||
|
|
dd1a8b174f | ||
|
|
cfb7641eea | ||
|
|
e832efeb9e | ||
|
|
a42548e9d1 | ||
|
|
8615ce28a8 | ||
|
|
8336efec41 | ||
|
|
8560a1e571 | ||
|
|
29c33e6a6a | ||
|
|
a58475dbef | ||
|
|
8a0edd0809 | ||
|
|
35d509d8e7 | ||
|
|
eef808d921 | ||
|
|
9d9ea5c1a0 | ||
|
|
e21ad5cfaa | ||
|
|
05ab0c0aa2 | ||
|
|
e2b6233570 | ||
|
|
19f995f38f | ||
|
|
ac168bbc60 | ||
|
|
5c5e537b31 | ||
|
|
118bcee196 | ||
|
|
3eabd6d1d0 | ||
|
|
ee96e5e08d | ||
|
|
3d9ccd1ddc | ||
|
|
d8161bfe57 | ||
|
|
5fd42399d4 | ||
|
|
b2030255ca | ||
|
|
9f903ec06e | ||
|
|
4ea461c330 | ||
|
|
042a9b8ef6 | ||
|
|
65f1a4154a | ||
|
|
c6a51289b0 | ||
|
|
87525109f1 | ||
|
|
c596d8a5d9 | ||
|
|
d79ad76e48 | ||
|
|
dde0353432 |
259
.agents/api-endpoints-and-auth.md
Normal file
259
.agents/api-endpoints-and-auth.md
Normal file
@@ -0,0 +1,259 @@
|
||||
# API Endpoints and Authentication
|
||||
|
||||
This guide covers how to add new API endpoints and properly integrate them with the auth/permissions system.
|
||||
|
||||
## Architecture overview
|
||||
|
||||
Authentication and authorization flow through three layers:
|
||||
|
||||
1. **Global auth middleware** (`core/http/auth/middleware.go` → `auth.Middleware`) — applied to every request in `core/http/app.go`. Handles session cookies, Bearer tokens, API keys, and legacy API keys. Populates `auth_user` and `auth_role` in the Echo context.
|
||||
2. **Feature middleware** (`auth.RequireFeature`) — per-feature access control applied to route groups or individual routes. Checks if the authenticated user has the specific feature enabled.
|
||||
3. **Admin middleware** (`auth.RequireAdmin`) — restricts endpoints to admin users only.
|
||||
|
||||
When auth is disabled (no auth DB, no legacy API keys), all middleware becomes pass-through (`auth.NoopMiddleware`).
|
||||
|
||||
## Adding a new API endpoint
|
||||
|
||||
### Step 1: Create the handler
|
||||
|
||||
Write the endpoint handler in the appropriate package under `core/http/endpoints/`. Follow existing patterns:
|
||||
|
||||
```go
|
||||
// core/http/endpoints/localai/my_feature.go
|
||||
func MyFeatureEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Use auth.GetUser(c) to get the authenticated user (may be nil if auth is disabled)
|
||||
user := auth.GetUser(c)
|
||||
|
||||
// Your logic here
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Register routes
|
||||
|
||||
Add routes in the appropriate file under `core/http/routes/`. The file you use depends on the endpoint category:
|
||||
|
||||
| File | Category |
|
||||
|------|----------|
|
||||
| `routes/openai.go` | OpenAI-compatible API endpoints (`/v1/...`) |
|
||||
| `routes/localai.go` | LocalAI-specific endpoints (`/api/...`, `/models/...`, `/backends/...`) |
|
||||
| `routes/agents.go` | Agent pool endpoints (`/api/agents/...`) |
|
||||
| `routes/auth.go` | Auth endpoints (`/api/auth/...`) |
|
||||
| `routes/ui_api.go` | UI backend API endpoints |
|
||||
|
||||
### Step 3: Apply the right middleware
|
||||
|
||||
Choose the appropriate protection level:
|
||||
|
||||
#### No auth required (public)
|
||||
Exempt paths bypass auth entirely. Add to `isExemptPath()` in `middleware.go` or use the `/api/auth/` prefix (always exempt). Use sparingly — most endpoints should require auth.
|
||||
|
||||
#### Standard auth (any authenticated user)
|
||||
The global middleware already handles this. API paths (`/api/`, `/v1/`, etc.) automatically require authentication when auth is enabled. You don't need to add any extra middleware.
|
||||
|
||||
```go
|
||||
router.GET("/v1/my-endpoint", myHandler) // auth enforced by global middleware
|
||||
```
|
||||
|
||||
#### Admin only
|
||||
Pass `adminMiddleware` to the route. This is set up in `app.go` and passed to `Register*Routes` functions:
|
||||
|
||||
```go
|
||||
// In the Register function signature, accept the middleware:
|
||||
func RegisterMyRoutes(router *echo.Echo, app *application.Application, adminMiddleware echo.MiddlewareFunc) {
|
||||
router.POST("/models/apply", myHandler, adminMiddleware)
|
||||
}
|
||||
```
|
||||
|
||||
#### Feature-gated
|
||||
For endpoints that should be toggleable per-user, use feature middleware. There are two approaches:
|
||||
|
||||
**Approach A: Route-level middleware** (preferred for groups of related endpoints)
|
||||
|
||||
```go
|
||||
// In app.go, create the feature middleware:
|
||||
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
|
||||
|
||||
// Pass it to the route registration function:
|
||||
routes.RegisterMyRoutes(e, app, myFeatureMw)
|
||||
|
||||
// In the routes file, apply to a group:
|
||||
g := e.Group("/api/my-feature", myFeatureMw)
|
||||
g.GET("", listHandler)
|
||||
g.POST("", createHandler)
|
||||
```
|
||||
|
||||
**Approach B: RouteFeatureRegistry** (preferred for individual OpenAI-compatible endpoints)
|
||||
|
||||
Add an entry to `RouteFeatureRegistry` in `core/http/auth/features.go`. The `RequireRouteFeature` global middleware will automatically enforce it:
|
||||
|
||||
```go
|
||||
var RouteFeatureRegistry = []RouteFeature{
|
||||
// ... existing entries ...
|
||||
{"POST", "/v1/my-endpoint", FeatureMyFeature},
|
||||
}
|
||||
```
|
||||
|
||||
## Adding a new feature
|
||||
|
||||
When you need a new toggleable feature (not just a new endpoint under an existing feature):
|
||||
|
||||
### 1. Define the feature constant
|
||||
|
||||
Add to `core/http/auth/permissions.go`:
|
||||
|
||||
```go
|
||||
const (
|
||||
// Add to the appropriate group:
|
||||
// Agent features (default OFF for new users)
|
||||
FeatureMyFeature = "my_feature"
|
||||
|
||||
// OR API features (default ON for new users)
|
||||
FeatureMyFeature = "my_feature"
|
||||
)
|
||||
```
|
||||
|
||||
Then add it to the appropriate slice:
|
||||
|
||||
```go
|
||||
// Default OFF — user must be explicitly granted access:
|
||||
var AgentFeatures = []string{..., FeatureMyFeature}
|
||||
|
||||
// Default ON — user has access unless explicitly revoked:
|
||||
var APIFeatures = []string{..., FeatureMyFeature}
|
||||
```
|
||||
|
||||
### 2. Add feature metadata
|
||||
|
||||
In `core/http/auth/features.go`, add to the appropriate `FeatureMetas` function so the admin UI can display it:
|
||||
|
||||
```go
|
||||
func AgentFeatureMetas() []FeatureMeta {
|
||||
return []FeatureMeta{
|
||||
// ... existing ...
|
||||
{FeatureMyFeature, "My Feature", false}, // false = default OFF
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Wire up the middleware
|
||||
|
||||
In `core/http/app.go`:
|
||||
|
||||
```go
|
||||
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
|
||||
```
|
||||
|
||||
Then pass it to the route registration function.
|
||||
|
||||
### 4. Register route-feature mappings (if applicable)
|
||||
|
||||
If your feature gates standard API endpoints (like `/v1/...`), add entries to `RouteFeatureRegistry` in `features.go` instead of using per-route middleware.
|
||||
|
||||
## Accessing the authenticated user in handlers
|
||||
|
||||
```go
|
||||
import "github.com/mudler/LocalAI/core/http/auth"
|
||||
|
||||
func MyHandler(c echo.Context) error {
|
||||
// Get the user (nil when auth is disabled or unauthenticated)
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
// Handle unauthenticated — or let middleware handle it
|
||||
}
|
||||
|
||||
// Check role
|
||||
if user.Role == auth.RoleAdmin {
|
||||
// admin-specific logic
|
||||
}
|
||||
|
||||
// Check feature access programmatically (when you need conditional behavior, not full blocking)
|
||||
if auth.HasFeatureAccess(db, user, auth.FeatureMyFeature) {
|
||||
// feature-specific logic
|
||||
}
|
||||
|
||||
// Check model access
|
||||
if !auth.IsModelAllowed(db, user, modelName) {
|
||||
return c.JSON(http.StatusForbidden, ...)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Middleware composition patterns
|
||||
|
||||
Middleware can be composed at different levels. Here are the patterns used in the codebase:
|
||||
|
||||
### Group-level middleware (agents pattern)
|
||||
```go
|
||||
// All routes in the group share the middleware
|
||||
g := e.Group("/api/agents", poolReadyMw, agentsMw)
|
||||
g.GET("", listHandler)
|
||||
g.POST("", createHandler)
|
||||
```
|
||||
|
||||
### Per-route middleware (localai pattern)
|
||||
```go
|
||||
// Individual routes get middleware as extra arguments
|
||||
router.POST("/models/apply", applyHandler, adminMiddleware)
|
||||
router.GET("/metrics", metricsHandler, adminMiddleware)
|
||||
```
|
||||
|
||||
### Middleware slice (openai pattern)
|
||||
```go
|
||||
// Build a middleware chain for a handler
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
modelFilterMiddleware,
|
||||
}
|
||||
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
|
||||
```
|
||||
|
||||
## Error response format
|
||||
|
||||
Always use `schema.ErrorResponse` for auth/permission errors to stay consistent with the OpenAI-compatible API:
|
||||
|
||||
```go
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "feature not enabled for your account",
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Use these HTTP status codes:
|
||||
- `401 Unauthorized` — no valid credentials provided
|
||||
- `403 Forbidden` — authenticated but lacking permission
|
||||
- `429 Too Many Requests` — rate limited (auth endpoints)
|
||||
|
||||
## Usage tracking
|
||||
|
||||
If your endpoint should be tracked for usage (token counts, request counts), add the `usageMiddleware` to its middleware chain. See `core/http/middleware/usage.go` and how it's applied in `routes/openai.go`.
|
||||
|
||||
## Path protection rules
|
||||
|
||||
The global auth middleware classifies paths as API paths or non-API paths:
|
||||
|
||||
- **API paths** (always require auth when auth is enabled): `/api/`, `/v1/`, `/models/`, `/backends/`, `/backend/`, `/tts`, `/vad`, `/video`, `/stores/`, `/system`, `/ws/`, `/metrics`
|
||||
- **Exempt paths** (never require auth): `/api/auth/` prefix, anything in `appConfig.PathWithoutAuth`
|
||||
- **Non-API paths** (UI, static assets): pass through without auth — the React UI handles login redirects client-side
|
||||
|
||||
If you add endpoints under a new top-level path prefix, add it to `isAPIPath()` in `middleware.go` to ensure it requires authentication.
|
||||
|
||||
## Checklist
|
||||
|
||||
When adding a new endpoint:
|
||||
|
||||
- [ ] Handler in `core/http/endpoints/`
|
||||
- [ ] Route registered in appropriate `core/http/routes/` file
|
||||
- [ ] Auth level chosen: public / standard / admin / feature-gated
|
||||
- [ ] If feature-gated: constant in `permissions.go`, metadata in `features.go`, middleware in `app.go`
|
||||
- [ ] If new path prefix: added to `isAPIPath()` in `middleware.go`
|
||||
- [ ] If OpenAI-compatible: entry in `RouteFeatureRegistry`
|
||||
- [ ] If token-counting: `usageMiddleware` added to middleware chain
|
||||
- [ ] Error responses use `schema.ErrorResponse` format
|
||||
- [ ] Tests cover both authenticated and unauthenticated access
|
||||
141
.agents/debugging-backends.md
Normal file
141
.agents/debugging-backends.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Debugging and Rebuilding Backends
|
||||
|
||||
When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
- **Source directory**: `backend/python/<name>/` (or `backend/go/<name>/`, `backend/cpp/<name>/`)
|
||||
- **Installed directory**: `backends/<name>/` — this is what LocalAI actually runs. It is populated by `make backends/<name>` which builds a Docker image, exports it, and installs it via `local-ai backends install`.
|
||||
- **Virtual environment**: `backends/<name>/venv/` — the installed Python venv (for Python backends). The Python binary is at `backends/<name>/venv/bin/python`.
|
||||
|
||||
Editing files in `backend/python/<name>/` does **not** affect the running backend until you rebuild with `make backends/<name>`.
|
||||
|
||||
## Diagnosing Failures
|
||||
|
||||
### 1. Check the logs
|
||||
|
||||
Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID:
|
||||
|
||||
```
|
||||
GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..."
|
||||
```
|
||||
|
||||
Common error patterns:
|
||||
- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub.
|
||||
- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`).
|
||||
- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback.
|
||||
|
||||
### 2. Test the Python environment directly
|
||||
|
||||
You can run the installed venv's Python to check imports without starting the full server:
|
||||
|
||||
```bash
|
||||
backends/<name>/venv/bin/python -c "import datasets; print(datasets.__version__)"
|
||||
```
|
||||
|
||||
If `pip` is missing from the venv, bootstrap it:
|
||||
|
||||
```bash
|
||||
backends/<name>/venv/bin/python -m ensurepip
|
||||
```
|
||||
|
||||
Then use `backends/<name>/venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements.
|
||||
|
||||
### 3. Check upstream dependency constraints
|
||||
|
||||
When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`:
|
||||
|
||||
```
|
||||
https://github.com/huggingface/trl/blob/main/requirements.txt
|
||||
```
|
||||
|
||||
Pin minimum versions in the backend's requirements files to match upstream.
|
||||
|
||||
## Common Fixes
|
||||
|
||||
### Missing gRPC methods
|
||||
|
||||
If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`:
|
||||
|
||||
```python
|
||||
def LoadModel(self, request, context):
|
||||
"""No-op — actual loading happens elsewhere."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
```
|
||||
|
||||
The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading.
|
||||
|
||||
### Dependency version conflicts
|
||||
|
||||
Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps:
|
||||
|
||||
1. Identify the broken import in the logs
|
||||
2. Test in the installed venv: `backends/<name>/venv/bin/python -c "import <module>"`
|
||||
3. Check upstream requirements for version constraints
|
||||
4. Update **all** requirements files in `backend/python/<name>/`:
|
||||
- `requirements.txt` — base deps (grpcio, protobuf)
|
||||
- `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index)
|
||||
- `requirements-cublas12.txt` — CUDA 12
|
||||
- `requirements-cublas13.txt` — CUDA 13
|
||||
5. Rebuild: `make backends/<name>`
|
||||
|
||||
### PyTorch index conflicts (uv resolver)
|
||||
|
||||
The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`:
|
||||
|
||||
```bash
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
installRequirements
|
||||
```
|
||||
|
||||
Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference.
|
||||
|
||||
## Rebuilding
|
||||
|
||||
### Rebuild a single backend
|
||||
|
||||
```bash
|
||||
make backends/<name>
|
||||
```
|
||||
|
||||
This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/<name>.tar`, and installs it into `backends/<name>/`. It also rebuilds the `local-ai` Go binary (without extra tags).
|
||||
|
||||
**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/<name>` step will overwrite your binary without that tag. Rebuild the Go binary afterward:
|
||||
|
||||
```bash
|
||||
GO_TAGS=auth make build
|
||||
```
|
||||
|
||||
### Rebuild and restart
|
||||
|
||||
After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded.
|
||||
|
||||
```bash
|
||||
# Kill existing process
|
||||
kill <pid>
|
||||
|
||||
# Restart
|
||||
./local-ai run --debug [your flags]
|
||||
```
|
||||
|
||||
### Quick iteration (skip Docker rebuild)
|
||||
|
||||
For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly:
|
||||
|
||||
```bash
|
||||
# Edit the installed copy
|
||||
vim backends/<name>/backend.py
|
||||
|
||||
# Restart LocalAI to respawn the gRPC process
|
||||
```
|
||||
|
||||
This is useful for testing but **does not persist** — the next `make backends/<name>` will overwrite it. Always commit fixes to the source in `backend/python/<name>/`.
|
||||
|
||||
## Verification
|
||||
|
||||
After fixing and rebuilding:
|
||||
|
||||
1. Start LocalAI and confirm the backend registers: look for `Registering backend name="<name>"` in the logs
|
||||
2. Trigger the operation that failed (e.g. start a fine-tuning job)
|
||||
3. Watch the GRPC stderr/stdout lines for the backend's model ID
|
||||
4. Confirm no errors in the traceback
|
||||
39
.github/workflows/backend.yml
vendored
39
.github/workflows/backend.yml
vendored
@@ -118,6 +118,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -366,6 +379,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -757,6 +783,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
|
||||
72
.github/workflows/tests-ui-e2e.yml
vendored
Normal file
72
.github/workflows/tests-ui-e2e.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
---
|
||||
name: 'UI E2E Tests'
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'core/http/**'
|
||||
- 'tests/e2e-ui/**'
|
||||
- 'tests/e2e/mock-backend/**'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
concurrency:
|
||||
group: ci-tests-ui-e2e-${{ github.head_ref || github.ref }}-${{ github.repository }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
tests-ui-e2e:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.26.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: false
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '22'
|
||||
- name: Proto Dependencies
|
||||
run: |
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||
- name: System Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential libopus-dev
|
||||
- name: Build UI test server
|
||||
run: PATH="$PATH:$HOME/go/bin" make build-ui-test-server
|
||||
- name: Install Playwright
|
||||
working-directory: core/http/react-ui
|
||||
run: |
|
||||
npm install
|
||||
npx playwright install --with-deps chromium
|
||||
- name: Run Playwright tests
|
||||
working-directory: core/http/react-ui
|
||||
run: npx playwright test
|
||||
- name: Upload Playwright report
|
||||
if: ${{ failure() }}
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: playwright-report
|
||||
path: core/http/react-ui/playwright-report/
|
||||
retention-days: 7
|
||||
- name: Setup tmate session if tests fail
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
with:
|
||||
detached: true
|
||||
connect-timeout-seconds: 180
|
||||
limit-access-to-actor: true
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -72,3 +72,8 @@ core/http/react-ui/dist
|
||||
|
||||
# Extracted backend binaries for container-based testing
|
||||
local-backends/
|
||||
|
||||
# UI E2E test artifacts
|
||||
tests/e2e-ui/ui-test-server
|
||||
core/http/react-ui/playwright-report/
|
||||
core/http/react-ui/test-results/
|
||||
|
||||
@@ -11,6 +11,8 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
|
||||
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
|
||||
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
@@ -256,7 +256,7 @@ RUN apt-get update && \
|
||||
|
||||
FROM build-requirements AS builder-base
|
||||
|
||||
ARG GO_TAGS=""
|
||||
ARG GO_TAGS="auth"
|
||||
ARG GRPC_BACKENDS
|
||||
ARG MAKEFLAGS
|
||||
ARG LD_FLAGS="-s -w"
|
||||
|
||||
25
Makefile
25
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -421,6 +421,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/voxcpm
|
||||
$(MAKE) -C backend/python/whisperx
|
||||
$(MAKE) -C backend/python/ace-step
|
||||
$(MAKE) -C backend/python/trl
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
@@ -440,6 +441,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/voxcpm test
|
||||
$(MAKE) -C backend/python/whisperx test
|
||||
$(MAKE) -C backend/python/ace-step test
|
||||
$(MAKE) -C backend/python/trl test
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
IMAGE_TYPE?=core
|
||||
@@ -572,6 +574,7 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true
|
||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
@@ -629,12 +632,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -646,6 +650,23 @@ build-mock-backend: protogen-go
|
||||
clean-mock-backend:
|
||||
rm -f tests/e2e/mock-backend/mock-backend
|
||||
|
||||
########################################################
|
||||
### UI E2E Test Server
|
||||
########################################################
|
||||
|
||||
build-ui-test-server: build-mock-backend react-ui protogen-go
|
||||
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
|
||||
|
||||
test-ui-e2e: build-ui-test-server
|
||||
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
|
||||
|
||||
test-ui-e2e-docker:
|
||||
docker build -t localai-ui-e2e -f tests/e2e-ui/Dockerfile .
|
||||
docker run --rm localai-ui-e2e
|
||||
|
||||
clean-ui-test-server:
|
||||
rm -f tests/e2e-ui/ui-test-server
|
||||
|
||||
########################################################
|
||||
### END Backends
|
||||
########################################################
|
||||
|
||||
71
README.md
71
README.md
@@ -62,49 +62,16 @@
|
||||
|
||||
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
||||
|
||||
<details>
|
||||
<summary><strong>Table of Contents</strong></summary>
|
||||
|
||||
- [Local Stack Family](#local-stack-family)
|
||||
- [Screenshots / Video](#screenshots--video)
|
||||
- [Quickstart](#-quickstart)
|
||||
- [macOS Download](#macos-download)
|
||||
- [Containers (Docker, podman, ...)](#containers-docker-podman-)
|
||||
- [Latest project news](#-latest-project-news)
|
||||
- [Features](#-features)
|
||||
- [Supported Backends & Acceleration](#-supported-backends--acceleration)
|
||||
- [Text Generation & Language Models](#text-generation--language-models)
|
||||
- [Audio & Speech Processing](#audio--speech-processing)
|
||||
- [Image & Video Generation](#image--video-generation)
|
||||
- [Specialized AI Tasks](#specialized-ai-tasks)
|
||||
- [Hardware Acceleration Matrix](#hardware-acceleration-matrix)
|
||||
- [Community and integrations](#-community-and-integrations)
|
||||
- [Resources](#-resources)
|
||||
- [Media, Blogs, Social](#book--media-blogs-social)
|
||||
- [Autonomous Development Team](#-autonomous-development-team)
|
||||
- [Citation](#citation)
|
||||
- [Sponsors](#️-sponsors)
|
||||
- [Individual sponsors](#individual-sponsors)
|
||||
- [Star history](#-star-history)
|
||||
- [License](#-license)
|
||||
- [Acknowledgements](#-acknowledgements)
|
||||
- [Contributors](#-contributors)
|
||||
|
||||
</details>
|
||||
|
||||
## Local Stack Family
|
||||
|
||||
Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tools, you might also like:
|
||||
|
||||
- **[LocalAGI](https://github.com/mudler/LocalAGI)** - AI agent orchestration platform with OpenAI Responses API compatibility and advanced agentic capabilities
|
||||
- **[LocalRecall](https://github.com/mudler/LocalRecall)** - MCP/REST API knowledge base system providing persistent memory and storage for AI agents
|
||||
- 🆕 **[Cogito](https://github.com/mudler/cogito)** - Go library for building intelligent, co-operative agentic software and LLM-powered workflows, focusing on improving results for small, open source language models that scales to any LLM. Powers LocalAGI and LocalAI MCP/Agentic capabilities
|
||||
- 🆕 **[Wiz](https://github.com/mudler/wiz)** - Terminal-based AI agent accessible via Ctrl+Space keybinding. Portable, local-LLM friendly shell assistant with TUI/CLI modes, tool execution with approval, MCP protocol support, and multi-shell compatibility (zsh, bash, fish)
|
||||
- 🆕 **[SkillServer](https://github.com/mudler/skillserver)** - Simple, centralized skills database for AI agents via MCP. Manages skills as Markdown files with MCP server integration, web UI for editing, Git synchronization, and full-text search capabilities
|
||||
|
||||
|
||||
## Screenshots / Video
|
||||
|
||||
### Chat, Model gallery
|
||||
|
||||
https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
|
||||
|
||||
### Agents
|
||||
|
||||
https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
|
||||
|
||||
### Youtube video
|
||||
|
||||
<h1 align="center">
|
||||
@@ -113,29 +80,8 @@ Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tool
|
||||
<br>
|
||||
</h1>
|
||||
|
||||
|
||||
### Screenshots
|
||||
|
||||
| Talk Interface | Generate Audio |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
| Models Overview | Generate Images |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
| Chat Interface | Home |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
| Login | Swarm |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
## 💻 Quickstart
|
||||
|
||||
|
||||
|
||||
### macOS Download:
|
||||
|
||||
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
|
||||
@@ -143,6 +89,7 @@ Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tool
|
||||
</a>
|
||||
|
||||
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
|
||||
> Install the DMG and paste this code into terminal: `sudo xattr -d com.apple.quarantine /Applications/LocalAI.app`
|
||||
|
||||
### Containers (Docker, podman, ...)
|
||||
|
||||
|
||||
@@ -39,6 +39,13 @@ service Backend {
|
||||
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
||||
|
||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||
|
||||
// Fine-tuning RPCs
|
||||
rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {}
|
||||
rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {}
|
||||
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
|
||||
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
|
||||
rpc ExportModel(ExportModelRequest) returns (Result) {}
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -528,3 +535,105 @@ message ModelMetadataResponse {
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
||||
}
|
||||
|
||||
// Fine-tuning messages
|
||||
|
||||
message FineTuneRequest {
|
||||
// Model identification
|
||||
string model = 1; // HF model name or local path
|
||||
string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train
|
||||
string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training"
|
||||
|
||||
// Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion)
|
||||
int32 adapter_rank = 10; // LoRA rank (r), default 16
|
||||
int32 adapter_alpha = 11; // scaling factor, default 16
|
||||
float adapter_dropout = 12; // default 0.0
|
||||
repeated string target_modules = 13; // layer names to adapt
|
||||
|
||||
// Universal training hyperparameters
|
||||
float learning_rate = 20; // default 2e-4
|
||||
int32 num_epochs = 21; // default 3
|
||||
int32 batch_size = 22; // default 2
|
||||
int32 gradient_accumulation_steps = 23; // default 4
|
||||
int32 warmup_steps = 24; // default 5
|
||||
int32 max_steps = 25; // 0 = use epochs
|
||||
int32 save_steps = 26; // 0 = only save final
|
||||
float weight_decay = 27; // default 0.01
|
||||
bool gradient_checkpointing = 28;
|
||||
string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy
|
||||
int32 seed = 30; // default 3407
|
||||
string mixed_precision = 31; // fp16, bf16, fp8, no
|
||||
|
||||
// Dataset
|
||||
string dataset_source = 40; // HF dataset ID, local file/dir path
|
||||
string dataset_split = 41; // train, test, etc.
|
||||
|
||||
// Output
|
||||
string output_dir = 50;
|
||||
string job_id = 51; // client-assigned or auto-generated
|
||||
|
||||
// Resume training from a checkpoint
|
||||
string resume_from_checkpoint = 55; // path to checkpoint dir to resume from
|
||||
|
||||
// Backend-specific AND method-specific extensibility
|
||||
map<string, string> extra_options = 60;
|
||||
}
|
||||
|
||||
message FineTuneJobResult {
|
||||
string job_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
message FineTuneProgressRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
message FineTuneProgressUpdate {
|
||||
string job_id = 1;
|
||||
int32 current_step = 2;
|
||||
int32 total_steps = 3;
|
||||
float current_epoch = 4;
|
||||
float total_epochs = 5;
|
||||
float loss = 6;
|
||||
float learning_rate = 7;
|
||||
float grad_norm = 8;
|
||||
float eval_loss = 9;
|
||||
float eta_seconds = 10;
|
||||
float progress_percent = 11;
|
||||
string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped
|
||||
string message = 13;
|
||||
string checkpoint_path = 14; // set when a checkpoint is saved
|
||||
string sample_path = 15; // set when a sample is generated (video/image backends)
|
||||
map<string, float> extra_metrics = 16; // method-specific metrics
|
||||
}
|
||||
|
||||
message FineTuneStopRequest {
|
||||
string job_id = 1;
|
||||
bool save_checkpoint = 2;
|
||||
}
|
||||
|
||||
message ListCheckpointsRequest {
|
||||
string output_dir = 1;
|
||||
}
|
||||
|
||||
message ListCheckpointsResponse {
|
||||
repeated CheckpointInfo checkpoints = 1;
|
||||
}
|
||||
|
||||
message CheckpointInfo {
|
||||
string path = 1;
|
||||
int32 step = 2;
|
||||
float epoch = 3;
|
||||
float loss = 4;
|
||||
string created_at = 5;
|
||||
}
|
||||
|
||||
message ExportModelRequest {
|
||||
string checkpoint_path = 1;
|
||||
string output_path = 2;
|
||||
string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers
|
||||
string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||
string model = 5; // base model name (for merge operations)
|
||||
map<string, string> extra_options = 6;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=e30f1fdf74ea9238ff562901aa974c75aab6619b
|
||||
LLAMA_VERSION?=5744d7ec430e2f875a393770195fda530560773f
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# acestep.cpp version
|
||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||
ACESTEP_CPP_VERSION?=5aa065445541094cba934299cd498bbb9fa5c434
|
||||
ACESTEP_CPP_VERSION?=ab020a9aefcd364423e0665da12babc6b0c7b507
|
||||
SO_TARGET?=libgoacestepcpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -112,6 +112,7 @@ func TestLoadModel(t *testing.T) {
|
||||
|
||||
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: mainModelPath,
|
||||
ModelPath: modelDir,
|
||||
Options: []string{
|
||||
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
|
||||
"dit_model:acestep-v15-turbo-Q8_0.gguf",
|
||||
@@ -151,6 +152,7 @@ func TestSoundGeneration(t *testing.T) {
|
||||
// Load models
|
||||
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: mainModelPath,
|
||||
ModelPath: modelDir,
|
||||
Options: []string{
|
||||
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
|
||||
"dit_model:acestep-v15-turbo-Q8_0.gguf",
|
||||
|
||||
@@ -24,7 +24,7 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
|
||||
lmModel := opts.ModelFile
|
||||
|
||||
// Get the base directory from ModelFile for resolving relative paths
|
||||
baseDir := filepath.Dir(lmModel)
|
||||
baseDir := opts.ModelPath
|
||||
|
||||
var textEncoderModel, ditModel, vaeModel string
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=d6dd6d7b555c233bb9bc9f20b4751eb8c9269743
|
||||
STABLEDIFFUSION_GGML_VERSION?=545fac4f3fb0117a4e962b1a04cf933a7e635933
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=30c5194c9691e4e9a98b3dea9f19727397d3f46e
|
||||
WHISPER_CPP_VERSION?=9386f239401074690479731c1e41683fbbeac557
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -3029,3 +3029,54 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-voxtral
|
||||
- &trl
|
||||
name: "trl"
|
||||
alias: "trl"
|
||||
license: apache-2.0
|
||||
description: |
|
||||
HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods.
|
||||
Works on CPU and GPU.
|
||||
urls:
|
||||
- https://github.com/huggingface/trl
|
||||
tags:
|
||||
- fine-tuning
|
||||
- LLM
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
capabilities:
|
||||
default: "cpu-trl"
|
||||
nvidia: "cuda12-trl"
|
||||
nvidia-cuda-12: "cuda12-trl"
|
||||
nvidia-cuda-13: "cuda13-trl"
|
||||
## TRL backend images
|
||||
- !!merge <<: *trl
|
||||
name: "cpu-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cpu-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda12-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cublas-cuda12-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda12-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cublas-cuda12-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda13-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cublas-cuda13-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda13-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cublas-cuda13-trl
|
||||
|
||||
26
backend/python/trl/Makefile
Normal file
26
backend/python/trl/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
|
||||
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||
|
||||
.PHONY: trl
|
||||
trl:
|
||||
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: trl
|
||||
@echo "Running trl..."
|
||||
bash run.sh
|
||||
@echo "trl run."
|
||||
|
||||
.PHONY: test
|
||||
test: trl
|
||||
@echo "Testing trl..."
|
||||
bash test.sh
|
||||
@echo "trl tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
860
backend/python/trl/backend.py
Normal file
860
backend/python/trl/backend.py
Normal file
@@ -0,0 +1,860 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TRL fine-tuning backend for LocalAI.
|
||||
|
||||
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
|
||||
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||
|
||||
|
||||
class ProgressCallback:
|
||||
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
|
||||
|
||||
def __init__(self, job_id, progress_queue, total_epochs):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = progress_queue
|
||||
self.total_epochs = total_epochs
|
||||
|
||||
def get_callback(self):
|
||||
from transformers import TrainerCallback
|
||||
|
||||
parent = self
|
||||
|
||||
class _Callback(TrainerCallback):
|
||||
def __init__(self):
|
||||
self._train_start_time = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self._train_start_time = time.time()
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if logs is None:
|
||||
return
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
eta = 0.0
|
||||
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
|
||||
elapsed = time.time() - self._train_start_time
|
||||
remaining_steps = total_steps - state.global_step
|
||||
if state.global_step > 0:
|
||||
eta = remaining_steps * (elapsed / state.global_step)
|
||||
|
||||
extra_metrics = {}
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
|
||||
extra_metrics[k] = float(v)
|
||||
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(logs.get('epoch', 0)),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
loss=float(logs.get('loss', 0)),
|
||||
learning_rate=float(logs.get('learning_rate', 0)),
|
||||
grad_norm=float(logs.get('grad_norm', 0)),
|
||||
eval_loss=float(logs.get('eval_loss', 0)),
|
||||
eta_seconds=float(eta),
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
extra_metrics=extra_metrics,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_prediction_step(self, args, state, control, **kwargs):
|
||||
"""Send periodic updates during evaluation so the UI doesn't freeze."""
|
||||
if not hasattr(self, '_eval_update_counter'):
|
||||
self._eval_update_counter = 0
|
||||
self._eval_update_counter += 1
|
||||
# Throttle: send an update every 10 prediction steps
|
||||
if self._eval_update_counter % 10 != 0:
|
||||
return
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(state.epoch or 0),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
message=f"Evaluating... (batch {self._eval_update_counter})",
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
"""Report eval results once evaluation is done."""
|
||||
# Reset prediction counter for next eval round
|
||||
self._eval_update_counter = 0
|
||||
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
|
||||
eval_loss = 0.0
|
||||
extra_metrics = {}
|
||||
if metrics:
|
||||
eval_loss = float(metrics.get('eval_loss', 0))
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
|
||||
extra_metrics[k] = float(v)
|
||||
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(state.epoch or 0),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
eval_loss=eval_loss,
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
message=f"Evaluation complete at step {state.global_step}",
|
||||
extra_metrics=extra_metrics,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
status="saving",
|
||||
message=f"Checkpoint saved at step {state.global_step}",
|
||||
checkpoint_path=checkpoint_path,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=state.max_steps,
|
||||
progress_percent=100.0,
|
||||
status="completed",
|
||||
message="Training completed",
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
return _Callback()
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Represents an active fine-tuning job."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.trainer = None
|
||||
self.thread = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.error = None
|
||||
self.completed = False
|
||||
self.stopped = False
|
||||
|
||||
|
||||
def _is_gated_repo_error(exc):
|
||||
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
|
||||
try:
|
||||
from huggingface_hub.utils import GatedRepoError
|
||||
if isinstance(exc, GatedRepoError):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
msg = str(exc).lower()
|
||||
if "gated repo" in msg or "access to model" in msg:
|
||||
return True
|
||||
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
|
||||
if exc.response.status_code in (401, 403):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.active_job = None
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Accept LoadModel — actual model loading happens in StartFineTune."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def StartFineTune(self, request, context):
|
||||
if self.active_job is not None and not self.active_job.completed:
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id="",
|
||||
success=False,
|
||||
message="A fine-tuning job is already running",
|
||||
)
|
||||
|
||||
job_id = request.job_id if request.job_id else str(uuid.uuid4())
|
||||
job = ActiveJob(job_id)
|
||||
self.active_job = job
|
||||
|
||||
# Start training in background thread
|
||||
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
|
||||
job.thread = thread
|
||||
thread.start()
|
||||
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id=job_id,
|
||||
success=True,
|
||||
message="Fine-tuning job started",
|
||||
)
|
||||
|
||||
def _run_training(self, request, job):
|
||||
try:
|
||||
self._do_training(request, job)
|
||||
except Exception as e:
|
||||
if _is_gated_repo_error(e):
|
||||
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||
else:
|
||||
msg = f"Training failed: {e}"
|
||||
job.error = msg
|
||||
job.completed = True
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
status="failed",
|
||||
message=msg,
|
||||
)
|
||||
job.progress_queue.put(update)
|
||||
# Send sentinel
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def _do_training(self, request, job):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
extra = dict(request.extra_options)
|
||||
training_method = request.training_method or "sft"
|
||||
training_type = request.training_type or "lora"
|
||||
|
||||
# Send loading status
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
|
||||
))
|
||||
|
||||
# Determine device and dtype
|
||||
device_map = "auto" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
|
||||
|
||||
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
|
||||
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
|
||||
# Load model
|
||||
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
|
||||
if hf_token:
|
||||
model_kwargs["token"] = hf_token
|
||||
if extra.get("trust_remote_code", "false").lower() == "true":
|
||||
model_kwargs["trust_remote_code"] = True
|
||||
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
|
||||
from transformers import BitsAndBytesConfig
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
job.model = model
|
||||
job.tokenizer = tokenizer
|
||||
|
||||
# Apply LoRA if requested
|
||||
if training_type == "lora":
|
||||
from peft import LoraConfig, get_peft_model
|
||||
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
|
||||
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
|
||||
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
|
||||
|
||||
target_modules = list(request.target_modules) if request.target_modules else None
|
||||
peft_config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=target_modules or "all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
# Load dataset
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
|
||||
))
|
||||
|
||||
dataset_split = request.dataset_split or "train"
|
||||
if os.path.exists(request.dataset_source):
|
||||
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
|
||||
elif request.dataset_source.endswith('.csv'):
|
||||
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
|
||||
else:
|
||||
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||
else:
|
||||
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||
|
||||
# Eval dataset setup
|
||||
eval_dataset = None
|
||||
eval_strategy = extra.get("eval_strategy", "steps")
|
||||
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
|
||||
|
||||
if eval_strategy != "no":
|
||||
eval_split = extra.get("eval_split")
|
||||
eval_dataset_source = extra.get("eval_dataset_source")
|
||||
if eval_split:
|
||||
# Load a specific split as eval dataset
|
||||
if os.path.exists(request.dataset_source):
|
||||
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
|
||||
elif request.dataset_source.endswith('.csv'):
|
||||
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
|
||||
else:
|
||||
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||
else:
|
||||
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||
elif eval_dataset_source:
|
||||
# Load eval dataset from a separate source
|
||||
eval_dataset = load_dataset(eval_dataset_source, split="train")
|
||||
else:
|
||||
# Auto-split from training set
|
||||
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
|
||||
split = dataset.train_test_split(test_size=eval_split_ratio)
|
||||
dataset = split["train"]
|
||||
eval_dataset = split["test"]
|
||||
|
||||
if eval_strategy == "no":
|
||||
eval_dataset = None
|
||||
|
||||
# Training config
|
||||
output_dir = request.output_dir or f"./output-{job.job_id}"
|
||||
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
|
||||
batch_size = request.batch_size if request.batch_size > 0 else 2
|
||||
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
|
||||
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
|
||||
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
|
||||
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
|
||||
max_steps = request.max_steps if request.max_steps > 0 else -1
|
||||
save_steps = request.save_steps if request.save_steps > 0 else 500
|
||||
seed = request.seed if request.seed > 0 else 3407
|
||||
optimizer = request.optimizer or "adamw_torch"
|
||||
|
||||
# Checkpoint save controls
|
||||
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
|
||||
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
|
||||
|
||||
# CPU vs GPU training args (can be overridden via extra_options)
|
||||
use_cpu = not torch.cuda.is_available()
|
||||
common_train_kwargs = {}
|
||||
if use_cpu:
|
||||
common_train_kwargs["use_cpu"] = True
|
||||
common_train_kwargs["fp16"] = False
|
||||
common_train_kwargs["bf16"] = False
|
||||
common_train_kwargs["gradient_checkpointing"] = False
|
||||
else:
|
||||
common_train_kwargs["bf16"] = True
|
||||
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
|
||||
|
||||
# Allow extra_options to override training kwargs
|
||||
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
|
||||
if flag in extra:
|
||||
common_train_kwargs[flag] = extra[flag].lower() == "true"
|
||||
|
||||
# Create progress callback
|
||||
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
|
||||
|
||||
# Build save kwargs (shared across all methods)
|
||||
_save_kwargs = {}
|
||||
if save_strategy == "steps" and save_steps > 0:
|
||||
_save_kwargs["save_steps"] = save_steps
|
||||
_save_kwargs["save_strategy"] = "steps"
|
||||
elif save_strategy == "epoch":
|
||||
_save_kwargs["save_strategy"] = "epoch"
|
||||
elif save_strategy == "no":
|
||||
_save_kwargs["save_strategy"] = "no"
|
||||
else:
|
||||
_save_kwargs["save_steps"] = save_steps
|
||||
_save_kwargs["save_strategy"] = "steps"
|
||||
if save_total_limit:
|
||||
_save_kwargs["save_total_limit"] = save_total_limit
|
||||
|
||||
# Eval kwargs
|
||||
_eval_kwargs = {}
|
||||
if eval_dataset is not None:
|
||||
_eval_kwargs["eval_strategy"] = eval_strategy
|
||||
_eval_kwargs["eval_steps"] = eval_steps
|
||||
|
||||
# Common training arguments shared by all methods
|
||||
_common_args = dict(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
learning_rate=lr,
|
||||
gradient_accumulation_steps=grad_accum,
|
||||
warmup_steps=warmup_steps,
|
||||
weight_decay=weight_decay,
|
||||
max_steps=max_steps,
|
||||
seed=seed,
|
||||
optim=optimizer,
|
||||
logging_steps=1,
|
||||
report_to="none",
|
||||
**_save_kwargs,
|
||||
**common_train_kwargs,
|
||||
**_eval_kwargs,
|
||||
)
|
||||
|
||||
# Select trainer based on training method
|
||||
if training_method == "sft":
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
max_length = int(extra.get("max_seq_length", "512"))
|
||||
packing = extra.get("packing", "false").lower() == "true"
|
||||
|
||||
training_args = SFTConfig(
|
||||
max_length=max_length,
|
||||
packing=packing,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "dpo":
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
loss_type = extra.get("loss_type", "sigmoid")
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = DPOConfig(
|
||||
beta=beta,
|
||||
loss_type=loss_type,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "grpo":
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
num_generations = int(extra.get("num_generations", "4"))
|
||||
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||
|
||||
training_args = GRPOConfig(
|
||||
num_generations=num_generations,
|
||||
max_completion_length=max_completion_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
# GRPO requires reward functions passed via extra_options as a JSON list
|
||||
from reward_functions import build_reward_functions
|
||||
|
||||
reward_funcs = []
|
||||
if extra.get("reward_funcs"):
|
||||
reward_funcs = build_reward_functions(extra["reward_funcs"])
|
||||
|
||||
if not reward_funcs:
|
||||
raise ValueError(
|
||||
"GRPO requires at least one reward function. "
|
||||
"Specify reward_functions in the request or "
|
||||
"reward_funcs in extra_options."
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=reward_funcs,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "orpo":
|
||||
from trl import ORPOTrainer, ORPOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = ORPOConfig(
|
||||
beta=beta,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "kto":
|
||||
from trl import KTOTrainer, KTOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = KTOConfig(
|
||||
beta=beta,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "rloo":
|
||||
from trl import RLOOTrainer, RLOOConfig
|
||||
|
||||
num_generations = int(extra.get("num_generations", "4"))
|
||||
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||
|
||||
training_args = RLOOConfig(
|
||||
num_generations=num_generations,
|
||||
max_new_tokens=max_completion_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = RLOOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "reward":
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = RewardConfig(
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported training method: {training_method}. "
|
||||
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
|
||||
|
||||
job.trainer = trainer
|
||||
|
||||
# Start training
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="training", message="Training started",
|
||||
))
|
||||
|
||||
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
|
||||
trainer.train(resume_from_checkpoint=resume_ckpt)
|
||||
|
||||
# Save final model
|
||||
trainer.save_model(output_dir)
|
||||
if tokenizer:
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
job.completed = True
|
||||
# Sentinel to signal stream end
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def FineTuneProgress(self, request, context):
|
||||
if self.active_job is None or self.active_job.job_id != request.job_id:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(f"Job {request.job_id} not found")
|
||||
return
|
||||
|
||||
job = self.active_job
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
if update is None:
|
||||
break
|
||||
yield update
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
except queue.Empty:
|
||||
if job.completed or job.stopped:
|
||||
break
|
||||
if not context.is_active():
|
||||
break
|
||||
continue
|
||||
|
||||
def StopFineTune(self, request, context):
|
||||
# Stopping is handled by killing the process from Go via ShutdownModel.
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def ListCheckpoints(self, request, context):
|
||||
output_dir = request.output_dir
|
||||
if not os.path.isdir(output_dir):
|
||||
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
|
||||
|
||||
checkpoints = []
|
||||
for entry in sorted(os.listdir(output_dir)):
|
||||
if entry.startswith("checkpoint-"):
|
||||
ckpt_path = os.path.join(output_dir, entry)
|
||||
if not os.path.isdir(ckpt_path):
|
||||
continue
|
||||
step = 0
|
||||
try:
|
||||
step = int(entry.split("-")[1])
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
# Try to read trainer_state.json for metadata
|
||||
loss = 0.0
|
||||
epoch = 0.0
|
||||
state_file = os.path.join(ckpt_path, "trainer_state.json")
|
||||
if os.path.exists(state_file):
|
||||
try:
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
if state.get("log_history"):
|
||||
last_log = state["log_history"][-1]
|
||||
loss = last_log.get("loss", 0.0)
|
||||
epoch = last_log.get("epoch", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
created_at = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
time.gmtime(os.path.getmtime(ckpt_path)),
|
||||
)
|
||||
|
||||
checkpoints.append(backend_pb2.CheckpointInfo(
|
||||
path=ckpt_path,
|
||||
step=step,
|
||||
epoch=float(epoch),
|
||||
loss=float(loss),
|
||||
created_at=created_at,
|
||||
))
|
||||
|
||||
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
|
||||
|
||||
def ExportModel(self, request, context):
|
||||
export_format = request.export_format or "lora"
|
||||
output_path = request.output_path
|
||||
checkpoint_path = request.checkpoint_path
|
||||
|
||||
# Extract HF token for gated model access
|
||||
extra = dict(request.extra_options) if request.extra_options else {}
|
||||
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
|
||||
if not checkpoint_path or not os.path.isdir(checkpoint_path):
|
||||
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
if export_format == "lora":
|
||||
# Just copy the adapter files
|
||||
import shutil
|
||||
for f in os.listdir(checkpoint_path):
|
||||
src = os.path.join(checkpoint_path, f)
|
||||
dst = os.path.join(output_path, f)
|
||||
if os.path.isfile(src):
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
elif export_format in ("merged_16bit", "merged_4bit"):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
base_model_name = request.model
|
||||
if not base_model_name:
|
||||
return backend_pb2.Result(success=False, message="Base model name required for merge export")
|
||||
|
||||
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||
merged = model.merge_and_unload()
|
||||
merged.save_pretrained(output_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||
tokenizer.save_pretrained(output_path)
|
||||
|
||||
elif export_format == "gguf":
|
||||
import torch
|
||||
import subprocess
|
||||
import shutil
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
base_model_name = request.model
|
||||
if not base_model_name:
|
||||
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
|
||||
|
||||
# Step 1: Merge LoRA into base model
|
||||
merge_dir = os.path.join(output_path, "_hf_merged")
|
||||
os.makedirs(merge_dir, exist_ok=True)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||
merged = model.merge_and_unload()
|
||||
merged.save_pretrained(merge_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||
tokenizer.save_pretrained(merge_dir)
|
||||
|
||||
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
|
||||
# Gemma models need this file for GGUF conversion to use the
|
||||
# SentencePiece path; without it, the script falls back to BPE
|
||||
# handling which fails on unrecognized pre-tokenizer hashes.
|
||||
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
|
||||
if not os.path.exists(sp_model_path):
|
||||
sp_copied = False
|
||||
# Method 1: Load the slow tokenizer which keeps the SP model file
|
||||
try:
|
||||
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
|
||||
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
|
||||
sp_copied = True
|
||||
print(f"Copied tokenizer.model from slow tokenizer cache")
|
||||
except Exception as e:
|
||||
print(f"Slow tokenizer method failed: {e}")
|
||||
# Method 2: Download from HF hub
|
||||
if not sp_copied:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(cached_sp, sp_model_path)
|
||||
sp_copied = True
|
||||
print(f"Copied tokenizer.model from HF hub")
|
||||
except Exception as e:
|
||||
print(f"HF hub download method failed: {e}")
|
||||
if not sp_copied:
|
||||
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
|
||||
"GGUF conversion may fail for SentencePiece models.")
|
||||
|
||||
# Free GPU memory before conversion
|
||||
del merged, model, base_model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
|
||||
quant = request.quantization_method or "auto"
|
||||
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
|
||||
outtype = outtype_map.get(quant, "f16")
|
||||
|
||||
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
|
||||
gguf_path = os.path.join(output_path, gguf_filename)
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
|
||||
if not os.path.exists(convert_script):
|
||||
return backend_pb2.Result(success=False,
|
||||
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
|
||||
|
||||
# Log merge_dir contents for debugging conversion issues
|
||||
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
|
||||
print(f"Merge dir contents: {merge_files}", flush=True)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["NO_LOCAL_GGUF"] = "1"
|
||||
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
|
||||
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
|
||||
if conv_result.returncode != 0:
|
||||
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
|
||||
return backend_pb2.Result(success=False,
|
||||
message=f"GGUF conversion failed: {diag}")
|
||||
|
||||
# Clean up intermediate merged model
|
||||
shutil.rmtree(merge_dir, ignore_errors=True)
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
|
||||
|
||||
except Exception as e:
|
||||
if _is_gated_repo_error(e):
|
||||
return backend_pb2.Result(success=False,
|
||||
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
|
||||
|
||||
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
# Handle graceful shutdown
|
||||
def stop(signum, frame):
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, stop)
|
||||
signal.signal(signal.SIGINT, stop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
serve(args.addr)
|
||||
37
backend/python/trl/install.sh
Normal file
37
backend/python/trl/install.sh
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
installRequirements
|
||||
|
||||
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
|
||||
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
curl -L --fail --retry 3 \
|
||||
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
|
||||
fi
|
||||
|
||||
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||
pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
pip install "gguf>=0.16.0"
|
||||
}
|
||||
else
|
||||
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
uv pip install "gguf>=0.16.0"
|
||||
}
|
||||
fi
|
||||
9
backend/python/trl/requirements-cpu.txt
Normal file
9
backend/python/trl/requirements-cpu.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
9
backend/python/trl/requirements-cublas12.txt
Normal file
9
backend/python/trl/requirements-cublas12.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
bitsandbytes
|
||||
9
backend/python/trl/requirements-cublas13.txt
Normal file
9
backend/python/trl/requirements-cublas13.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
bitsandbytes
|
||||
3
backend/python/trl/requirements.txt
Normal file
3
backend/python/trl/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
236
backend/python/trl/reward_functions.py
Normal file
236
backend/python/trl/reward_functions.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
Built-in reward functions and inline function compiler for GRPO training.
|
||||
|
||||
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import math
|
||||
import string
|
||||
import functools
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in reward functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
|
||||
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
|
||||
return [1.0 if pattern.search(c) else 0.0 for c in completions]
|
||||
|
||||
|
||||
def reasoning_accuracy_reward(completions, **kwargs):
|
||||
"""Extracts <answer>...</answer> content and compares to the expected answer."""
|
||||
answers = kwargs.get("answer", [])
|
||||
if not answers:
|
||||
return [0.0] * len(completions)
|
||||
|
||||
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
|
||||
scores = []
|
||||
for i, c in enumerate(completions):
|
||||
expected = answers[i] if i < len(answers) else ""
|
||||
match = pattern.search(c)
|
||||
if match:
|
||||
extracted = match.group(1).strip()
|
||||
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
|
||||
else:
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
def length_reward(completions, target_length=200, **kwargs):
|
||||
"""Score based on proximity to target_length. Returns [0, 1]."""
|
||||
scores = []
|
||||
for c in completions:
|
||||
length = len(c)
|
||||
if target_length <= 0:
|
||||
scores.append(0.0)
|
||||
else:
|
||||
diff = abs(length - target_length) / target_length
|
||||
scores.append(max(0.0, 1.0 - diff))
|
||||
return scores
|
||||
|
||||
|
||||
def xml_tag_reward(completions, **kwargs):
|
||||
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
|
||||
tags = ["think", "answer"]
|
||||
scores = []
|
||||
for c in completions:
|
||||
tag_score = 0.0
|
||||
for tag in tags:
|
||||
if f"<{tag}>" in c and f"</{tag}>" in c:
|
||||
tag_score += 0.5
|
||||
scores.append(min(tag_score, 1.0))
|
||||
return scores
|
||||
|
||||
|
||||
def no_repetition_reward(completions, n=4, **kwargs):
|
||||
"""Penalizes n-gram repetition. Returns [0, 1]."""
|
||||
scores = []
|
||||
for c in completions:
|
||||
words = c.split()
|
||||
if len(words) < n:
|
||||
scores.append(1.0)
|
||||
continue
|
||||
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
|
||||
unique = len(set(ngrams))
|
||||
total = len(ngrams)
|
||||
scores.append(unique / total if total > 0 else 1.0)
|
||||
return scores
|
||||
|
||||
|
||||
def code_execution_reward(completions, **kwargs):
|
||||
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
|
||||
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
|
||||
scores = []
|
||||
for c in completions:
|
||||
match = pattern.search(c)
|
||||
if not match:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
code = match.group(1)
|
||||
try:
|
||||
compile(code, "<inline>", "exec")
|
||||
scores.append(1.0)
|
||||
except SyntaxError:
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BUILTIN_REGISTRY = {
|
||||
"format_reward": format_reward,
|
||||
"reasoning_accuracy_reward": reasoning_accuracy_reward,
|
||||
"length_reward": length_reward,
|
||||
"xml_tag_reward": xml_tag_reward,
|
||||
"no_repetition_reward": no_repetition_reward,
|
||||
"code_execution_reward": code_execution_reward,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inline function compiler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SAFE_BUILTINS = {
|
||||
"len": len, "int": int, "float": float, "str": str, "bool": bool,
|
||||
"list": list, "dict": dict, "tuple": tuple, "set": set,
|
||||
"range": range, "enumerate": enumerate, "zip": zip,
|
||||
"map": map, "filter": filter, "sorted": sorted,
|
||||
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
|
||||
"any": any, "all": all, "isinstance": isinstance, "type": type,
|
||||
"print": print, "True": True, "False": False, "None": None,
|
||||
"ValueError": ValueError, "TypeError": TypeError,
|
||||
"KeyError": KeyError, "IndexError": IndexError,
|
||||
}
|
||||
|
||||
|
||||
def compile_inline_reward(name, code):
|
||||
"""Compile user-provided code into a reward function.
|
||||
|
||||
The code should be the body of a function that receives
|
||||
`completions` (list[str]) and `**kwargs`, and returns list[float].
|
||||
|
||||
Available modules: re, math, json, string.
|
||||
"""
|
||||
func_source = (
|
||||
f"def _user_reward_{name}(completions, **kwargs):\n"
|
||||
+ "\n".join(f" {line}" for line in code.splitlines())
|
||||
)
|
||||
|
||||
restricted_globals = {
|
||||
"__builtins__": _SAFE_BUILTINS,
|
||||
"re": re,
|
||||
"math": math,
|
||||
"json": json,
|
||||
"string": string,
|
||||
}
|
||||
|
||||
try:
|
||||
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
|
||||
except SyntaxError as e:
|
||||
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
|
||||
|
||||
exec(compiled, restricted_globals)
|
||||
func = restricted_globals[f"_user_reward_{name}"]
|
||||
|
||||
# Validate with a quick smoke test
|
||||
try:
|
||||
result = func(["test"], answer=["test"])
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(
|
||||
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
|
||||
)
|
||||
except Exception as e:
|
||||
if "must return a list" in str(e):
|
||||
raise
|
||||
# Other errors during smoke test are acceptable (e.g. missing kwargs)
|
||||
pass
|
||||
|
||||
return func
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_reward_functions(specs_json):
|
||||
"""Parse a JSON list of reward function specs and return a list of callables.
|
||||
|
||||
Each spec is a dict with:
|
||||
- type: "builtin" or "inline"
|
||||
- name: function name
|
||||
- code: (inline only) Python function body
|
||||
- params: (optional) dict of string params applied via functools.partial
|
||||
"""
|
||||
if isinstance(specs_json, str):
|
||||
specs = json.loads(specs_json)
|
||||
else:
|
||||
specs = specs_json
|
||||
|
||||
if not isinstance(specs, list):
|
||||
raise ValueError("reward_funcs must be a JSON array of reward function specs")
|
||||
|
||||
reward_funcs = []
|
||||
for spec in specs:
|
||||
spec_type = spec.get("type", "builtin")
|
||||
name = spec.get("name", "")
|
||||
params = spec.get("params", {})
|
||||
|
||||
if spec_type == "builtin":
|
||||
if name not in BUILTIN_REGISTRY:
|
||||
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
|
||||
raise ValueError(
|
||||
f"Unknown builtin reward function '{name}'. Available: {available}"
|
||||
)
|
||||
func = BUILTIN_REGISTRY[name]
|
||||
if params:
|
||||
# Convert string params to appropriate types
|
||||
typed_params = {}
|
||||
for k, v in params.items():
|
||||
try:
|
||||
typed_params[k] = int(v)
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
typed_params[k] = float(v)
|
||||
except (ValueError, TypeError):
|
||||
typed_params[k] = v
|
||||
func = functools.partial(func, **typed_params)
|
||||
reward_funcs.append(func)
|
||||
|
||||
elif spec_type == "inline":
|
||||
code = spec.get("code", "")
|
||||
if not code.strip():
|
||||
raise ValueError(f"Inline reward function '{name}' has no code")
|
||||
func = compile_inline_reward(name, code)
|
||||
reward_funcs.append(func)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
|
||||
|
||||
return reward_funcs
|
||||
10
backend/python/trl/run.sh
Normal file
10
backend/python/trl/run.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
58
backend/python/trl/test.py
Normal file
58
backend/python/trl/test.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Test script for the TRL fine-tuning gRPC backend.
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""Tests for the TRL fine-tuning gRPC service."""
|
||||
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self):
|
||||
self.service.kill()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""Test that the server starts and responds to health checks."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_list_checkpoints_empty(self):
|
||||
"""Test listing checkpoints on a non-existent directory."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.ListCheckpoints(
|
||||
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
|
||||
)
|
||||
self.assertEqual(len(response.checkpoints), 0)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("ListCheckpoints service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
11
backend/python/trl/test.sh
Normal file
11
backend/python/trl/test.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Application struct {
|
||||
@@ -22,6 +23,7 @@ type Application struct {
|
||||
galleryService *services.GalleryService
|
||||
agentJobService *services.AgentJobService
|
||||
agentPoolService atomic.Pointer[services.AgentPoolService]
|
||||
authDB *gorm.DB
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
p2pMutex sync.Mutex
|
||||
@@ -74,6 +76,11 @@ func (a *Application) AgentPoolService() *services.AgentPoolService {
|
||||
return a.agentPoolService.Load()
|
||||
}
|
||||
|
||||
// AuthDB returns the auth database connection, or nil if auth is not enabled.
|
||||
func (a *Application) AuthDB() *gorm.DB {
|
||||
return a.authDB
|
||||
}
|
||||
|
||||
// StartupConfig returns the original startup configuration (from env vars, before file loading)
|
||||
func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||
return a.startupConfig
|
||||
@@ -118,9 +125,23 @@ func (a *Application) StartAgentPool() {
|
||||
xlog.Error("Failed to create agent pool service", "error", err)
|
||||
return
|
||||
}
|
||||
if a.authDB != nil {
|
||||
aps.SetAuthDB(a.authDB)
|
||||
}
|
||||
if err := aps.Start(a.applicationConfig.Context); err != nil {
|
||||
xlog.Error("Failed to start agent pool", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wire per-user scoped services so collections, skills, and jobs are isolated per user
|
||||
usm := services.NewUserServicesManager(
|
||||
aps.UserStorage(),
|
||||
a.applicationConfig,
|
||||
a.modelLoader,
|
||||
a.backendLoader,
|
||||
a.templatesEvaluator,
|
||||
)
|
||||
aps.SetUserServicesManager(usm)
|
||||
|
||||
a.agentPoolService.Store(aps)
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
envF16 := appConfig.F16 == startupAppConfig.F16
|
||||
envDebug := appConfig.Debug == startupAppConfig.Debug
|
||||
envCORS := appConfig.CORS == startupAppConfig.CORS
|
||||
envCSRF := appConfig.CSRF == startupAppConfig.CSRF
|
||||
envCSRF := appConfig.DisableCSRF == startupAppConfig.DisableCSRF
|
||||
envCORSAllowOrigins := appConfig.CORSAllowOrigins == startupAppConfig.CORSAllowOrigins
|
||||
envP2PToken := appConfig.P2PToken == startupAppConfig.P2PToken
|
||||
envP2PNetworkID := appConfig.P2PNetworkID == startupAppConfig.P2PNetworkID
|
||||
@@ -313,7 +313,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
appConfig.CORS = *settings.CORS
|
||||
}
|
||||
if settings.CSRF != nil && !envCSRF {
|
||||
appConfig.CSRF = *settings.CSRF
|
||||
appConfig.DisableCSRF = *settings.CSRF
|
||||
}
|
||||
if settings.CORSAllowOrigins != nil && !envCORSAllowOrigins {
|
||||
appConfig.CORSAllowOrigins = *settings.CORSAllowOrigins
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -10,6 +12,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
coreStartup "github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
@@ -81,6 +84,45 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize auth database if auth is enabled
|
||||
if options.Auth.Enabled {
|
||||
// Auto-generate HMAC secret if not provided
|
||||
if options.Auth.APIKeyHMACSecret == "" {
|
||||
secretFile := filepath.Join(options.DataPath, ".hmac_secret")
|
||||
secret, err := loadOrGenerateHMACSecret(secretFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize HMAC secret: %w", err)
|
||||
}
|
||||
options.Auth.APIKeyHMACSecret = secret
|
||||
}
|
||||
|
||||
authDB, err := auth.InitDB(options.Auth.DatabaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize auth database: %w", err)
|
||||
}
|
||||
application.authDB = authDB
|
||||
xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL)
|
||||
|
||||
// Start session and expired API key cleanup goroutine
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-options.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := auth.CleanExpiredSessions(authDB); err != nil {
|
||||
xlog.Error("failed to clean expired sessions", "error", err)
|
||||
}
|
||||
if err := auth.CleanExpiredAPIKeys(authDB); err != nil {
|
||||
xlog.Error("failed to clean expired API keys", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
xlog.Error("error installing models", "error", err)
|
||||
}
|
||||
@@ -136,6 +178,8 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
loadRuntimeSettingsFromFile(options)
|
||||
}
|
||||
|
||||
application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging)
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
@@ -382,6 +426,12 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
if settings.EnableBackendLogging != nil {
|
||||
if !options.EnableBackendLogging {
|
||||
options.EnableBackendLogging = *settings.EnableBackendLogging
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Debug("Runtime settings loaded from runtime_settings.json")
|
||||
}
|
||||
|
||||
@@ -426,6 +476,31 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon
|
||||
}
|
||||
}
|
||||
|
||||
// loadOrGenerateHMACSecret loads an HMAC secret from the given file path,
|
||||
// or generates a random 32-byte secret and persists it if the file doesn't exist.
|
||||
func loadOrGenerateHMACSecret(path string) (string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err == nil {
|
||||
secret := string(data)
|
||||
if len(secret) >= 32 {
|
||||
return secret, nil
|
||||
}
|
||||
}
|
||||
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("failed to generate HMAC secret: %w", err)
|
||||
}
|
||||
secret := hex.EncodeToString(b)
|
||||
|
||||
if err := os.WriteFile(path, []byte(secret), 0600); err != nil {
|
||||
return "", fmt.Errorf("failed to persist HMAC secret: %w", err)
|
||||
}
|
||||
|
||||
xlog.Info("Generated new HMAC secret for API key hashing", "path", path)
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// migrateDataFiles moves persistent data files from the old config directory
|
||||
// to the new data directory. Only moves files that exist in src but not in dst.
|
||||
func migrateDataFiles(srcDir, dstDir string) {
|
||||
|
||||
@@ -3,8 +3,10 @@ package backend
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
@@ -18,6 +20,7 @@ func Detection(
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
detectionModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -25,9 +28,35 @@ func Detection(
|
||||
return nil, fmt.Errorf("could not load detection model")
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
|
||||
Src: sourceFile,
|
||||
})
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Type: trace.BackendTraceDetection,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(sourceFile, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
"source_file": sourceFile,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConf
|
||||
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ func ImageGeneration(height, width, step, seed int, positive_prompt, negative_pr
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,10 @@ type TokenUsage struct {
|
||||
TimingTokenGeneration float64
|
||||
}
|
||||
|
||||
// ModelInferenceFunc is a test-friendly indirection to call model inference logic.
|
||||
// Tests can override this variable to provide a stub implementation.
|
||||
var ModelInferenceFunc = ModelInference
|
||||
|
||||
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64, metadata map[string]string) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
@@ -61,6 +65,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
opts := ModelOptions(*c, o)
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -4,13 +4,33 @@ import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// recordModelLoadFailure records a backend trace when model loading fails.
|
||||
func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, backend string, err error, data map[string]any) {
|
||||
if !appConfig.EnableTracing {
|
||||
return
|
||||
}
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems)
|
||||
trace.RecordBackendTrace(trace.BackendTrace{
|
||||
Timestamp: time.Now(),
|
||||
Type: trace.BackendTraceModelLoad,
|
||||
ModelName: modelName,
|
||||
Backend: backend,
|
||||
Summary: "Model load failed",
|
||||
Error: err.Error(),
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
|
||||
name := c.Name
|
||||
if name == "" {
|
||||
@@ -109,6 +129,16 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
|
||||
mmap = *c.MMap
|
||||
}
|
||||
|
||||
// Intel SYCL backend has issues with mmap enabled
|
||||
// See: https://github.com/mudler/LocalAI/issues/9012
|
||||
// Automatically disable mmap for Intel SYCL backends
|
||||
if c.Backend != "" {
|
||||
if strings.Contains(strings.ToLower(c.Backend), "intel") || strings.Contains(strings.ToLower(c.Backend), "sycl") {
|
||||
mmap = false
|
||||
xlog.Info("Auto-disabling mmap for Intel SYCL backend", "backend", c.Backend)
|
||||
}
|
||||
}
|
||||
|
||||
ctxSize := 4096
|
||||
if c.ContextSize != nil {
|
||||
ctxSize = *c.ContextSize
|
||||
|
||||
@@ -15,6 +15,7 @@ func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
rerankModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ func SoundGeneration(
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
soundGenModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ func TokenMetrics(
|
||||
opts := ModelOptions(modelConfig, appConfig, model.WithModel(modelFile))
|
||||
model, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.Model
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err = loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return schema.TokenizeResponse{}, err
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
|
||||
transcriptionModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ func ModelTTS(
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
@@ -131,6 +132,7 @@ func ModelTTSStream(
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
ttsModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ func VAD(request *schema.VADRequest,
|
||||
opts := ModelOptions(modelConfig, appConfig)
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, en
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
235
core/cli/agent.go
Normal file
235
core/cli/agent.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
type AgentCMD struct {
|
||||
Run AgentRunCMD `cmd:"" help:"Run an agent standalone (without the full LocalAI server)"`
|
||||
List AgentListCMD `cmd:"" help:"List agents in the pool registry"`
|
||||
}
|
||||
|
||||
type AgentRunCMD struct {
|
||||
Name string `arg:"" optional:"" help:"Agent name to run from the pool registry (pool.json)"`
|
||||
|
||||
Config string `short:"c" help:"Path to a JSON agent config file (alternative to loading by name)" type:"path"`
|
||||
Prompt string `short:"p" help:"Run in foreground mode: send a single prompt and print the response"`
|
||||
|
||||
// Agent pool settings (mirrors RunCMD agent flags)
|
||||
APIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"API URL for the agent to call (e.g. http://127.0.0.1:8080)" group:"agents"`
|
||||
APIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"API key for the agent" group:"agents"`
|
||||
DefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for the agent" group:"agents"`
|
||||
MultimodalModel string `env:"LOCALAI_AGENT_POOL_MULTIMODAL_MODEL" help:"Multimodal model for the agent" group:"agents"`
|
||||
TranscriptionModel string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_MODEL" help:"Transcription model for the agent" group:"agents"`
|
||||
TranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Transcription language for the agent" group:"agents"`
|
||||
TTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"TTS model for the agent" group:"agents"`
|
||||
StateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" default:"agents" help:"State directory containing pool.json" type:"path" group:"agents"`
|
||||
Timeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Agent timeout" group:"agents"`
|
||||
EnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service" group:"agents"`
|
||||
EnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"`
|
||||
CustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory" group:"agents"`
|
||||
}
|
||||
|
||||
func (r *AgentRunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.Name == "" && r.Config == "" {
|
||||
return fmt.Errorf("either an agent name or --config must be provided")
|
||||
}
|
||||
|
||||
agentConfig, err := r.loadAgentConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Override agent config fields from CLI flags when provided
|
||||
r.applyOverrides(agentConfig)
|
||||
|
||||
xlog.Info("Starting standalone agent", "name", agentConfig.Name)
|
||||
|
||||
appConfig := r.buildAppConfig()
|
||||
|
||||
poolService, err := services.NewAgentPoolService(appConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent pool service: %w", err)
|
||||
}
|
||||
|
||||
if err := poolService.Start(appConfig.Context); err != nil {
|
||||
return fmt.Errorf("failed to start agent pool service: %w", err)
|
||||
}
|
||||
defer poolService.Stop()
|
||||
|
||||
pool := poolService.Pool()
|
||||
|
||||
// Start the agent standalone (does not persist to pool.json)
|
||||
if err := pool.StartAgentStandalone(agentConfig.Name, agentConfig); err != nil {
|
||||
return fmt.Errorf("failed to start agent %q: %w", agentConfig.Name, err)
|
||||
}
|
||||
|
||||
ag := pool.GetAgent(agentConfig.Name)
|
||||
if ag == nil {
|
||||
return fmt.Errorf("agent %q not found after start", agentConfig.Name)
|
||||
}
|
||||
|
||||
// Foreground mode: send a single prompt and exit
|
||||
if r.Prompt != "" {
|
||||
xlog.Info("Sending prompt to agent", "agent", agentConfig.Name)
|
||||
result := ag.Ask(coreTypes.WithText(r.Prompt))
|
||||
if result == nil {
|
||||
return fmt.Errorf("agent returned no result")
|
||||
}
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("agent error: %w", result.Error)
|
||||
}
|
||||
fmt.Println(result.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Background mode: run until interrupted
|
||||
xlog.Info("Agent running in background mode. Press Ctrl+C to stop.", "agent", agentConfig.Name)
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
|
||||
xlog.Info("Shutting down agent", "agent", agentConfig.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AgentRunCMD) loadAgentConfig() (*state.AgentConfig, error) {
|
||||
// Load from JSON config file
|
||||
if r.Config != "" {
|
||||
data, err := os.ReadFile(r.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %q: %w", r.Config, err)
|
||||
}
|
||||
var cfg state.AgentConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file %q: %w", r.Config, err)
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
return nil, fmt.Errorf("agent config must have a name")
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Load from pool.json by name
|
||||
poolFile := r.StateDir + "/pool.json"
|
||||
data, err := os.ReadFile(poolFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read pool registry %q: %w", poolFile, err)
|
||||
}
|
||||
|
||||
var pool map[string]state.AgentConfig
|
||||
if err := json.Unmarshal(data, &pool); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse pool registry %q: %w", poolFile, err)
|
||||
}
|
||||
|
||||
cfg, ok := pool[r.Name]
|
||||
if !ok {
|
||||
available := make([]string, 0, len(pool))
|
||||
for name := range pool {
|
||||
available = append(available, name)
|
||||
}
|
||||
return nil, fmt.Errorf("agent %q not found in pool registry. Available agents: %v", r.Name, available)
|
||||
}
|
||||
|
||||
cfg.Name = r.Name
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (r *AgentRunCMD) applyOverrides(cfg *state.AgentConfig) {
|
||||
if r.APIURL != "" {
|
||||
cfg.APIURL = r.APIURL
|
||||
}
|
||||
if r.APIKey != "" {
|
||||
cfg.APIKey = r.APIKey
|
||||
}
|
||||
if r.DefaultModel != "" && cfg.Model == "" {
|
||||
cfg.Model = r.DefaultModel
|
||||
}
|
||||
if r.MultimodalModel != "" && cfg.MultimodalModel == "" {
|
||||
cfg.MultimodalModel = r.MultimodalModel
|
||||
}
|
||||
if r.TranscriptionModel != "" && cfg.TranscriptionModel == "" {
|
||||
cfg.TranscriptionModel = r.TranscriptionModel
|
||||
}
|
||||
if r.TranscriptionLanguage != "" && cfg.TranscriptionLanguage == "" {
|
||||
cfg.TranscriptionLanguage = r.TranscriptionLanguage
|
||||
}
|
||||
if r.TTSModel != "" && cfg.TTSModel == "" {
|
||||
cfg.TTSModel = r.TTSModel
|
||||
}
|
||||
}
|
||||
|
||||
func (r *AgentRunCMD) buildAppConfig() *config.ApplicationConfig {
|
||||
appConfig := &config.ApplicationConfig{
|
||||
Context: context.Background(),
|
||||
}
|
||||
appConfig.AgentPool = config.AgentPoolConfig{
|
||||
Enabled: true,
|
||||
APIURL: r.APIURL,
|
||||
APIKey: r.APIKey,
|
||||
DefaultModel: r.DefaultModel,
|
||||
MultimodalModel: r.MultimodalModel,
|
||||
TranscriptionModel: r.TranscriptionModel,
|
||||
TranscriptionLanguage: r.TranscriptionLanguage,
|
||||
TTSModel: r.TTSModel,
|
||||
StateDir: r.StateDir,
|
||||
Timeout: r.Timeout,
|
||||
EnableSkills: r.EnableSkills,
|
||||
EnableLogs: r.EnableLogs,
|
||||
CustomActionsDir: r.CustomActionsDir,
|
||||
}
|
||||
return appConfig
|
||||
}
|
||||
|
||||
type AgentListCMD struct {
|
||||
StateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" default:"agents" help:"State directory containing pool.json" type:"path" group:"agents"`
|
||||
}
|
||||
|
||||
func (r *AgentListCMD) Run(ctx *cliContext.Context) error {
|
||||
poolFile := r.StateDir + "/pool.json"
|
||||
data, err := os.ReadFile(poolFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
fmt.Println("No agents found (pool.json does not exist)")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to read pool registry %q: %w", poolFile, err)
|
||||
}
|
||||
|
||||
var pool map[string]state.AgentConfig
|
||||
if err := json.Unmarshal(data, &pool); err != nil {
|
||||
return fmt.Errorf("failed to parse pool registry %q: %w", poolFile, err)
|
||||
}
|
||||
|
||||
if len(pool) == 0 {
|
||||
fmt.Println("No agents found in pool registry")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Agents in %s:\n", poolFile)
|
||||
for name, cfg := range pool {
|
||||
model := cfg.Model
|
||||
if model == "" {
|
||||
model = "(default)"
|
||||
}
|
||||
desc := cfg.Description
|
||||
if desc == "" {
|
||||
desc = "(no description)"
|
||||
}
|
||||
fmt.Printf(" - %s [model: %s] %s\n", name, model, desc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
214
core/cli/agent_test.go
Normal file
214
core/cli/agent_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
)
|
||||
|
||||
func TestAgentRunCMD_LoadAgentConfigFromFile(t *testing.T) {
|
||||
// Create a temporary agent config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
|
||||
cfg := state.AgentConfig{
|
||||
Name: "test-agent",
|
||||
Model: "llama3",
|
||||
SystemPrompt: "You are a helpful assistant",
|
||||
}
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(configFile, data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadAgentConfig() error: %v", err)
|
||||
}
|
||||
if loaded.Name != "test-agent" {
|
||||
t.Errorf("expected name %q, got %q", "test-agent", loaded.Name)
|
||||
}
|
||||
if loaded.Model != "llama3" {
|
||||
t.Errorf("expected model %q, got %q", "llama3", loaded.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRunCMD_LoadAgentConfigFromPool(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"my-agent": {
|
||||
Model: "gpt-4",
|
||||
Description: "A test agent",
|
||||
SystemPrompt: "Hello",
|
||||
},
|
||||
"other-agent": {
|
||||
Model: "llama3",
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "my-agent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadAgentConfig() error: %v", err)
|
||||
}
|
||||
if loaded.Name != "my-agent" {
|
||||
t.Errorf("expected name %q, got %q", "my-agent", loaded.Name)
|
||||
}
|
||||
if loaded.Model != "gpt-4" {
|
||||
t.Errorf("expected model %q, got %q", "gpt-4", loaded.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRunCMD_LoadAgentConfigFromPool_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"existing-agent": {Model: "llama3"},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "nonexistent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
_, err = cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing agent, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRunCMD_LoadAgentConfigNoNameOrConfig(t *testing.T) {
|
||||
cmd := &AgentRunCMD{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
|
||||
_, err := cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no pool.json exists, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRunCMD_ApplyOverrides(t *testing.T) {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
APIURL: "http://localhost:9090",
|
||||
APIKey: "secret",
|
||||
DefaultModel: "my-model",
|
||||
}
|
||||
|
||||
cmd.applyOverrides(cfg)
|
||||
|
||||
if cfg.APIURL != "http://localhost:9090" {
|
||||
t.Errorf("expected APIURL %q, got %q", "http://localhost:9090", cfg.APIURL)
|
||||
}
|
||||
if cfg.APIKey != "secret" {
|
||||
t.Errorf("expected APIKey %q, got %q", "secret", cfg.APIKey)
|
||||
}
|
||||
if cfg.Model != "my-model" {
|
||||
t.Errorf("expected Model %q, got %q", "my-model", cfg.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRunCMD_ApplyOverridesDoesNotOverwriteExisting(t *testing.T) {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
Model: "existing-model",
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
DefaultModel: "override-model",
|
||||
}
|
||||
|
||||
cmd.applyOverrides(cfg)
|
||||
|
||||
if cfg.Model != "existing-model" {
|
||||
t.Errorf("expected Model to remain %q, got %q", "existing-model", cfg.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRunCMD_LoadConfigMissingName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
|
||||
// Agent config with no name
|
||||
cfg := state.AgentConfig{
|
||||
Model: "llama3",
|
||||
}
|
||||
data, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configFile, data, 0644)
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
_, err := cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for config with no name, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentListCMD_NoPoolFile(t *testing.T) {
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
|
||||
// Should not error, just print "no agents found"
|
||||
err := cmd.Run(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentListCMD_WithAgents(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"agent-a": {Model: "llama3", Description: "First agent"},
|
||||
"agent-b": {Model: "gpt-4"},
|
||||
}
|
||||
data, _ := json.MarshalIndent(pool, "", " ")
|
||||
os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)
|
||||
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
err := cmd.Run(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ var CLI struct {
|
||||
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
|
||||
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
|
||||
Util UtilCMD `cmd:"" help:"Utility commands"`
|
||||
Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"`
|
||||
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
|
||||
Completion CompletionCMD `cmd:"" help:"Generate shell completion scripts for bash, zsh, or fish"`
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ type RunCMD struct {
|
||||
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
|
||||
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
|
||||
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
|
||||
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
|
||||
DisableCSRF bool `env:"LOCALAI_DISABLE_CSRF" help:"Disable CSRF middleware (enabled by default)" group:"api"`
|
||||
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
||||
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
||||
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
|
||||
@@ -121,6 +121,24 @@ type RunCMD struct {
|
||||
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
||||
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
||||
|
||||
// Fine-tuning
|
||||
EnableFineTuning bool `env:"LOCALAI_ENABLE_FINETUNING" default:"false" help:"Enable fine-tuning support" group:"finetuning"`
|
||||
|
||||
// Authentication
|
||||
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
|
||||
AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"`
|
||||
GitHubClientID string `env:"GITHUB_CLIENT_ID" help:"GitHub OAuth App Client ID (auto-enables auth when set)" group:"auth"`
|
||||
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET" help:"GitHub OAuth App Client Secret" group:"auth"`
|
||||
OIDCIssuer string `env:"LOCALAI_OIDC_ISSUER" help:"OIDC issuer URL for auto-discovery" group:"auth"`
|
||||
OIDCClientID string `env:"LOCALAI_OIDC_CLIENT_ID" help:"OIDC Client ID (auto-enables auth)" group:"auth"`
|
||||
OIDCClientSecret string `env:"LOCALAI_OIDC_CLIENT_SECRET" help:"OIDC Client Secret" group:"auth"`
|
||||
AuthBaseURL string `env:"LOCALAI_BASE_URL" help:"Base URL for OAuth callbacks (e.g. http://localhost:8080)" group:"auth"`
|
||||
AuthAdminEmail string `env:"LOCALAI_ADMIN_EMAIL" help:"Email address to auto-promote to admin role" group:"auth"`
|
||||
AuthRegistrationMode string `env:"LOCALAI_REGISTRATION_MODE" default:"open" help:"Registration mode: 'open' (default), 'approval', or 'invite' (invite code required)" group:"auth"`
|
||||
DisableLocalAuth bool `env:"LOCALAI_DISABLE_LOCAL_AUTH" default:"false" help:"Disable local email/password registration and login (use with OAuth/OIDC-only setups)" group:"auth"`
|
||||
AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"`
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
Version bool
|
||||
}
|
||||
|
||||
@@ -165,7 +183,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithBackendGalleries(r.BackendGalleries),
|
||||
config.WithCors(r.CORS),
|
||||
config.WithCorsAllowOrigins(r.CORSAllowOrigins),
|
||||
config.WithCsrf(r.CSRF),
|
||||
config.WithDisableCSRF(r.DisableCSRF),
|
||||
config.WithThreads(r.Threads),
|
||||
config.WithUploadLimitMB(r.UploadLimit),
|
||||
config.WithApiKeys(r.APIKeys),
|
||||
@@ -311,6 +329,51 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.WithAgentHubURL(r.AgentHubURL))
|
||||
}
|
||||
|
||||
// Fine-tuning
|
||||
if r.EnableFineTuning {
|
||||
opts = append(opts, config.EnableFineTuning)
|
||||
}
|
||||
|
||||
// Authentication
|
||||
authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != ""
|
||||
if authEnabled {
|
||||
opts = append(opts, config.WithAuthEnabled(true))
|
||||
|
||||
dbURL := r.AuthDatabaseURL
|
||||
if dbURL == "" {
|
||||
dbURL = filepath.Join(r.DataPath, "database.db")
|
||||
}
|
||||
opts = append(opts, config.WithAuthDatabaseURL(dbURL))
|
||||
|
||||
if r.GitHubClientID != "" {
|
||||
opts = append(opts, config.WithAuthGitHubClientID(r.GitHubClientID))
|
||||
opts = append(opts, config.WithAuthGitHubClientSecret(r.GitHubClientSecret))
|
||||
}
|
||||
if r.OIDCClientID != "" {
|
||||
opts = append(opts, config.WithAuthOIDCIssuer(r.OIDCIssuer))
|
||||
opts = append(opts, config.WithAuthOIDCClientID(r.OIDCClientID))
|
||||
opts = append(opts, config.WithAuthOIDCClientSecret(r.OIDCClientSecret))
|
||||
}
|
||||
if r.AuthBaseURL != "" {
|
||||
opts = append(opts, config.WithAuthBaseURL(r.AuthBaseURL))
|
||||
}
|
||||
if r.AuthAdminEmail != "" {
|
||||
opts = append(opts, config.WithAuthAdminEmail(r.AuthAdminEmail))
|
||||
}
|
||||
if r.AuthRegistrationMode != "" {
|
||||
opts = append(opts, config.WithAuthRegistrationMode(r.AuthRegistrationMode))
|
||||
}
|
||||
if r.DisableLocalAuth {
|
||||
opts = append(opts, config.WithAuthDisableLocalAuth(true))
|
||||
}
|
||||
if r.AuthAPIKeyHMACSecret != "" {
|
||||
opts = append(opts, config.WithAuthAPIKeyHMACSecret(r.AuthAPIKeyHMACSecret))
|
||||
}
|
||||
if r.DefaultAPIKeyExpiry != "" {
|
||||
opts = append(opts, config.WithAuthDefaultAPIKeyExpiry(r.DefaultAPIKeyExpiry))
|
||||
}
|
||||
}
|
||||
|
||||
if idleWatchDog || busyWatchDog {
|
||||
opts = append(opts, config.EnableWatchDog)
|
||||
if idleWatchDog {
|
||||
|
||||
@@ -21,6 +21,7 @@ type ApplicationConfig struct {
|
||||
Debug bool
|
||||
EnableTracing bool
|
||||
TracingMaxItems int
|
||||
EnableBackendLogging bool
|
||||
GeneratedContentDir string
|
||||
|
||||
UploadDir string
|
||||
@@ -29,7 +30,7 @@ type ApplicationConfig struct {
|
||||
DynamicConfigsDir string
|
||||
DynamicConfigsDirPollInterval time.Duration
|
||||
CORS bool
|
||||
CSRF bool
|
||||
DisableCSRF bool
|
||||
PreloadJSONModels string
|
||||
PreloadModelsFromPath string
|
||||
CORSAllowOrigins string
|
||||
@@ -95,6 +96,29 @@ type ApplicationConfig struct {
|
||||
|
||||
// Agent Pool (LocalAGI integration)
|
||||
AgentPool AgentPoolConfig
|
||||
|
||||
// Fine-tuning
|
||||
FineTuning FineTuningConfig
|
||||
|
||||
// Authentication & Authorization
|
||||
Auth AuthConfig
|
||||
}
|
||||
|
||||
// AuthConfig holds configuration for user authentication and authorization.
|
||||
type AuthConfig struct {
|
||||
Enabled bool
|
||||
DatabaseURL string // "postgres://..." or file path for SQLite
|
||||
GitHubClientID string
|
||||
GitHubClientSecret string
|
||||
OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com)
|
||||
OIDCClientID string
|
||||
OIDCClientSecret string
|
||||
BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080")
|
||||
AdminEmail string // auto-promote to admin on login
|
||||
RegistrationMode string // "open", "approval" (default when empty), "invite"
|
||||
DisableLocalAuth bool // disable local email/password registration and login
|
||||
APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty
|
||||
DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry
|
||||
}
|
||||
|
||||
// AgentPoolConfig holds configuration for the LocalAGI agent pool integration.
|
||||
@@ -121,6 +145,11 @@ type AgentPoolConfig struct {
|
||||
AgentHubURL string // default: "https://agenthub.localai.io"
|
||||
}
|
||||
|
||||
// FineTuningConfig holds configuration for fine-tuning support.
|
||||
type FineTuningConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type AppOption func(*ApplicationConfig)
|
||||
|
||||
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
@@ -149,6 +178,8 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
"/favicon.svg",
|
||||
"/readyz",
|
||||
"/healthz",
|
||||
"/api/auth/",
|
||||
"/assets/",
|
||||
},
|
||||
}
|
||||
for _, oo := range o {
|
||||
@@ -193,9 +224,9 @@ func WithP2PNetworkID(s string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithCsrf(b bool) AppOption {
|
||||
func WithDisableCSRF(b bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.CSRF = b
|
||||
o.DisableCSRF = b
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,6 +244,10 @@ var EnableTracing = func(o *ApplicationConfig) {
|
||||
o.EnableTracing = true
|
||||
}
|
||||
|
||||
var EnableBackendLogging = func(o *ApplicationConfig) {
|
||||
o.EnableBackendLogging = true
|
||||
}
|
||||
|
||||
var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
|
||||
o.WatchDog = true
|
||||
o.WatchDogIdle = true
|
||||
@@ -706,6 +741,92 @@ func WithAgentHubURL(url string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// Fine-tuning options
|
||||
|
||||
var EnableFineTuning = func(o *ApplicationConfig) {
|
||||
o.FineTuning.Enabled = true
|
||||
}
|
||||
|
||||
// Auth options
|
||||
|
||||
func WithAuthEnabled(enabled bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.Enabled = enabled
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthDatabaseURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.DatabaseURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthGitHubClientID(clientID string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.GitHubClientID = clientID
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthGitHubClientSecret(clientSecret string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.GitHubClientSecret = clientSecret
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthBaseURL(baseURL string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.BaseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthAdminEmail(email string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.AdminEmail = email
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthRegistrationMode(mode string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.RegistrationMode = mode
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthDisableLocalAuth(disable bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.DisableLocalAuth = disable
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthOIDCIssuer(issuer string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.OIDCIssuer = issuer
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthOIDCClientID(clientID string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.OIDCClientID = clientID
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthOIDCClientSecret(clientSecret string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.OIDCClientSecret = clientSecret
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthAPIKeyHMACSecret(secret string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.APIKeyHMACSecret = secret
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthDefaultAPIKeyExpiry(expiry string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Auth.DefaultAPIKeyExpiry = expiry
|
||||
}
|
||||
}
|
||||
|
||||
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
|
||||
// Some options defined at the application level are going to be passed as defaults for
|
||||
// all the configuration for the models.
|
||||
@@ -743,8 +864,9 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
debug := o.Debug
|
||||
tracingMaxItems := o.TracingMaxItems
|
||||
enableTracing := o.EnableTracing
|
||||
enableBackendLogging := o.EnableBackendLogging
|
||||
cors := o.CORS
|
||||
csrf := o.CSRF
|
||||
csrf := o.DisableCSRF
|
||||
corsAllowOrigins := o.CORSAllowOrigins
|
||||
p2pToken := o.P2PToken
|
||||
p2pNetworkID := o.P2PNetworkID
|
||||
@@ -816,6 +938,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
Debug: &debug,
|
||||
TracingMaxItems: &tracingMaxItems,
|
||||
EnableTracing: &enableTracing,
|
||||
EnableBackendLogging: &enableBackendLogging,
|
||||
CORS: &cors,
|
||||
CSRF: &csrf,
|
||||
CORSAllowOrigins: &corsAllowOrigins,
|
||||
@@ -944,11 +1067,14 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
if settings.TracingMaxItems != nil {
|
||||
o.TracingMaxItems = *settings.TracingMaxItems
|
||||
}
|
||||
if settings.EnableBackendLogging != nil {
|
||||
o.EnableBackendLogging = *settings.EnableBackendLogging
|
||||
}
|
||||
if settings.CORS != nil {
|
||||
o.CORS = *settings.CORS
|
||||
}
|
||||
if settings.CSRF != nil {
|
||||
o.CSRF = *settings.CSRF
|
||||
o.DisableCSRF = *settings.CSRF
|
||||
}
|
||||
if settings.CORSAllowOrigins != nil {
|
||||
o.CORSAllowOrigins = *settings.CORSAllowOrigins
|
||||
|
||||
@@ -26,7 +26,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
F16: true,
|
||||
Debug: true,
|
||||
CORS: true,
|
||||
CSRF: true,
|
||||
DisableCSRF: true,
|
||||
CORSAllowOrigins: "https://example.com",
|
||||
P2PToken: "test-token",
|
||||
P2PNetworkID: "test-network",
|
||||
@@ -377,7 +377,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
appConfig.ApplyRuntimeSettings(rs)
|
||||
|
||||
Expect(appConfig.CORS).To(BeTrue())
|
||||
Expect(appConfig.CSRF).To(BeTrue())
|
||||
Expect(appConfig.DisableCSRF).To(BeTrue())
|
||||
Expect(appConfig.CORSAllowOrigins).To(Equal("https://example.com,https://other.com"))
|
||||
})
|
||||
|
||||
@@ -463,7 +463,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
F16: true,
|
||||
Debug: false,
|
||||
CORS: true,
|
||||
CSRF: false,
|
||||
DisableCSRF: false,
|
||||
CORSAllowOrigins: "https://test.com",
|
||||
P2PToken: "round-trip-token",
|
||||
P2PNetworkID: "round-trip-network",
|
||||
@@ -495,7 +495,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(target.F16).To(Equal(original.F16))
|
||||
Expect(target.Debug).To(Equal(original.Debug))
|
||||
Expect(target.CORS).To(Equal(original.CORS))
|
||||
Expect(target.CSRF).To(Equal(original.CSRF))
|
||||
Expect(target.DisableCSRF).To(Equal(original.DisableCSRF))
|
||||
Expect(target.CORSAllowOrigins).To(Equal(original.CORSAllowOrigins))
|
||||
Expect(target.P2PToken).To(Equal(original.P2PToken))
|
||||
Expect(target.P2PNetworkID).To(Equal(original.P2PNetworkID))
|
||||
|
||||
@@ -36,8 +36,9 @@ type RuntimeSettings struct {
|
||||
ContextSize *int `json:"context_size,omitempty"`
|
||||
F16 *bool `json:"f16,omitempty"`
|
||||
Debug *bool `json:"debug,omitempty"`
|
||||
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
||||
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
||||
EnableTracing *bool `json:"enable_tracing,omitempty"`
|
||||
TracingMaxItems *int `json:"tracing_max_items,omitempty"`
|
||||
EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"`
|
||||
|
||||
// Security/CORS settings
|
||||
CORS *bool `json:"cors,omitempty"`
|
||||
|
||||
173
core/gallery/backend_resolve.go
Normal file
173
core/gallery/backend_resolve.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/xsync"
|
||||
"github.com/mudler/xlog"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config.
|
||||
type modelConfigCacheEntry struct {
|
||||
configMap map[string]interface{}
|
||||
lastUpdated time.Time
|
||||
}
|
||||
|
||||
func (e modelConfigCacheEntry) hasExpired() bool {
|
||||
return e.lastUpdated.Before(time.Now().Add(-1 * time.Hour))
|
||||
}
|
||||
|
||||
// modelConfigCache caches parsed model config maps keyed by URL.
|
||||
var modelConfigCache = xsync.NewSyncedMap[string, modelConfigCacheEntry]()
|
||||
|
||||
// resolveBackend determines the backend for a GalleryModel by checking (in priority order):
|
||||
// 1. Overrides["backend"] — highest priority, same as install-time merge
|
||||
// 2. Inline ConfigFile["backend"] — for models with inline config maps
|
||||
// 3. URL-referenced config file — fetched, parsed, and cached
|
||||
//
|
||||
// The model's URL should already be resolved (local override applied) before calling this.
|
||||
func resolveBackend(m *GalleryModel, basePath string) string {
|
||||
// 1. Overrides take priority (matches install-time mergo.WithOverride behavior)
|
||||
if b, ok := m.Overrides["backend"].(string); ok && b != "" {
|
||||
return b
|
||||
}
|
||||
|
||||
// 2. Inline config_file map
|
||||
if b, ok := m.ConfigFile["backend"].(string); ok && b != "" {
|
||||
return b
|
||||
}
|
||||
|
||||
// 3. Fetch and parse the URL-referenced config
|
||||
if m.URL != "" {
|
||||
configMap := fetchModelConfigMap(m.URL, basePath)
|
||||
if b, ok := configMap["backend"].(string); ok && b != "" {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// fetchModelConfigMap fetches a model config URL, parses the config_file YAML string
|
||||
// inside it, and returns the result as a map. Results are cached for 1 hour.
|
||||
// Local file:// URLs skip the cache so edits are picked up immediately.
|
||||
func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} {
|
||||
// Check cache (skip for file:// URLs so local edits are picked up immediately)
|
||||
isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix)
|
||||
if !isLocal && modelConfigCache.Exists(modelURL) {
|
||||
entry := modelConfigCache.Get(modelURL)
|
||||
if !entry.hasExpired() {
|
||||
return entry.configMap
|
||||
}
|
||||
modelConfigCache.Delete(modelURL)
|
||||
}
|
||||
|
||||
// Reuse existing gallery config fetcher
|
||||
modelConfig, err := GetGalleryConfigFromURL[ModelConfig](modelURL, basePath)
|
||||
if err != nil {
|
||||
xlog.Debug("Failed to fetch model config for backend resolution", "url", modelURL, "error", err)
|
||||
// Cache the failure for remote URLs to avoid repeated fetch attempts
|
||||
if !isLocal {
|
||||
modelConfigCache.Set(modelURL, modelConfigCacheEntry{
|
||||
configMap: map[string]interface{}{},
|
||||
lastUpdated: time.Now(),
|
||||
})
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
// Parse the config_file YAML string into a map
|
||||
configMap := make(map[string]interface{})
|
||||
if modelConfig.ConfigFile != "" {
|
||||
if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil {
|
||||
xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Cache for remote URLs
|
||||
if !isLocal {
|
||||
modelConfigCache.Set(modelURL, modelConfigCacheEntry{
|
||||
configMap: configMap,
|
||||
lastUpdated: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return configMap
|
||||
}
|
||||
|
||||
// prefetchModelConfigs fetches model config URLs in parallel to warm the cache.
|
||||
// This avoids sequential HTTP requests on cold start (~50 unique gallery files).
|
||||
func prefetchModelConfigs(urls []string, basePath string) {
|
||||
const maxConcurrency = 10
|
||||
sem := make(chan struct{}, maxConcurrency)
|
||||
var wg sync.WaitGroup
|
||||
for _, url := range urls {
|
||||
wg.Add(1)
|
||||
go func(u string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
fetchModelConfigMap(u, basePath)
|
||||
}(url)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// resolveModelURLLocally attempts to resolve a github: model URL to a local file://
|
||||
// path when the gallery itself was loaded from a local path. This supports development
|
||||
// workflows where new model files are added locally before being pushed to GitHub.
|
||||
//
|
||||
// For example, if the gallery was loaded from file:///path/to/gallery/index.yaml
|
||||
// and a model references github:mudler/LocalAI/gallery/foo.yaml@master, this will
|
||||
// check if /path/to/gallery/foo.yaml exists locally and return file:///path/to/gallery/foo.yaml.
|
||||
//
|
||||
// This is applied to model.URL in AvailableGalleryModels so that both listing (backend
|
||||
// resolution) and installation use the same resolved URL.
|
||||
func resolveModelURLLocally(modelURL, galleryURL string) string {
|
||||
galleryDir := localGalleryDir(galleryURL)
|
||||
if galleryDir == "" {
|
||||
return modelURL
|
||||
}
|
||||
|
||||
// Only handle github: URLs
|
||||
if !strings.HasPrefix(modelURL, downloader.GithubURI) && !strings.HasPrefix(modelURL, downloader.GithubURI2) {
|
||||
return modelURL
|
||||
}
|
||||
|
||||
// Extract the filename from the github URL
|
||||
// Format: github:org/repo/path/to/file.yaml@branch
|
||||
raw := strings.TrimPrefix(modelURL, downloader.GithubURI2)
|
||||
raw = strings.TrimPrefix(raw, downloader.GithubURI)
|
||||
// Remove @branch suffix
|
||||
if idx := strings.LastIndex(raw, "@"); idx >= 0 {
|
||||
raw = raw[:idx]
|
||||
}
|
||||
filename := filepath.Base(raw)
|
||||
|
||||
localPath := filepath.Join(galleryDir, filename)
|
||||
if _, err := os.Stat(localPath); err == nil {
|
||||
return downloader.LocalPrefix + localPath
|
||||
}
|
||||
|
||||
return modelURL
|
||||
}
|
||||
|
||||
// localGalleryDir returns the directory of a gallery URL if it's local, or "" if remote.
|
||||
func localGalleryDir(galleryURL string) string {
|
||||
if strings.HasPrefix(galleryURL, downloader.LocalPrefix) {
|
||||
return filepath.Dir(strings.TrimPrefix(galleryURL, downloader.LocalPrefix))
|
||||
}
|
||||
// Plain path (no scheme) that exists on disk
|
||||
if !strings.Contains(galleryURL, "://") && !strings.HasPrefix(galleryURL, downloader.GithubURI) {
|
||||
if info, err := os.Stat(galleryURL); err == nil && !info.IsDir() {
|
||||
return filepath.Dir(galleryURL)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -218,6 +218,36 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Resolve model URLs locally (for local galleries) and collect unique
|
||||
// URLs that need fetching for backend resolution.
|
||||
uniqueURLs := map[string]struct{}{}
|
||||
for _, m := range galleryModels {
|
||||
if m.URL != "" {
|
||||
m.URL = resolveModelURLLocally(m.URL, gallery.URL)
|
||||
}
|
||||
if m.Backend == "" && m.URL != "" {
|
||||
uniqueURLs[m.URL] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-warm cache with parallel fetches to avoid sequential HTTP
|
||||
// requests on cold start (~50 unique gallery config files).
|
||||
if len(uniqueURLs) > 0 {
|
||||
urls := make([]string, 0, len(uniqueURLs))
|
||||
for u := range uniqueURLs {
|
||||
urls = append(urls, u)
|
||||
}
|
||||
prefetchModelConfigs(urls, systemState.Model.ModelsPath)
|
||||
}
|
||||
|
||||
// Resolve backends from warm cache.
|
||||
for _, m := range galleryModels {
|
||||
if m.Backend == "" {
|
||||
m.Backend = resolveBackend(m, systemState.Model.ModelsPath)
|
||||
}
|
||||
}
|
||||
|
||||
models = append(models, galleryModels...)
|
||||
}
|
||||
|
||||
|
||||
205
core/gallery/importers/local.go
Normal file
205
core/gallery/importers/local.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ImportLocalPath scans a local directory for exported model files and produces
|
||||
// a config.ModelConfig with the correct backend, model path, and options.
|
||||
// Paths in the returned config are relative to modelsPath when possible so that
|
||||
// the YAML config remains portable.
|
||||
//
|
||||
// Detection order:
|
||||
// 1. GGUF files (*.gguf) — uses llama-cpp backend
|
||||
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
|
||||
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
|
||||
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
// Make paths relative to the models directory (parent of dirPath)
|
||||
// so config YAML stays portable.
|
||||
modelsDir := filepath.Dir(dirPath)
|
||||
relPath := func(absPath string) string {
|
||||
if rel, err := filepath.Rel(modelsDir, absPath); err == nil {
|
||||
return rel
|
||||
}
|
||||
return absPath
|
||||
}
|
||||
|
||||
// 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention)
|
||||
ggufFile := findGGUF(dirPath)
|
||||
if ggufFile == "" {
|
||||
ggufSubdir := dirPath + "_gguf"
|
||||
ggufFile = findGGUF(ggufSubdir)
|
||||
}
|
||||
if ggufFile != "" {
|
||||
xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Options: []string{"use_jinja:true"},
|
||||
}
|
||||
cfg.Model = relPath(ggufFile)
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.Description = buildDescription(dirPath, "GGUF")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 2. LoRA adapter: look for adapter_config.json
|
||||
|
||||
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
|
||||
if fileExists(adapterConfigPath) {
|
||||
xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath)
|
||||
baseModel := readBaseModel(dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = baseModel
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
|
||||
cfg.Description = buildDescription(dirPath, "LoRA adapter")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json
|
||||
if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) {
|
||||
xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath)
|
||||
baseModel := readBaseModel(dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = baseModel
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
|
||||
cfg.Description = buildDescription(dirPath, "LoRA adapter")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 3. Merged model: *.safetensors or pytorch_model*.bin + config.json
|
||||
if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) {
|
||||
xlog.Info("ImportLocalPath: detected merged model", "path", dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = relPath(dirPath)
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.Description = buildDescription(dirPath, "merged model")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not detect model format in directory %s", dirPath)
|
||||
}
|
||||
|
||||
// findGGUF returns the path to the first .gguf file found in dir, or "".
|
||||
func findGGUF(dir string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") {
|
||||
return filepath.Join(dir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
|
||||
func readBaseModel(dirPath string) string {
|
||||
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
|
||||
var ac map[string]any
|
||||
if json.Unmarshal(data, &ac) == nil {
|
||||
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
|
||||
return bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try export_metadata.json → base_model (Unsloth writes this)
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
|
||||
var meta map[string]any
|
||||
if json.Unmarshal(data, &meta) == nil {
|
||||
if bm, ok := meta["base_model"].(string); ok && bm != "" {
|
||||
return bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// buildDescription creates a human-readable description using available metadata.
|
||||
func buildDescription(dirPath, formatLabel string) string {
|
||||
base := ""
|
||||
|
||||
// Try adapter_config.json
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
|
||||
var ac map[string]any
|
||||
if json.Unmarshal(data, &ac) == nil {
|
||||
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
|
||||
base = bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try export_metadata.json
|
||||
if base == "" {
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
|
||||
var meta map[string]any
|
||||
if json.Unmarshal(data, &meta) == nil {
|
||||
if bm, ok := meta["base_model"].(string); ok && bm != "" {
|
||||
base = bm
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if base != "" {
|
||||
return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel)
|
||||
}
|
||||
return fmt.Sprintf("Fine-tuned model (%s)", formatLabel)
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
info, err := os.Stat(path)
|
||||
return err == nil && !info.IsDir()
|
||||
}
|
||||
|
||||
func hasFileWithSuffix(dir, suffix string) bool {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasFileWithPrefix(dir, prefix string) bool {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
148
core/gallery/importers/local_test.go
Normal file
148
core/gallery/importers/local_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
)
|
||||
|
||||
var _ = Describe("ImportLocalPath", func() {
|
||||
var tmpDir string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "importers-local-test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
Context("GGUF detection", func() {
|
||||
It("detects a GGUF file in the directory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "my-model")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("llama-cpp"))
|
||||
Expect(cfg.Model).To(ContainSubstring(".gguf"))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat"))
|
||||
Expect(cfg.Options).To(ContainElement("use_jinja:true"))
|
||||
})
|
||||
|
||||
It("detects GGUF in _gguf subdirectory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "my-model")
|
||||
ggufDir := modelDir + "_gguf"
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("llama-cpp"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("LoRA adapter detection", func() {
|
||||
It("detects LoRA adapter via adapter_config.json", func() {
|
||||
modelDir := filepath.Join(tmpDir, "lora-model")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{
|
||||
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
||||
"peft_type": "LORA",
|
||||
}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "lora-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf"))
|
||||
Expect(cfg.LLMConfig.LoraAdapter).To(Equal("lora-model"))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
})
|
||||
|
||||
It("reads base model from export_metadata.json as fallback", func() {
|
||||
modelDir := filepath.Join(tmpDir, "lora-unsloth")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{"peft_type": "LORA"}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"}
|
||||
data, _ = json.Marshal(metadata)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Merged model detection", func() {
|
||||
It("detects merged model with safetensors + config.json", func() {
|
||||
modelDir := filepath.Join(tmpDir, "merged")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "merged")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal("merged"))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
})
|
||||
|
||||
It("detects merged model with pytorch_model files", func() {
|
||||
modelDir := filepath.Join(tmpDir, "merged-pt")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "merged-pt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal("merged-pt"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("fallback", func() {
|
||||
It("returns error for empty directory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "empty")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
_, err := importers.ImportLocalPath(modelDir, "empty")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("could not detect model format"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("description", func() {
|
||||
It("includes base model name in description", func() {
|
||||
modelDir := filepath.Join(tmpDir, "desc-test")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{
|
||||
"base_model_name_or_path": "TinyLlama/TinyLlama-1.1B",
|
||||
}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "desc-test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B"))
|
||||
Expect(cfg.Description).To(ContainSubstring("Fine-tuned from"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -19,4 +19,7 @@ type Metadata struct {
|
||||
Gallery config.Gallery `json:"gallery,omitempty" yaml:"gallery,omitempty"`
|
||||
// Installed is used to indicate if the model is installed or not
|
||||
Installed bool `json:"installed,omitempty" yaml:"installed,omitempty"`
|
||||
// Backend is the resolved backend engine for this model (e.g. "llama-cpp").
|
||||
// Populated at load time from overrides, inline config, or the URL-referenced config file.
|
||||
Backend string `json:"backend,omitempty" yaml:"backend,omitempty"`
|
||||
}
|
||||
|
||||
102
core/http/app.go
102
core/http/app.go
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
@@ -170,11 +171,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(e)
|
||||
|
||||
// Get key auth middleware
|
||||
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||
}
|
||||
// Build auth middleware: use the new auth.Middleware when auth is enabled or
|
||||
// as a unified replacement for the legacy key-auth middleware.
|
||||
authMiddleware := auth.Middleware(application.AuthDB(), application.ApplicationConfig())
|
||||
|
||||
// Favicon handler
|
||||
e.GET("/favicon.svg", func(c echo.Context) error {
|
||||
@@ -209,8 +208,20 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.Static("/generated-videos", videoPath)
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
|
||||
e.Use(keyAuthMiddleware)
|
||||
// Initialize usage recording when auth DB is available
|
||||
if application.AuthDB() != nil {
|
||||
httpMiddleware.InitUsageRecorder(application.AuthDB())
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is
|
||||
// the role of the exempt-path logic inside the middleware.
|
||||
e.Use(authMiddleware)
|
||||
|
||||
// Feature and model access control (after auth middleware, before routes)
|
||||
if application.AuthDB() != nil {
|
||||
e.Use(auth.RequireRouteFeature(application.AuthDB()))
|
||||
e.Use(auth.RequireModelAccess(application.AuthDB()))
|
||||
}
|
||||
|
||||
// CORS middleware
|
||||
if application.ApplicationConfig().CORS {
|
||||
@@ -223,14 +234,63 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.Use(middleware.CORS())
|
||||
}
|
||||
|
||||
// CSRF middleware
|
||||
if application.ApplicationConfig().CSRF {
|
||||
xlog.Debug("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
||||
e.Use(middleware.CSRF())
|
||||
// CSRF middleware (enabled by default, disable with LOCALAI_DISABLE_CSRF=true)
|
||||
//
|
||||
// Protection relies on Echo's Sec-Fetch-Site header check (supported by all
|
||||
// modern browsers). The legacy cookie+token approach is removed because
|
||||
// Echo's Sec-Fetch-Site short-circuit never sets the cookie, so the frontend
|
||||
// could never read a token to send back.
|
||||
if !application.ApplicationConfig().DisableCSRF {
|
||||
xlog.Debug("Enabling CSRF middleware (Sec-Fetch-Site mode)")
|
||||
e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
// Skip CSRF for API clients using auth headers (may be cross-origin)
|
||||
if c.Request().Header.Get("Authorization") != "" {
|
||||
return true
|
||||
}
|
||||
if c.Request().Header.Get("x-api-key") != "" || c.Request().Header.Get("xi-api-key") != "" {
|
||||
return true
|
||||
}
|
||||
// Skip when Sec-Fetch-Site header is absent (older browsers, reverse
|
||||
// proxies that strip the header). The SameSite=Lax cookie attribute
|
||||
// provides baseline CSRF protection for these clients.
|
||||
if c.Request().Header.Get("Sec-Fetch-Site") == "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
// Allow same-site requests (subdomains / different ports) in addition
|
||||
// to same-origin which Echo already permits by default.
|
||||
AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
|
||||
secFetchSite := c.Request().Header.Get("Sec-Fetch-Site")
|
||||
if secFetchSite == "same-site" {
|
||||
return true, nil
|
||||
}
|
||||
// cross-site: block
|
||||
return false, nil
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Admin middleware: enforces admin role when auth is enabled, no-op otherwise
|
||||
var adminMiddleware echo.MiddlewareFunc
|
||||
if application.AuthDB() != nil {
|
||||
adminMiddleware = auth.RequireAdmin()
|
||||
} else {
|
||||
adminMiddleware = auth.NoopMiddleware()
|
||||
}
|
||||
|
||||
// Feature middlewares: per-feature access control
|
||||
agentsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureAgents)
|
||||
skillsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureSkills)
|
||||
collectionsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureCollections)
|
||||
mcpJobsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCPJobs)
|
||||
|
||||
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Register auth routes (login, callback, API keys, user management)
|
||||
routes.RegisterAuthRoutes(e, application)
|
||||
|
||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||
@@ -239,14 +299,26 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
opcache = services.NewOpCache(application.GalleryService())
|
||||
}
|
||||
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application)
|
||||
routes.RegisterAgentPoolRoutes(e, application)
|
||||
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw)
|
||||
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
||||
// Fine-tuning routes
|
||||
if application.ApplicationConfig().FineTuning.Enabled {
|
||||
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||
ftService := services.NewFineTuneService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
)
|
||||
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||
}
|
||||
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
||||
|
||||
// Serve React SPA from / with SPA fallback via 404 handler
|
||||
reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist")
|
||||
|
||||
@@ -428,8 +428,10 @@ var _ = Describe("API test", func() {
|
||||
"X-Forwarded-Prefix": {"/myprefix/"},
|
||||
})
|
||||
Expect(err).To(BeNil(), "error")
|
||||
Expect(sc).To(Equal(401), "status code")
|
||||
Expect(sc).To(Equal(200), "status code")
|
||||
// Non-API paths pass through to the React SPA (which handles login client-side)
|
||||
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body")
|
||||
Expect(string(body)).To(ContainSubstring(`<div id="root">`), "should serve React SPA")
|
||||
})
|
||||
|
||||
It("Should support reverse-proxy when authenticated", func() {
|
||||
|
||||
121
core/http/auth/apikeys.go
Normal file
121
core/http/auth/apikeys.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyPrefix = "lai-"
|
||||
apiKeyRandBytes = 32 // 32 bytes = 64 hex chars
|
||||
keyPrefixLen = 8 // display prefix length (from the random part)
|
||||
)
|
||||
|
||||
// GenerateAPIKey generates a new API key. Returns the plaintext key,
|
||||
// its HMAC-SHA256 hash, and a display prefix.
|
||||
func GenerateAPIKey(hmacSecret string) (plaintext, hash, prefix string, err error) {
|
||||
b := make([]byte, apiKeyRandBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to generate API key: %w", err)
|
||||
}
|
||||
|
||||
randHex := hex.EncodeToString(b)
|
||||
plaintext = apiKeyPrefix + randHex
|
||||
hash = HashAPIKey(plaintext, hmacSecret)
|
||||
prefix = plaintext[:len(apiKeyPrefix)+keyPrefixLen]
|
||||
|
||||
return plaintext, hash, prefix, nil
|
||||
}
|
||||
|
||||
// HashAPIKey returns the HMAC-SHA256 hex digest of the given plaintext key.
|
||||
// If hmacSecret is empty, falls back to plain SHA-256 for backward compatibility.
|
||||
func HashAPIKey(plaintext, hmacSecret string) string {
|
||||
if hmacSecret == "" {
|
||||
h := sha256.Sum256([]byte(plaintext))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
mac := hmac.New(sha256.New, []byte(hmacSecret))
|
||||
mac.Write([]byte(plaintext))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// CreateAPIKey generates and stores a new API key for the given user.
|
||||
// Returns the plaintext key (shown once) and the database record.
|
||||
func CreateAPIKey(db *gorm.DB, userID, name, role, hmacSecret string, expiresAt *time.Time) (string, *UserAPIKey, error) {
|
||||
plaintext, hash, prefix, err := GenerateAPIKey(hmacSecret)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
record := &UserAPIKey{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
KeyHash: hash,
|
||||
KeyPrefix: prefix,
|
||||
Role: role,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
if err := db.Create(record).Error; err != nil {
|
||||
return "", nil, fmt.Errorf("failed to store API key: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, record, nil
|
||||
}
|
||||
|
||||
// ValidateAPIKey looks up an API key by hashing the plaintext and searching
|
||||
// the database. Returns the key record if found, or an error.
|
||||
// Updates LastUsed on successful validation.
|
||||
func ValidateAPIKey(db *gorm.DB, plaintext, hmacSecret string) (*UserAPIKey, error) {
|
||||
hash := HashAPIKey(plaintext, hmacSecret)
|
||||
|
||||
var key UserAPIKey
|
||||
if err := db.Preload("User").Where("key_hash = ?", hash).First(&key).Error; err != nil {
|
||||
return nil, fmt.Errorf("invalid API key")
|
||||
}
|
||||
|
||||
if key.ExpiresAt != nil && time.Now().After(*key.ExpiresAt) {
|
||||
return nil, fmt.Errorf("API key expired")
|
||||
}
|
||||
|
||||
if key.User.Status != StatusActive {
|
||||
return nil, fmt.Errorf("user account is not active")
|
||||
}
|
||||
|
||||
// Update LastUsed
|
||||
now := time.Now()
|
||||
db.Model(&key).Update("last_used", now)
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// ListAPIKeys returns all API keys for the given user (without plaintext).
|
||||
func ListAPIKeys(db *gorm.DB, userID string) ([]UserAPIKey, error) {
|
||||
var keys []UserAPIKey
|
||||
if err := db.Where("user_id = ?", userID).Order("created_at DESC").Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// RevokeAPIKey deletes an API key. Only the owner can revoke their own key.
|
||||
func RevokeAPIKey(db *gorm.DB, keyID, userID string) error {
|
||||
result := db.Where("id = ? AND user_id = ?", keyID, userID).Delete(&UserAPIKey{})
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("API key not found or not owned by user")
|
||||
}
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// CleanExpiredAPIKeys removes all API keys that have passed their expiry time.
|
||||
func CleanExpiredAPIKeys(db *gorm.DB) error {
|
||||
return db.Where("expires_at IS NOT NULL AND expires_at < ?", time.Now()).Delete(&UserAPIKey{}).Error
|
||||
}
|
||||
212
core/http/auth/apikeys_test.go
Normal file
212
core/http/auth/apikeys_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ = Describe("API Keys", func() {
|
||||
var (
|
||||
db *gorm.DB
|
||||
user *auth.User
|
||||
)
|
||||
|
||||
// Use empty HMAC secret for tests (falls back to plain SHA-256)
|
||||
hmacSecret := ""
|
||||
|
||||
BeforeEach(func() {
|
||||
db = testDB()
|
||||
user = createTestUser(db, "apikey@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
})
|
||||
|
||||
Describe("GenerateAPIKey", func() {
|
||||
It("returns key with 'lai-' prefix", func() {
|
||||
plaintext, _, _, err := auth.GenerateAPIKey(hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(plaintext).To(HavePrefix("lai-"))
|
||||
})
|
||||
|
||||
It("returns consistent hash for same plaintext", func() {
|
||||
plaintext, hash, _, err := auth.GenerateAPIKey(hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(auth.HashAPIKey(plaintext, hmacSecret)).To(Equal(hash))
|
||||
})
|
||||
|
||||
It("returns prefix for display", func() {
|
||||
_, _, prefix, err := auth.GenerateAPIKey(hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(prefix).To(HavePrefix("lai-"))
|
||||
Expect(len(prefix)).To(Equal(12)) // "lai-" + 8 chars
|
||||
})
|
||||
|
||||
It("generates unique keys", func() {
|
||||
key1, _, _, _ := auth.GenerateAPIKey(hmacSecret)
|
||||
key2, _, _, _ := auth.GenerateAPIKey(hmacSecret)
|
||||
Expect(key1).ToNot(Equal(key2))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("CreateAPIKey", func() {
|
||||
It("stores hashed key in DB", func() {
|
||||
plaintext, record, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser, hmacSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(plaintext).To(HavePrefix("lai-"))
|
||||
Expect(record.KeyHash).To(Equal(auth.HashAPIKey(plaintext, hmacSecret)))
|
||||
})
|
||||
|
||||
It("does not store plaintext in DB", func() {
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser, hmacSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var keys []auth.UserAPIKey
|
||||
db.Find(&keys)
|
||||
for _, k := range keys {
|
||||
Expect(k.KeyHash).ToNot(Equal(plaintext))
|
||||
Expect(strings.Contains(k.KeyHash, "lai-")).To(BeFalse())
|
||||
}
|
||||
})
|
||||
|
||||
It("inherits role from parameter", func() {
|
||||
_, record, err := auth.CreateAPIKey(db, user.ID, "admin key", auth.RoleAdmin, hmacSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(record.Role).To(Equal(auth.RoleAdmin))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ValidateAPIKey", func() {
|
||||
It("returns UserAPIKey for valid key", func() {
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "valid key", auth.RoleUser, hmacSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, err := auth.ValidateAPIKey(db, plaintext, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(found).ToNot(BeNil())
|
||||
Expect(found.UserID).To(Equal(user.ID))
|
||||
})
|
||||
|
||||
It("returns error for invalid key", func() {
|
||||
_, err := auth.ValidateAPIKey(db, "lai-invalidkey12345678901234567890", hmacSecret)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("updates LastUsed timestamp", func() {
|
||||
plaintext, record, err := auth.CreateAPIKey(db, user.ID, "used key", auth.RoleUser, hmacSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(record.LastUsed).To(BeNil())
|
||||
|
||||
_, err = auth.ValidateAPIKey(db, plaintext, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var updated auth.UserAPIKey
|
||||
db.First(&updated, "id = ?", record.ID)
|
||||
Expect(updated.LastUsed).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("loads associated user", func() {
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "with user", auth.RoleUser, hmacSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, err := auth.ValidateAPIKey(db, plaintext, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(found.User.ID).To(Equal(user.ID))
|
||||
Expect(found.User.Email).To(Equal("apikey@example.com"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ListAPIKeys", func() {
|
||||
It("returns all keys for the user", func() {
|
||||
auth.CreateAPIKey(db, user.ID, "key1", auth.RoleUser, hmacSecret, nil)
|
||||
auth.CreateAPIKey(db, user.ID, "key2", auth.RoleUser, hmacSecret, nil)
|
||||
|
||||
keys, err := auth.ListAPIKeys(db, user.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(keys).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("does not return other users' keys", func() {
|
||||
other := createTestUser(db, "other@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
auth.CreateAPIKey(db, user.ID, "my key", auth.RoleUser, hmacSecret, nil)
|
||||
auth.CreateAPIKey(db, other.ID, "other key", auth.RoleUser, hmacSecret, nil)
|
||||
|
||||
keys, err := auth.ListAPIKeys(db, user.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(keys).To(HaveLen(1))
|
||||
Expect(keys[0].Name).To(Equal("my key"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with HMAC secret", func() {
|
||||
hmacSecretVal := "test-hmac-secret-456"
|
||||
|
||||
It("generates different hash than empty secret", func() {
|
||||
plaintext, _, _, err := auth.GenerateAPIKey("")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
hashEmpty := auth.HashAPIKey(plaintext, "")
|
||||
hashHMAC := auth.HashAPIKey(plaintext, hmacSecretVal)
|
||||
Expect(hashEmpty).ToNot(Equal(hashHMAC))
|
||||
})
|
||||
|
||||
It("round-trips CreateAPIKey and ValidateAPIKey with HMAC secret", func() {
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "hmac key", auth.RoleUser, hmacSecretVal, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, err := auth.ValidateAPIKey(db, plaintext, hmacSecretVal)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(found).ToNot(BeNil())
|
||||
Expect(found.UserID).To(Equal(user.ID))
|
||||
})
|
||||
|
||||
It("does not validate with wrong HMAC secret", func() {
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "hmac key2", auth.RoleUser, hmacSecretVal, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = auth.ValidateAPIKey(db, plaintext, "wrong-secret")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("does not validate key created with empty secret using non-empty secret", func() {
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "empty-secret key", auth.RoleUser, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = auth.ValidateAPIKey(db, plaintext, hmacSecretVal)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("does not validate key created with non-empty secret using empty secret", func() {
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "nonempty-secret key", auth.RoleUser, hmacSecretVal, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = auth.ValidateAPIKey(db, plaintext, "")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RevokeAPIKey", func() {
|
||||
It("deletes the key record", func() {
|
||||
plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser, hmacSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = auth.RevokeAPIKey(db, record.ID, user.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = auth.ValidateAPIKey(db, plaintext, hmacSecret)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("only allows owner to revoke their own key", func() {
|
||||
_, record, err := auth.CreateAPIKey(db, user.ID, "mine", auth.RoleUser, hmacSecret, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
other := createTestUser(db, "attacker@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
err = auth.RevokeAPIKey(db, record.ID, other.ID)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
15
core/http/auth/auth_suite_test.go
Normal file
15
core/http/auth/auth_suite_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Auth Suite")
|
||||
}
|
||||
49
core/http/auth/db.go
Normal file
49
core/http/auth/db.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// InitDB initializes the auth database. If databaseURL starts with "postgres://"
|
||||
// or "postgresql://", it connects to PostgreSQL; otherwise it treats the value
|
||||
// as a SQLite file path (use ":memory:" for in-memory).
|
||||
// SQLite support requires building with the "auth" build tag (CGO).
|
||||
func InitDB(databaseURL string) (*gorm.DB, error) {
|
||||
var dialector gorm.Dialector
|
||||
|
||||
if strings.HasPrefix(databaseURL, "postgres://") || strings.HasPrefix(databaseURL, "postgresql://") {
|
||||
dialector = postgres.Open(databaseURL)
|
||||
} else {
|
||||
d, err := openSQLiteDialector(databaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dialector = d
|
||||
}
|
||||
|
||||
db, err := gorm.Open(dialector, &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open auth database: %w", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&User{}, &Session{}, &UserAPIKey{}, &UsageRecord{}, &UserPermission{}, &InviteCode{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate auth tables: %w", err)
|
||||
}
|
||||
|
||||
// Create composite index on users(provider, subject) for fast OAuth lookups
|
||||
if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil {
|
||||
// Ignore error on postgres if index already exists
|
||||
if !strings.Contains(err.Error(), "already exists") {
|
||||
return nil, fmt.Errorf("failed to create composite index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
13
core/http/auth/db_nosqlite.go
Normal file
13
core/http/auth/db_nosqlite.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !auth
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func openSQLiteDialector(path string) (gorm.Dialector, error) {
|
||||
return nil, fmt.Errorf("SQLite auth database requires building with -tags auth (CGO); use DATABASE_URL with PostgreSQL instead")
|
||||
}
|
||||
12
core/http/auth/db_sqlite.go
Normal file
12
core/http/auth/db_sqlite.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build auth
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func openSQLiteDialector(path string) (gorm.Dialector, error) {
|
||||
return sqlite.Open(path), nil
|
||||
}
|
||||
53
core/http/auth/db_test.go
Normal file
53
core/http/auth/db_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("InitDB", func() {
|
||||
Context("SQLite", func() {
|
||||
It("creates all tables with in-memory SQLite", func() {
|
||||
db, err := auth.InitDB(":memory:")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(db).ToNot(BeNil())
|
||||
|
||||
// Verify tables exist
|
||||
Expect(db.Migrator().HasTable(&auth.User{})).To(BeTrue())
|
||||
Expect(db.Migrator().HasTable(&auth.Session{})).To(BeTrue())
|
||||
Expect(db.Migrator().HasTable(&auth.UserAPIKey{})).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is idempotent - running twice does not error", func() {
|
||||
db, err := auth.InitDB(":memory:")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Re-migrate on same DB should succeed
|
||||
err = db.AutoMigrate(&auth.User{}, &auth.Session{}, &auth.UserAPIKey{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("creates composite index on users(provider, subject)", func() {
|
||||
db, err := auth.InitDB(":memory:")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Insert a user to verify the index doesn't prevent normal operations
|
||||
user := &auth.User{
|
||||
ID: "test-1",
|
||||
Provider: auth.ProviderGitHub,
|
||||
Subject: "12345",
|
||||
Role: "admin",
|
||||
Status: auth.StatusActive,
|
||||
}
|
||||
Expect(db.Create(user).Error).ToNot(HaveOccurred())
|
||||
|
||||
// Query using the indexed columns should work
|
||||
var found auth.User
|
||||
Expect(db.Where("provider = ? AND subject = ?", auth.ProviderGitHub, "12345").First(&found).Error).ToNot(HaveOccurred())
|
||||
Expect(found.ID).To(Equal("test-1"))
|
||||
})
|
||||
})
|
||||
})
|
||||
144
core/http/auth/features.go
Normal file
144
core/http/auth/features.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package auth
|
||||
|
||||
// RouteFeature maps a route pattern + HTTP method to a required feature.
|
||||
type RouteFeature struct {
|
||||
Method string // "POST", "GET", "*" (any)
|
||||
Pattern string // Echo route pattern, e.g. "/v1/chat/completions"
|
||||
Feature string // Feature constant, e.g. FeatureChat
|
||||
}
|
||||
|
||||
// RouteFeatureRegistry is the single source of truth for endpoint -> feature mappings.
|
||||
// To gate a new endpoint, add an entry here -- no other file changes needed.
|
||||
var RouteFeatureRegistry = []RouteFeature{
|
||||
// Chat / Completions
|
||||
{"POST", "/v1/chat/completions", FeatureChat},
|
||||
{"POST", "/chat/completions", FeatureChat},
|
||||
{"POST", "/v1/completions", FeatureChat},
|
||||
{"POST", "/completions", FeatureChat},
|
||||
{"POST", "/v1/engines/:model/completions", FeatureChat},
|
||||
{"POST", "/v1/edits", FeatureChat},
|
||||
{"POST", "/edits", FeatureChat},
|
||||
|
||||
// Anthropic
|
||||
{"POST", "/v1/messages", FeatureChat},
|
||||
{"POST", "/messages", FeatureChat},
|
||||
|
||||
// Open Responses
|
||||
{"POST", "/v1/responses", FeatureChat},
|
||||
{"POST", "/responses", FeatureChat},
|
||||
{"GET", "/v1/responses", FeatureChat},
|
||||
{"GET", "/responses", FeatureChat},
|
||||
|
||||
// Embeddings
|
||||
{"POST", "/v1/embeddings", FeatureEmbeddings},
|
||||
{"POST", "/embeddings", FeatureEmbeddings},
|
||||
{"POST", "/v1/engines/:model/embeddings", FeatureEmbeddings},
|
||||
|
||||
// Images
|
||||
{"POST", "/v1/images/generations", FeatureImages},
|
||||
{"POST", "/images/generations", FeatureImages},
|
||||
{"POST", "/v1/images/inpainting", FeatureImages},
|
||||
{"POST", "/images/inpainting", FeatureImages},
|
||||
|
||||
// Audio transcription
|
||||
{"POST", "/v1/audio/transcriptions", FeatureAudioTranscription},
|
||||
{"POST", "/audio/transcriptions", FeatureAudioTranscription},
|
||||
|
||||
// Audio speech / TTS
|
||||
{"POST", "/v1/audio/speech", FeatureAudioSpeech},
|
||||
{"POST", "/audio/speech", FeatureAudioSpeech},
|
||||
{"POST", "/tts", FeatureAudioSpeech},
|
||||
{"POST", "/v1/text-to-speech/:voice-id", FeatureAudioSpeech},
|
||||
|
||||
// VAD
|
||||
{"POST", "/vad", FeatureVAD},
|
||||
{"POST", "/v1/vad", FeatureVAD},
|
||||
|
||||
// Detection
|
||||
{"POST", "/v1/detection", FeatureDetection},
|
||||
|
||||
// Video
|
||||
{"POST", "/video", FeatureVideo},
|
||||
|
||||
// Sound generation
|
||||
{"POST", "/v1/sound-generation", FeatureSound},
|
||||
|
||||
// Realtime
|
||||
{"GET", "/v1/realtime", FeatureRealtime},
|
||||
{"POST", "/v1/realtime/sessions", FeatureRealtime},
|
||||
{"POST", "/v1/realtime/transcription_session", FeatureRealtime},
|
||||
{"POST", "/v1/realtime/calls", FeatureRealtime},
|
||||
|
||||
// MCP
|
||||
{"POST", "/v1/mcp/chat/completions", FeatureMCP},
|
||||
{"POST", "/mcp/v1/chat/completions", FeatureMCP},
|
||||
{"POST", "/mcp/chat/completions", FeatureMCP},
|
||||
|
||||
// Tokenize
|
||||
{"POST", "/v1/tokenize", FeatureTokenize},
|
||||
|
||||
// Rerank
|
||||
{"POST", "/v1/rerank", FeatureRerank},
|
||||
|
||||
// Stores
|
||||
{"POST", "/stores/set", FeatureStores},
|
||||
{"POST", "/stores/delete", FeatureStores},
|
||||
{"POST", "/stores/get", FeatureStores},
|
||||
{"POST", "/stores/find", FeatureStores},
|
||||
|
||||
// Fine-tuning
|
||||
{"POST", "/api/fine-tuning/jobs", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
|
||||
{"POST", "/api/fine-tuning/jobs/:id/stop", FeatureFineTuning},
|
||||
{"DELETE", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id/progress", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id/checkpoints", FeatureFineTuning},
|
||||
{"POST", "/api/fine-tuning/jobs/:id/export", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning},
|
||||
{"POST", "/api/fine-tuning/datasets", FeatureFineTuning},
|
||||
}
|
||||
|
||||
// FeatureMeta describes a feature for the admin API/UI.
|
||||
type FeatureMeta struct {
|
||||
Key string `json:"key"`
|
||||
Label string `json:"label"`
|
||||
DefaultValue bool `json:"default"`
|
||||
}
|
||||
|
||||
// AgentFeatureMetas returns metadata for agent features.
|
||||
func AgentFeatureMetas() []FeatureMeta {
|
||||
return []FeatureMeta{
|
||||
{FeatureAgents, "Agents", false},
|
||||
{FeatureSkills, "Skills", false},
|
||||
{FeatureCollections, "Collections", false},
|
||||
{FeatureMCPJobs, "MCP CI Jobs", false},
|
||||
}
|
||||
}
|
||||
|
||||
// GeneralFeatureMetas returns metadata for general features.
|
||||
func GeneralFeatureMetas() []FeatureMeta {
|
||||
return []FeatureMeta{
|
||||
{FeatureFineTuning, "Fine-Tuning", false},
|
||||
}
|
||||
}
|
||||
|
||||
// APIFeatureMetas returns metadata for API endpoint features.
|
||||
func APIFeatureMetas() []FeatureMeta {
|
||||
return []FeatureMeta{
|
||||
{FeatureChat, "Chat Completions", true},
|
||||
{FeatureImages, "Image Generation", true},
|
||||
{FeatureAudioSpeech, "Audio Speech / TTS", true},
|
||||
{FeatureAudioTranscription, "Audio Transcription", true},
|
||||
{FeatureVAD, "Voice Activity Detection", true},
|
||||
{FeatureDetection, "Detection", true},
|
||||
{FeatureVideo, "Video Generation", true},
|
||||
{FeatureEmbeddings, "Embeddings", true},
|
||||
{FeatureSound, "Sound Generation", true},
|
||||
{FeatureRealtime, "Realtime", true},
|
||||
{FeatureRerank, "Rerank", true},
|
||||
{FeatureTokenize, "Tokenize", true},
|
||||
{FeatureMCP, "MCP", true},
|
||||
{FeatureStores, "Stores", true},
|
||||
}
|
||||
}
|
||||
155
core/http/auth/helpers_test.go
Normal file
155
core/http/auth/helpers_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// testDB creates an in-memory SQLite GORM instance with auto-migration.
|
||||
func testDB() *gorm.DB {
|
||||
db, err := auth.InitDB(":memory:")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return db
|
||||
}
|
||||
|
||||
// createTestUser inserts a user directly into the DB for test setup.
|
||||
func createTestUser(db *gorm.DB, email, role, provider string) *auth.User {
|
||||
user := &auth.User{
|
||||
ID: generateTestID(),
|
||||
Email: email,
|
||||
Name: "Test User",
|
||||
Provider: provider,
|
||||
Subject: generateTestID(),
|
||||
Role: role,
|
||||
Status: auth.StatusActive,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return user
|
||||
}
|
||||
|
||||
// createTestSession creates a session for a user, returns plaintext session token.
|
||||
func createTestSession(db *gorm.DB, userID string) string {
|
||||
sessionID, err := auth.CreateSession(db, userID, "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return sessionID
|
||||
}
|
||||
|
||||
var testIDCounter int
|
||||
|
||||
func generateTestID() string {
|
||||
testIDCounter++
|
||||
return "test-id-" + string(rune('a'+testIDCounter))
|
||||
}
|
||||
|
||||
// ok is a simple handler that returns 200 OK.
|
||||
func ok(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
}
|
||||
|
||||
// newAuthTestApp creates a minimal Echo app with the new auth middleware.
|
||||
func newAuthTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo {
|
||||
e := echo.New()
|
||||
e.Use(auth.Middleware(db, appConfig))
|
||||
|
||||
// API routes (require auth)
|
||||
e.GET("/v1/models", ok)
|
||||
e.POST("/v1/chat/completions", ok)
|
||||
e.GET("/api/settings", ok)
|
||||
e.POST("/api/settings", ok)
|
||||
|
||||
// Auth routes (exempt)
|
||||
e.GET("/api/auth/status", ok)
|
||||
e.GET("/api/auth/github/login", ok)
|
||||
|
||||
// Static routes
|
||||
e.GET("/app", ok)
|
||||
e.GET("/app/*", ok)
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
// newAdminTestApp creates an Echo app with admin-protected routes.
|
||||
func newAdminTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo {
|
||||
e := echo.New()
|
||||
e.Use(auth.Middleware(db, appConfig))
|
||||
|
||||
// Regular routes
|
||||
e.GET("/v1/models", ok)
|
||||
e.POST("/v1/chat/completions", ok)
|
||||
|
||||
// Admin-only routes
|
||||
adminMw := auth.RequireAdmin()
|
||||
e.POST("/api/settings", ok, adminMw)
|
||||
e.POST("/models/apply", ok, adminMw)
|
||||
e.POST("/backends/apply", ok, adminMw)
|
||||
e.GET("/api/agents", ok, adminMw)
|
||||
|
||||
// Trace/log endpoints (admin only)
|
||||
e.GET("/api/traces", ok, adminMw)
|
||||
e.POST("/api/traces/clear", ok, adminMw)
|
||||
e.GET("/api/backend-logs", ok, adminMw)
|
||||
e.GET("/api/backend-logs/:modelId", ok, adminMw)
|
||||
|
||||
// Gallery/management reads (admin only)
|
||||
e.GET("/api/operations", ok, adminMw)
|
||||
e.GET("/api/models", ok, adminMw)
|
||||
e.GET("/api/backends", ok, adminMw)
|
||||
e.GET("/api/resources", ok, adminMw)
|
||||
e.GET("/api/p2p/workers", ok, adminMw)
|
||||
|
||||
// Agent task/job routes (admin only)
|
||||
e.POST("/api/agent/tasks", ok, adminMw)
|
||||
e.GET("/api/agent/tasks", ok, adminMw)
|
||||
e.GET("/api/agent/jobs", ok, adminMw)
|
||||
|
||||
// System info (admin only)
|
||||
e.GET("/system", ok, adminMw)
|
||||
e.GET("/backend/monitor", ok, adminMw)
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
// doRequest performs an HTTP request against the given Echo app and returns the recorder.
|
||||
func doRequest(e *echo.Echo, method, path string, opts ...func(*http.Request)) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(method, path, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
for _, opt := range opts {
|
||||
opt(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
func withBearerToken(token string) func(*http.Request) {
|
||||
return func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
}
|
||||
|
||||
func withXApiKey(key string) func(*http.Request) {
|
||||
return func(req *http.Request) {
|
||||
req.Header.Set("x-api-key", key)
|
||||
}
|
||||
}
|
||||
|
||||
func withSessionCookie(sessionID string) func(*http.Request) {
|
||||
return func(req *http.Request) {
|
||||
req.AddCookie(&http.Cookie{Name: "session", Value: sessionID})
|
||||
}
|
||||
}
|
||||
|
||||
func withTokenCookie(token string) func(*http.Request) {
|
||||
return func(req *http.Request) {
|
||||
req.AddCookie(&http.Cookie{Name: "token", Value: token})
|
||||
}
|
||||
}
|
||||
522
core/http/auth/middleware.go
Normal file
522
core/http/auth/middleware.go
Normal file
@@ -0,0 +1,522 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
contextKeyUser = "auth_user"
|
||||
contextKeyRole = "auth_role"
|
||||
)
|
||||
|
||||
// Middleware returns an Echo middleware that handles authentication.
|
||||
//
|
||||
// Resolution order:
|
||||
// 1. If auth not enabled AND no legacy API keys → pass through
|
||||
// 2. Skip auth for exempt paths (PathWithoutAuth + /api/auth/)
|
||||
// 3. If auth enabled (db != nil):
|
||||
// a. Try "session" cookie → DB lookup
|
||||
// b. Try Authorization: Bearer → session ID, then user API key
|
||||
// c. Try x-api-key / xi-api-key → user API key
|
||||
// d. Try "token" cookie → legacy API key check
|
||||
// e. Check all extracted keys against legacy ApiKeys → synthetic admin
|
||||
// 4. If auth not enabled → delegate to legacy API key validation
|
||||
// 5. If no auth found for /api/ or /v1/ paths → 401
|
||||
// 6. Otherwise pass through (static assets, UI pages, etc.)
|
||||
func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
authEnabled := db != nil
|
||||
hasLegacyKeys := len(appConfig.ApiKeys) > 0
|
||||
|
||||
// 1. No auth at all
|
||||
if !authEnabled && !hasLegacyKeys {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
path := c.Request().URL.Path
|
||||
exempt := isExemptPath(path, appConfig)
|
||||
authenticated := false
|
||||
|
||||
// 2. Try to authenticate (populates user in context if possible)
|
||||
if authEnabled {
|
||||
user := tryAuthenticate(c, db, appConfig)
|
||||
if user != nil {
|
||||
c.Set(contextKeyUser, user)
|
||||
c.Set(contextKeyRole, user.Role)
|
||||
authenticated = true
|
||||
|
||||
// Session rotation for cookie-based sessions
|
||||
if session, ok := c.Get("_auth_session").(*Session); ok {
|
||||
MaybeRotateSession(c, db, session, appConfig.Auth.APIKeyHMACSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Legacy API key validation (works whether auth is enabled or not)
|
||||
if !authenticated && hasLegacyKeys {
|
||||
key := extractKey(c)
|
||||
if key != "" && isValidLegacyKey(key, appConfig) {
|
||||
syntheticUser := &User{
|
||||
ID: "legacy-api-key",
|
||||
Name: "API Key User",
|
||||
Role: RoleAdmin,
|
||||
}
|
||||
c.Set(contextKeyUser, syntheticUser)
|
||||
c.Set(contextKeyRole, RoleAdmin)
|
||||
authenticated = true
|
||||
}
|
||||
}
|
||||
|
||||
// 4. If authenticated or exempt path, proceed
|
||||
if authenticated || exempt {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// 5. Require auth for API paths
|
||||
if isAPIPath(path) {
|
||||
// Check GET exemptions for legacy keys
|
||||
if hasLegacyKeys && appConfig.DisableApiKeyRequirementForHttpGet && c.Request().Method == http.MethodGet {
|
||||
for _, rx := range appConfig.HttpGetExemptedEndpoints {
|
||||
if rx.MatchString(c.Path()) {
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
return authError(c, appConfig)
|
||||
}
|
||||
|
||||
// 6. Non-API paths (UI, static assets) pass through.
|
||||
// The React UI handles login redirects client-side.
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAdmin returns middleware that checks the user has admin role.
|
||||
func RequireAdmin() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "Authentication required",
|
||||
Code: http.StatusUnauthorized,
|
||||
Type: "authentication_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
if user.Role != RoleAdmin {
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "Admin access required",
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NoopMiddleware returns a middleware that does nothing (pass-through).
|
||||
// Used when auth is disabled to satisfy route registration that expects
|
||||
// an admin middleware parameter.
|
||||
func NoopMiddleware() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
// RequireFeature returns middleware that checks the user has access to the given feature.
|
||||
// If no auth DB is provided, it passes through (backward compat).
|
||||
// Admins always pass. Regular users must have the feature enabled in their permissions.
|
||||
func RequireFeature(db *gorm.DB, feature string) echo.MiddlewareFunc {
|
||||
if db == nil {
|
||||
return NoopMiddleware()
|
||||
}
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "Authentication required",
|
||||
Code: http.StatusUnauthorized,
|
||||
Type: "authentication_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
if user.Role == RoleAdmin {
|
||||
return next(c)
|
||||
}
|
||||
perm, err := GetCachedUserPermissions(c, db, user.ID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "feature not enabled for your account",
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
val, exists := perm.Permissions[feature]
|
||||
if !exists {
|
||||
if !isDefaultOnFeature(feature) {
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "feature not enabled for your account",
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
} else if !val {
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "feature not enabled for your account",
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetUser returns the authenticated user from the echo context, or nil.
|
||||
func GetUser(c echo.Context) *User {
|
||||
u, ok := c.Get(contextKeyUser).(*User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// GetUserRole returns the role of the authenticated user, or empty string.
|
||||
func GetUserRole(c echo.Context) string {
|
||||
role, _ := c.Get(contextKeyRole).(string)
|
||||
return role
|
||||
}
|
||||
|
||||
// RequireRouteFeature returns a global middleware that checks the user has access
|
||||
// to the feature required by the matched route. It uses the RouteFeatureRegistry
|
||||
// to look up the required feature for each route pattern + HTTP method.
|
||||
// If no entry matches, the request passes through (no restriction).
|
||||
func RequireRouteFeature(db *gorm.DB) echo.MiddlewareFunc {
|
||||
if db == nil {
|
||||
return NoopMiddleware()
|
||||
}
|
||||
// Pre-build lookup map: "METHOD:pattern" -> feature
|
||||
lookup := map[string]string{}
|
||||
for _, rf := range RouteFeatureRegistry {
|
||||
lookup[rf.Method+":"+rf.Pattern] = rf.Feature
|
||||
}
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
path := c.Path() // Echo route pattern (e.g. "/v1/engines/:model/completions")
|
||||
method := c.Request().Method
|
||||
feature := lookup[method+":"+path]
|
||||
if feature == "" {
|
||||
feature = lookup["*:"+path]
|
||||
}
|
||||
if feature == "" {
|
||||
return next(c) // no restriction for this route
|
||||
}
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
return next(c) // auth middleware handles unauthenticated
|
||||
}
|
||||
if user.Role == RoleAdmin {
|
||||
return next(c)
|
||||
}
|
||||
perm, err := GetCachedUserPermissions(c, db, user.ID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "failed to check permissions",
|
||||
Code: http.StatusInternalServerError,
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
val, exists := perm.Permissions[feature]
|
||||
if !exists {
|
||||
if !isDefaultOnFeature(feature) {
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "feature not enabled for your account: " + feature,
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
} else if !val {
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "feature not enabled for your account: " + feature,
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireModelAccess returns a global middleware that checks the user is allowed
|
||||
// to use the resolved model. It extracts the model name directly from the request
|
||||
// (path param, query param, JSON body, or form value) rather than relying on a
|
||||
// context key set by downstream route-specific middleware.
|
||||
func RequireModelAccess(db *gorm.DB) echo.MiddlewareFunc {
|
||||
if db == nil {
|
||||
return NoopMiddleware()
|
||||
}
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
user := GetUser(c)
|
||||
if user == nil {
|
||||
return next(c)
|
||||
}
|
||||
if user.Role == RoleAdmin {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Check if this user even has a model allowlist enabled before
|
||||
// doing the expensive body read. Most users won't have restrictions.
|
||||
// Uses request-scoped cache to avoid duplicate DB hit when
|
||||
// RequireRouteFeature already fetched permissions.
|
||||
perm, err := GetCachedUserPermissions(c, db, user.ID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "failed to check permissions",
|
||||
Code: http.StatusInternalServerError,
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
allowlist := perm.AllowedModels
|
||||
if !allowlist.Enabled {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
modelName := extractModelFromRequest(c)
|
||||
if modelName == "" {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
for _, m := range allowlist.Models {
|
||||
if m == modelName {
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "access denied to model: " + modelName,
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractModelFromRequest extracts the model name from various request sources.
|
||||
// It checks URL path params, query params, JSON body, and form values.
|
||||
// For JSON bodies, it peeks at the body and resets it so downstream handlers
|
||||
// can still read it.
|
||||
func extractModelFromRequest(c echo.Context) string {
|
||||
// 1. URL path param (e.g. /v1/engines/:model/completions)
|
||||
if model := c.Param("model"); model != "" {
|
||||
return model
|
||||
}
|
||||
// 2. Query param
|
||||
if model := c.QueryParam("model"); model != "" {
|
||||
return model
|
||||
}
|
||||
// 3. Peek at JSON body
|
||||
if strings.HasPrefix(c.Request().Header.Get("Content-Type"), "application/json") {
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
c.Request().Body = io.NopCloser(bytes.NewReader(body)) // always reset
|
||||
if err == nil && len(body) > 0 {
|
||||
var m struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if json.Unmarshal(body, &m) == nil && m.Model != "" {
|
||||
return m.Model
|
||||
}
|
||||
}
|
||||
}
|
||||
// 4. Form value (multipart/form-data)
|
||||
if model := c.FormValue("model"); model != "" {
|
||||
return model
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// tryAuthenticate attempts to authenticate the request using the database.
|
||||
func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User {
|
||||
hmacSecret := appConfig.Auth.APIKeyHMACSecret
|
||||
|
||||
// a. Session cookie
|
||||
if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" {
|
||||
if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil {
|
||||
// Store session for rotation check in middleware
|
||||
c.Set("_auth_session", session)
|
||||
return user
|
||||
}
|
||||
}
|
||||
|
||||
// b. Authorization: Bearer token
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
// Try as session ID first
|
||||
if user, _ := ValidateSession(db, token, hmacSecret); user != nil {
|
||||
return user
|
||||
}
|
||||
|
||||
// Try as user API key
|
||||
if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil {
|
||||
return &key.User
|
||||
}
|
||||
}
|
||||
|
||||
// c. x-api-key / xi-api-key headers
|
||||
for _, header := range []string{"x-api-key", "xi-api-key"} {
|
||||
if key := c.Request().Header.Get(header); key != "" {
|
||||
if apiKey, err := ValidateAPIKey(db, key, hmacSecret); err == nil {
|
||||
return &apiKey.User
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// d. token cookie (legacy)
|
||||
if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" {
|
||||
// Try as user API key
|
||||
if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil {
|
||||
return &key.User
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractKey extracts an API key from the request (all sources).
|
||||
func extractKey(c echo.Context) string {
|
||||
// Authorization header
|
||||
auth := c.Request().Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
return strings.TrimPrefix(auth, "Bearer ")
|
||||
}
|
||||
if auth != "" {
|
||||
return auth
|
||||
}
|
||||
|
||||
// x-api-key
|
||||
if key := c.Request().Header.Get("x-api-key"); key != "" {
|
||||
return key
|
||||
}
|
||||
|
||||
// xi-api-key
|
||||
if key := c.Request().Header.Get("xi-api-key"); key != "" {
|
||||
return key
|
||||
}
|
||||
|
||||
// token cookie
|
||||
if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" {
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isValidLegacyKey checks if the key matches any configured API key
|
||||
// using constant-time comparison to prevent timing attacks.
|
||||
func isValidLegacyKey(key string, appConfig *config.ApplicationConfig) bool {
|
||||
for _, validKey := range appConfig.ApiKeys {
|
||||
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isExemptPath returns true if the path should skip authentication.
|
||||
func isExemptPath(path string, appConfig *config.ApplicationConfig) bool {
|
||||
// Auth endpoints are always public
|
||||
if strings.HasPrefix(path, "/api/auth/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check configured exempt paths
|
||||
for _, p := range appConfig.PathWithoutAuth {
|
||||
if strings.HasPrefix(path, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isAPIPath returns true for paths that always require authentication.
|
||||
func isAPIPath(path string) bool {
|
||||
return strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/models/") ||
|
||||
strings.HasPrefix(path, "/backends/") ||
|
||||
strings.HasPrefix(path, "/backend/") ||
|
||||
strings.HasPrefix(path, "/tts") ||
|
||||
strings.HasPrefix(path, "/vad") ||
|
||||
strings.HasPrefix(path, "/video") ||
|
||||
strings.HasPrefix(path, "/stores/") ||
|
||||
strings.HasPrefix(path, "/system") ||
|
||||
strings.HasPrefix(path, "/ws/") ||
|
||||
strings.HasPrefix(path, "/generated-") ||
|
||||
path == "/metrics"
|
||||
}
|
||||
|
||||
// authError returns an appropriate error response.
|
||||
func authError(c echo.Context, appConfig *config.ApplicationConfig) error {
|
||||
c.Response().Header().Set("WWW-Authenticate", "Bearer")
|
||||
|
||||
if appConfig.OpaqueErrors {
|
||||
return c.NoContent(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "An authentication key is required",
|
||||
Code: http.StatusUnauthorized,
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "An authentication key is required",
|
||||
Code: http.StatusUnauthorized,
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
306
core/http/auth/middleware_test.go
Normal file
306
core/http/auth/middleware_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ = Describe("Auth Middleware", func() {
|
||||
|
||||
Context("auth disabled, no API keys", func() {
|
||||
var app *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
appConfig := config.NewApplicationConfig()
|
||||
app = newAuthTestApp(nil, appConfig)
|
||||
})
|
||||
|
||||
It("passes through all requests", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models")
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("passes through POST requests", func() {
|
||||
rec := doRequest(app, http.MethodPost, "/v1/chat/completions")
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
|
||||
Context("auth disabled, API keys configured", func() {
|
||||
var app *echo.Echo
|
||||
const validKey = "sk-test-key-123"
|
||||
|
||||
BeforeEach(func() {
|
||||
appConfig := config.NewApplicationConfig()
|
||||
appConfig.ApiKeys = []string{validKey}
|
||||
app = newAuthTestApp(nil, appConfig)
|
||||
})
|
||||
|
||||
It("returns 401 for request without key", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models")
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("passes with valid Bearer token", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(validKey))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("passes with valid x-api-key header", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withXApiKey(validKey))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("passes with valid token cookie", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withTokenCookie(validKey))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("returns 401 for invalid key", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("wrong-key"))
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
})
|
||||
|
||||
Context("auth enabled with database", func() {
|
||||
var (
|
||||
db *gorm.DB
|
||||
app *echo.Echo
|
||||
appConfig *config.ApplicationConfig
|
||||
user *auth.User
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
db = testDB()
|
||||
appConfig = config.NewApplicationConfig()
|
||||
app = newAuthTestApp(db, appConfig)
|
||||
user = createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
})
|
||||
|
||||
It("allows requests with valid session cookie", func() {
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("allows requests with valid session as Bearer token", func() {
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("allows requests with valid user API key as Bearer token", func() {
|
||||
plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test", auth.RoleUser, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("allows requests with legacy API_KEY as admin bypass", func() {
|
||||
appConfig.ApiKeys = []string{"legacy-key-123"}
|
||||
app = newAuthTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("legacy-key-123"))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("returns 401 for expired session", func() {
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
// Manually expire (session ID in DB is the hash)
|
||||
hash := auth.HashAPIKey(sessionID, "")
|
||||
db.Model(&auth.Session{}).Where("id = ?", hash).
|
||||
Update("expires_at", "2020-01-01")
|
||||
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("returns 401 for invalid session ID", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie("invalid-session-id"))
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("returns 401 for revoked API key", func() {
|
||||
plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser, "", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = auth.RevokeAPIKey(db, record.ID, user.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext))
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("skips auth for /api/auth/* paths", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/api/auth/status")
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("skips auth for PathWithoutAuth paths", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/healthz")
|
||||
// healthz is not registered in our test app, so it'll be 404/405 but NOT 401
|
||||
Expect(rec.Code).ToNot(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("returns 401 for unauthenticated API requests", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/v1/models")
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("allows unauthenticated access to non-API paths when no legacy keys", func() {
|
||||
rec := doRequest(app, http.MethodGet, "/app")
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RequireAdmin", func() {
|
||||
var (
|
||||
db *gorm.DB
|
||||
appConfig *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
db = testDB()
|
||||
appConfig = config.NewApplicationConfig()
|
||||
})
|
||||
|
||||
It("passes for admin user", func() {
|
||||
admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, admin.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("returns 403 for user role", func() {
|
||||
user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusForbidden))
|
||||
})
|
||||
|
||||
It("returns 401 when no user in context", func() {
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodPost, "/api/settings")
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("allows admin to access model management", func() {
|
||||
admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, admin.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("blocks user from model management", func() {
|
||||
user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusForbidden))
|
||||
})
|
||||
|
||||
It("allows user to access regular inference endpoints", func() {
|
||||
user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodPost, "/v1/chat/completions", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("allows legacy API key (admin bypass) on admin routes", func() {
|
||||
appConfig.ApiKeys = []string{"admin-key"}
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodPost, "/api/settings", withBearerToken("admin-key"))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("allows admin to access trace endpoints", func() {
|
||||
admin := createTestUser(db, "admin2@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, admin.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("blocks non-admin from trace endpoints", func() {
|
||||
user := createTestUser(db, "user2@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusForbidden))
|
||||
|
||||
rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusForbidden))
|
||||
})
|
||||
|
||||
It("allows admin to access agent job endpoints", func() {
|
||||
admin := createTestUser(db, "admin3@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, admin.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("blocks non-admin from agent job endpoints", func() {
|
||||
user := createTestUser(db, "user3@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusForbidden))
|
||||
|
||||
rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusForbidden))
|
||||
})
|
||||
|
||||
It("blocks non-admin from system/management endpoints", func() {
|
||||
user := createTestUser(db, "user4@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, user.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} {
|
||||
rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusForbidden), "expected 403 for path: "+path)
|
||||
}
|
||||
})
|
||||
|
||||
It("allows admin to access system/management endpoints", func() {
|
||||
admin := createTestUser(db, "admin4@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
sessionID := createTestSession(db, admin.ID)
|
||||
app := newAdminTestApp(db, appConfig)
|
||||
|
||||
for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} {
|
||||
rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK), "expected 200 for path: "+path)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
148
core/http/auth/models.go
Normal file
148
core/http/auth/models.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Auth provider constants.
|
||||
const (
|
||||
ProviderLocal = "local"
|
||||
ProviderGitHub = "github"
|
||||
ProviderOIDC = "oidc"
|
||||
)
|
||||
|
||||
// User represents an authenticated user.
|
||||
type User struct {
|
||||
ID string `gorm:"primaryKey;size:36"`
|
||||
Email string `gorm:"size:255;index"`
|
||||
Name string `gorm:"size:255"`
|
||||
AvatarURL string `gorm:"size:512"`
|
||||
Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC
|
||||
Subject string `gorm:"size:255"` // provider-specific user ID
|
||||
PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users
|
||||
Role string `gorm:"size:20;default:user"`
|
||||
Status string `gorm:"size:20;default:active"` // "active", "pending"
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// Session represents a user login session.
|
||||
type Session struct {
|
||||
ID string `gorm:"primaryKey;size:64"` // HMAC-SHA256 hash of session token
|
||||
UserID string `gorm:"size:36;index"`
|
||||
ExpiresAt time.Time
|
||||
RotatedAt time.Time
|
||||
CreatedAt time.Time
|
||||
User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
// UserAPIKey represents a user-generated API key for programmatic access.
|
||||
type UserAPIKey struct {
|
||||
ID string `gorm:"primaryKey;size:36"`
|
||||
UserID string `gorm:"size:36;index"`
|
||||
Name string `gorm:"size:255"` // user-provided label
|
||||
KeyHash string `gorm:"size:64;uniqueIndex"`
|
||||
KeyPrefix string `gorm:"size:12"` // first 8 chars of key for display
|
||||
Role string `gorm:"size:20"`
|
||||
CreatedAt time.Time
|
||||
ExpiresAt *time.Time `gorm:"index"`
|
||||
LastUsed *time.Time
|
||||
User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
// PermissionMap is a flexible map of feature -> enabled, stored as JSON text.
|
||||
// Known features: "agents", "skills", "collections", "mcp_jobs".
|
||||
// New features can be added without schema changes.
|
||||
type PermissionMap map[string]bool
|
||||
|
||||
// Value implements driver.Valuer for GORM JSON serialization.
|
||||
func (p PermissionMap) Value() (driver.Value, error) {
|
||||
if p == nil {
|
||||
return "{}", nil
|
||||
}
|
||||
b, err := json.Marshal(p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal PermissionMap: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner for GORM JSON deserialization.
|
||||
func (p *PermissionMap) Scan(value any) error {
|
||||
if value == nil {
|
||||
*p = PermissionMap{}
|
||||
return nil
|
||||
}
|
||||
var bytes []byte
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
case []byte:
|
||||
bytes = v
|
||||
default:
|
||||
return fmt.Errorf("cannot scan %T into PermissionMap", value)
|
||||
}
|
||||
return json.Unmarshal(bytes, p)
|
||||
}
|
||||
|
||||
// InviteCode represents an admin-generated invitation for user registration.
|
||||
type InviteCode struct {
|
||||
ID string `gorm:"primaryKey;size:36"`
|
||||
Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code
|
||||
CodePrefix string `gorm:"size:12"` // first 8 chars for admin display
|
||||
CreatedBy string `gorm:"size:36;not null"`
|
||||
UsedBy *string `gorm:"size:36"`
|
||||
UsedAt *time.Time
|
||||
ExpiresAt time.Time `gorm:"not null;index"`
|
||||
CreatedAt time.Time
|
||||
Creator User `gorm:"foreignKey:CreatedBy"`
|
||||
Consumer *User `gorm:"foreignKey:UsedBy"`
|
||||
}
|
||||
|
||||
// ModelAllowlist controls which models a user can access.
|
||||
// When Enabled is false (default), all models are allowed.
|
||||
type ModelAllowlist struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer for GORM JSON serialization.
|
||||
func (m ModelAllowlist) Value() (driver.Value, error) {
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal ModelAllowlist: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner for GORM JSON deserialization.
|
||||
func (m *ModelAllowlist) Scan(value any) error {
|
||||
if value == nil {
|
||||
*m = ModelAllowlist{}
|
||||
return nil
|
||||
}
|
||||
var bytes []byte
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
case []byte:
|
||||
bytes = v
|
||||
default:
|
||||
return fmt.Errorf("cannot scan %T into ModelAllowlist", value)
|
||||
}
|
||||
return json.Unmarshal(bytes, m)
|
||||
}
|
||||
|
||||
// UserPermission stores per-user feature permissions.
|
||||
type UserPermission struct {
|
||||
ID string `gorm:"primaryKey;size:36"`
|
||||
UserID string `gorm:"size:36;uniqueIndex"`
|
||||
Permissions PermissionMap `gorm:"type:text"`
|
||||
AllowedModels ModelAllowlist `gorm:"type:text"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
|
||||
}
|
||||
439
core/http/auth/oauth.go
Normal file
439
core/http/auth/oauth.go
Normal file
@@ -0,0 +1,439 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/xlog"
|
||||
"golang.org/x/oauth2"
|
||||
githubOAuth "golang.org/x/oauth2/github"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// providerEntry holds the OAuth2/OIDC config for a single provider.
|
||||
type providerEntry struct {
|
||||
oauth2Config oauth2.Config
|
||||
oidcVerifier *oidc.IDTokenVerifier // nil for GitHub (API-based user info)
|
||||
name string
|
||||
userInfoURL string // only used for GitHub
|
||||
}
|
||||
|
||||
// oauthUserInfo is a provider-agnostic representation of an authenticated user.
|
||||
type oauthUserInfo struct {
|
||||
Subject string
|
||||
Email string
|
||||
Name string
|
||||
AvatarURL string
|
||||
}
|
||||
|
||||
// OAuthManager manages multiple OAuth/OIDC providers.
|
||||
type OAuthManager struct {
|
||||
providers map[string]*providerEntry
|
||||
}
|
||||
|
||||
// OAuthParams groups the parameters needed to create an OAuthManager.
|
||||
type OAuthParams struct {
|
||||
GitHubClientID string
|
||||
GitHubClientSecret string
|
||||
OIDCIssuer string
|
||||
OIDCClientID string
|
||||
OIDCClientSecret string
|
||||
}
|
||||
|
||||
// NewOAuthManager creates an OAuthManager from the given params.
|
||||
func NewOAuthManager(baseURL string, params OAuthParams) (*OAuthManager, error) {
|
||||
m := &OAuthManager{providers: make(map[string]*providerEntry)}
|
||||
|
||||
if params.GitHubClientID != "" {
|
||||
m.providers[ProviderGitHub] = &providerEntry{
|
||||
name: ProviderGitHub,
|
||||
oauth2Config: oauth2.Config{
|
||||
ClientID: params.GitHubClientID,
|
||||
ClientSecret: params.GitHubClientSecret,
|
||||
Endpoint: githubOAuth.Endpoint,
|
||||
RedirectURL: baseURL + "/api/auth/github/callback",
|
||||
Scopes: []string{"user:email", "read:user"},
|
||||
},
|
||||
userInfoURL: "https://api.github.com/user",
|
||||
}
|
||||
}
|
||||
|
||||
if params.OIDCClientID != "" && params.OIDCIssuer != "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, params.OIDCIssuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OIDC discovery failed for %s: %w", params.OIDCIssuer, err)
|
||||
}
|
||||
|
||||
verifier := provider.Verifier(&oidc.Config{ClientID: params.OIDCClientID})
|
||||
|
||||
m.providers[ProviderOIDC] = &providerEntry{
|
||||
name: ProviderOIDC,
|
||||
oauth2Config: oauth2.Config{
|
||||
ClientID: params.OIDCClientID,
|
||||
ClientSecret: params.OIDCClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: baseURL + "/api/auth/oidc/callback",
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
},
|
||||
oidcVerifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Providers returns the list of configured provider names.
|
||||
func (m *OAuthManager) Providers() []string {
|
||||
names := make([]string, 0, len(m.providers))
|
||||
for name := range m.providers {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// LoginHandler redirects the user to the OAuth provider's login page.
|
||||
func (m *OAuthManager) LoginHandler(providerName string) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
provider, ok := m.providers[providerName]
|
||||
if !ok {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"})
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to generate state"})
|
||||
}
|
||||
|
||||
secure := isSecure(c)
|
||||
c.SetCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 600, // 10 minutes
|
||||
})
|
||||
|
||||
// Store invite code in cookie if provided
|
||||
if inviteCode := c.QueryParam("invite_code"); inviteCode != "" {
|
||||
c.SetCookie(&http.Cookie{
|
||||
Name: "invite_code",
|
||||
Value: inviteCode,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 600,
|
||||
})
|
||||
}
|
||||
|
||||
url := provider.oauth2Config.AuthCodeURL(state)
|
||||
return c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
}
|
||||
}
|
||||
|
||||
// CallbackHandler handles the OAuth callback, creates/updates the user, and
|
||||
// creates a session.
|
||||
func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEmail, registrationMode, hmacSecret string) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
provider, ok := m.providers[providerName]
|
||||
if !ok {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"})
|
||||
}
|
||||
|
||||
// Validate state
|
||||
stateCookie, err := c.Cookie("oauth_state")
|
||||
if err != nil || stateCookie.Value == "" || subtle.ConstantTimeCompare([]byte(stateCookie.Value), []byte(c.QueryParam("state"))) != 1 {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid OAuth state"})
|
||||
}
|
||||
|
||||
// Clear state cookie
|
||||
c.SetCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure(c),
|
||||
MaxAge: -1,
|
||||
})
|
||||
|
||||
// Exchange code for token
|
||||
code := c.QueryParam("code")
|
||||
if code == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing authorization code"})
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request().Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
token, err := provider.oauth2Config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
xlog.Error("OAuth code exchange failed", "provider", providerName, "error", err)
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "OAuth authentication failed"})
|
||||
}
|
||||
|
||||
// Fetch user info — branch based on provider type
|
||||
var userInfo *oauthUserInfo
|
||||
if provider.oidcVerifier != nil {
|
||||
userInfo, err = extractOIDCUserInfo(ctx, provider.oidcVerifier, token)
|
||||
} else {
|
||||
userInfo, err = fetchGitHubUserInfoAsOAuth(ctx, token.AccessToken)
|
||||
}
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to fetch user info"})
|
||||
}
|
||||
|
||||
// Retrieve invite code from cookie if present
|
||||
var inviteCode string
|
||||
if ic, err := c.Cookie("invite_code"); err == nil && ic.Value != "" {
|
||||
inviteCode = ic.Value
|
||||
// Clear the invite code cookie
|
||||
c.SetCookie(&http.Cookie{
|
||||
Name: "invite_code",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure(c),
|
||||
MaxAge: -1,
|
||||
})
|
||||
}
|
||||
|
||||
// Upsert user (with invite code support)
|
||||
user, err := upsertOAuthUser(db, providerName, userInfo, adminEmail, registrationMode)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create user"})
|
||||
}
|
||||
|
||||
// For new users that are pending, check if they have a valid invite
|
||||
if user.Status != StatusActive && inviteCode != "" {
|
||||
if invite, err := ValidateInvite(db, inviteCode, hmacSecret); err == nil {
|
||||
user.Status = StatusActive
|
||||
db.Model(user).Update("status", StatusActive)
|
||||
ConsumeInvite(db, invite, user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != StatusActive {
|
||||
if registrationMode == "invite" {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "a valid invite code is required to register"})
|
||||
}
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "account pending approval"})
|
||||
}
|
||||
|
||||
// Maybe promote on login
|
||||
MaybePromote(db, user, adminEmail)
|
||||
|
||||
// Create session
|
||||
sessionID, err := CreateSession(db, user.ID, hmacSecret)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create session"})
|
||||
}
|
||||
|
||||
SetSessionCookie(c, sessionID)
|
||||
return c.Redirect(http.StatusTemporaryRedirect, "/app")
|
||||
}
|
||||
}
|
||||
|
||||
// extractOIDCUserInfo extracts user info from the OIDC ID token.
|
||||
func extractOIDCUserInfo(ctx context.Context, verifier *oidc.IDTokenVerifier, token *oauth2.Token) (*oauthUserInfo, error) {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok || rawIDToken == "" {
|
||||
return nil, fmt.Errorf("no id_token in token response")
|
||||
}
|
||||
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify ID token: %w", err)
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ID token claims: %w", err)
|
||||
}
|
||||
|
||||
return &oauthUserInfo{
|
||||
Subject: claims.Sub,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
AvatarURL: claims.Picture,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type githubUserInfo struct {
|
||||
ID int `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}
|
||||
|
||||
type githubEmail struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
|
||||
// fetchGitHubUserInfoAsOAuth fetches GitHub user info and returns it as oauthUserInfo.
|
||||
func fetchGitHubUserInfoAsOAuth(ctx context.Context, accessToken string) (*oauthUserInfo, error) {
|
||||
info, err := fetchGitHubUserInfo(ctx, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oauthUserInfo{
|
||||
Subject: fmt.Sprintf("%d", info.ID),
|
||||
Email: info.Email,
|
||||
Name: info.Name,
|
||||
AvatarURL: info.AvatarURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserInfo, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var info githubUserInfo
|
||||
if err := json.Unmarshal(body, &info); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no public email, fetch from /user/emails
|
||||
if info.Email == "" {
|
||||
info.Email, _ = fetchGitHubPrimaryEmail(ctx, accessToken)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var emails []githubEmail
|
||||
if err := json.Unmarshal(body, &emails); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, e := range emails {
|
||||
if e.Primary && e.Verified {
|
||||
return e.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to first verified email
|
||||
for _, e := range emails {
|
||||
if e.Verified {
|
||||
return e.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no verified email found")
|
||||
}
|
||||
|
||||
func upsertOAuthUser(db *gorm.DB, provider string, info *oauthUserInfo, adminEmail, registrationMode string) (*User, error) {
|
||||
// Normalize email from provider (#10)
|
||||
if info.Email != "" {
|
||||
info.Email = strings.ToLower(strings.TrimSpace(info.Email))
|
||||
}
|
||||
|
||||
var user User
|
||||
err := db.Where("provider = ? AND subject = ?", provider, info.Subject).First(&user).Error
|
||||
if err == nil {
|
||||
// Existing user — update profile fields
|
||||
user.Name = info.Name
|
||||
user.AvatarURL = info.AvatarURL
|
||||
if info.Email != "" {
|
||||
user.Email = info.Email
|
||||
}
|
||||
db.Save(&user)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// New user — empty registration mode defaults to "approval"
|
||||
effectiveMode := registrationMode
|
||||
if effectiveMode == "" {
|
||||
effectiveMode = "approval"
|
||||
}
|
||||
status := StatusActive
|
||||
if effectiveMode == "approval" || effectiveMode == "invite" {
|
||||
status = StatusPending
|
||||
}
|
||||
|
||||
role := AssignRole(db, info.Email, adminEmail)
|
||||
// First user is always active regardless of registration mode
|
||||
if role == RoleAdmin {
|
||||
status = StatusActive
|
||||
}
|
||||
|
||||
user = User{
|
||||
ID: uuid.New().String(),
|
||||
Email: info.Email,
|
||||
Name: info.Name,
|
||||
AvatarURL: info.AvatarURL,
|
||||
Provider: provider,
|
||||
Subject: info.Subject,
|
||||
Role: role,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
14
core/http/auth/password.go
Normal file
14
core/http/auth/password.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package auth
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
// HashPassword returns a bcrypt hash of the given password.
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
// CheckPassword compares a bcrypt hash with a plaintext password.
|
||||
func CheckPassword(hash, password string) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
|
||||
}
|
||||
217
core/http/auth/permissions.go
Normal file
217
core/http/auth/permissions.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const contextKeyPermissions = "auth_permissions"
|
||||
|
||||
// GetCachedUserPermissions returns the user's permission record, using a
|
||||
// request-scoped cache stored in the echo context. This avoids duplicate
|
||||
// DB lookups when multiple middlewares (RequireRouteFeature, RequireModelAccess)
|
||||
// both need permissions in the same request.
|
||||
func GetCachedUserPermissions(c echo.Context, db *gorm.DB, userID string) (*UserPermission, error) {
|
||||
if perm, ok := c.Get(contextKeyPermissions).(*UserPermission); ok && perm != nil {
|
||||
return perm, nil
|
||||
}
|
||||
perm, err := GetUserPermissions(db, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Set(contextKeyPermissions, perm)
|
||||
return perm, nil
|
||||
}
|
||||
|
||||
// Feature name constants — all code must use these, never bare strings.
|
||||
const (
|
||||
// Agent features (default OFF for new users)
|
||||
FeatureAgents = "agents"
|
||||
FeatureSkills = "skills"
|
||||
FeatureCollections = "collections"
|
||||
FeatureMCPJobs = "mcp_jobs"
|
||||
|
||||
// General features (default OFF for new users)
|
||||
FeatureFineTuning = "fine_tuning"
|
||||
|
||||
// API features (default ON for new users)
|
||||
FeatureChat = "chat"
|
||||
FeatureImages = "images"
|
||||
FeatureAudioSpeech = "audio_speech"
|
||||
FeatureAudioTranscription = "audio_transcription"
|
||||
FeatureVAD = "vad"
|
||||
FeatureDetection = "detection"
|
||||
FeatureVideo = "video"
|
||||
FeatureEmbeddings = "embeddings"
|
||||
FeatureSound = "sound"
|
||||
FeatureRealtime = "realtime"
|
||||
FeatureRerank = "rerank"
|
||||
FeatureTokenize = "tokenize"
|
||||
FeatureMCP = "mcp"
|
||||
FeatureStores = "stores"
|
||||
)
|
||||
|
||||
// AgentFeatures lists agent-related features (default OFF).
|
||||
var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs}
|
||||
|
||||
// GeneralFeatures lists general features (default OFF).
|
||||
var GeneralFeatures = []string{FeatureFineTuning}
|
||||
|
||||
// APIFeatures lists API endpoint features (default ON).
|
||||
var APIFeatures = []string{
|
||||
FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription,
|
||||
FeatureVAD, FeatureDetection, FeatureVideo, FeatureEmbeddings, FeatureSound,
|
||||
FeatureRealtime, FeatureRerank, FeatureTokenize, FeatureMCP, FeatureStores,
|
||||
}
|
||||
|
||||
// AllFeatures lists all known features (used by UI and validation).
|
||||
var AllFeatures = append(append(append([]string{}, AgentFeatures...), GeneralFeatures...), APIFeatures...)
|
||||
|
||||
// defaultOnFeatures is the set of features that default to ON when absent from a user's permission map.
|
||||
var defaultOnFeatures = func() map[string]bool {
|
||||
m := map[string]bool{}
|
||||
for _, f := range APIFeatures {
|
||||
m[f] = true
|
||||
}
|
||||
return m
|
||||
}()
|
||||
|
||||
// isDefaultOnFeature returns true if the feature defaults to ON when not explicitly set.
|
||||
func isDefaultOnFeature(feature string) bool {
|
||||
return defaultOnFeatures[feature]
|
||||
}
|
||||
|
||||
// GetUserPermissions returns the permission record for a user, creating a default
|
||||
// (empty map = all disabled) if none exists.
|
||||
func GetUserPermissions(db *gorm.DB, userID string) (*UserPermission, error) {
|
||||
var perm UserPermission
|
||||
err := db.Where("user_id = ?", userID).First(&perm).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
perm = UserPermission{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
Permissions: PermissionMap{},
|
||||
}
|
||||
if err := db.Create(&perm).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &perm, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &perm, nil
|
||||
}
|
||||
|
||||
// UpdateUserPermissions upserts the permission map for a user.
|
||||
func UpdateUserPermissions(db *gorm.DB, userID string, perms PermissionMap) error {
|
||||
var perm UserPermission
|
||||
err := db.Where("user_id = ?", userID).First(&perm).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
perm = UserPermission{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
Permissions: perms,
|
||||
}
|
||||
return db.Create(&perm).Error
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
perm.Permissions = perms
|
||||
return db.Save(&perm).Error
|
||||
}
|
||||
|
||||
// HasFeatureAccess returns true if the user is an admin or has the given feature enabled.
|
||||
// When a feature key is absent from the user's permission map, it checks whether the
|
||||
// feature defaults to ON (API features) or OFF (agent features) for backward compatibility.
|
||||
func HasFeatureAccess(db *gorm.DB, user *User, feature string) bool {
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
if user.Role == RoleAdmin {
|
||||
return true
|
||||
}
|
||||
perm, err := GetUserPermissions(db, user.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
val, exists := perm.Permissions[feature]
|
||||
if !exists {
|
||||
return isDefaultOnFeature(feature)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// GetPermissionMapForUser returns the effective permission map for a user.
|
||||
// Admins get all features as true (virtual).
|
||||
// For regular users, absent keys are filled with their defaults so the
|
||||
// UI/API always returns a complete picture.
|
||||
func GetPermissionMapForUser(db *gorm.DB, user *User) PermissionMap {
|
||||
if user == nil {
|
||||
return PermissionMap{}
|
||||
}
|
||||
if user.Role == RoleAdmin {
|
||||
m := PermissionMap{}
|
||||
for _, f := range AllFeatures {
|
||||
m[f] = true
|
||||
}
|
||||
return m
|
||||
}
|
||||
perm, err := GetUserPermissions(db, user.ID)
|
||||
if err != nil {
|
||||
return PermissionMap{}
|
||||
}
|
||||
// Fill in defaults for absent keys
|
||||
effective := PermissionMap{}
|
||||
for _, f := range AllFeatures {
|
||||
val, exists := perm.Permissions[f]
|
||||
if exists {
|
||||
effective[f] = val
|
||||
} else {
|
||||
effective[f] = isDefaultOnFeature(f)
|
||||
}
|
||||
}
|
||||
return effective
|
||||
}
|
||||
|
||||
// GetModelAllowlist returns the model allowlist for a user.
|
||||
func GetModelAllowlist(db *gorm.DB, userID string) ModelAllowlist {
|
||||
perm, err := GetUserPermissions(db, userID)
|
||||
if err != nil {
|
||||
return ModelAllowlist{}
|
||||
}
|
||||
return perm.AllowedModels
|
||||
}
|
||||
|
||||
// UpdateModelAllowlist updates the model allowlist for a user.
|
||||
func UpdateModelAllowlist(db *gorm.DB, userID string, allowlist ModelAllowlist) error {
|
||||
perm, err := GetUserPermissions(db, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
perm.AllowedModels = allowlist
|
||||
return db.Save(perm).Error
|
||||
}
|
||||
|
||||
// IsModelAllowed returns true if the user is allowed to use the given model.
|
||||
// Admins always have access. If the allowlist is not enabled, all models are allowed.
|
||||
func IsModelAllowed(db *gorm.DB, user *User, modelName string) bool {
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
if user.Role == RoleAdmin {
|
||||
return true
|
||||
}
|
||||
allowlist := GetModelAllowlist(db, user.ID)
|
||||
if !allowlist.Enabled {
|
||||
return true
|
||||
}
|
||||
for _, m := range allowlist.Models {
|
||||
if m == modelName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
103
core/http/auth/roles.go
Normal file
103
core/http/auth/roles.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
|
||||
StatusActive = "active"
|
||||
StatusPending = "pending"
|
||||
StatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// AssignRole determines the role for a new user.
|
||||
// First user in the database becomes admin. If adminEmail is set and matches,
|
||||
// the user becomes admin. Otherwise, the user gets the "user" role.
|
||||
// Must be called within a transaction that also creates the user to prevent
|
||||
// race conditions on the first-user admin assignment.
|
||||
func AssignRole(tx *gorm.DB, email, adminEmail string) string {
|
||||
var count int64
|
||||
tx.Model(&User{}).Count(&count)
|
||||
if count == 0 {
|
||||
return RoleAdmin
|
||||
}
|
||||
|
||||
if adminEmail != "" && strings.EqualFold(email, adminEmail) {
|
||||
return RoleAdmin
|
||||
}
|
||||
|
||||
return RoleUser
|
||||
}
|
||||
|
||||
// MaybePromote promotes a user to admin on login if their email matches
|
||||
// adminEmail. It does not demote existing admins. Returns true if the user
|
||||
// was promoted.
|
||||
func MaybePromote(db *gorm.DB, user *User, adminEmail string) bool {
|
||||
if user.Role == RoleAdmin {
|
||||
return false
|
||||
}
|
||||
|
||||
if adminEmail != "" && strings.EqualFold(user.Email, adminEmail) {
|
||||
user.Role = RoleAdmin
|
||||
db.Model(user).Update("role", RoleAdmin)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateInvite checks that an invite code exists, is unused, and has not expired.
|
||||
// The code is hashed with HMAC-SHA256 before lookup.
|
||||
func ValidateInvite(db *gorm.DB, code, hmacSecret string) (*InviteCode, error) {
|
||||
hash := HashAPIKey(code, hmacSecret)
|
||||
var invite InviteCode
|
||||
if err := db.Where("code = ?", hash).First(&invite).Error; err != nil {
|
||||
return nil, fmt.Errorf("invite code not found")
|
||||
}
|
||||
if invite.UsedBy != nil {
|
||||
return nil, fmt.Errorf("invite code already used")
|
||||
}
|
||||
if time.Now().After(invite.ExpiresAt) {
|
||||
return nil, fmt.Errorf("invite code expired")
|
||||
}
|
||||
return &invite, nil
|
||||
}
|
||||
|
||||
// ConsumeInvite marks an invite code as used by the given user.
|
||||
func ConsumeInvite(db *gorm.DB, invite *InviteCode, userID string) {
|
||||
now := time.Now()
|
||||
invite.UsedBy = &userID
|
||||
invite.UsedAt = &now
|
||||
db.Save(invite)
|
||||
}
|
||||
|
||||
// NeedsInviteOrApproval returns true if registration gating applies for the given mode.
|
||||
// Admins (first user or matching adminEmail) are never gated.
|
||||
// Must be called within a transaction that also creates the user.
|
||||
func NeedsInviteOrApproval(tx *gorm.DB, email, adminEmail, registrationMode string) bool {
|
||||
// Empty registration mode defaults to "approval"
|
||||
if registrationMode == "" {
|
||||
registrationMode = "approval"
|
||||
}
|
||||
if registrationMode != "approval" && registrationMode != "invite" {
|
||||
return false
|
||||
}
|
||||
// Admin email is never gated
|
||||
if adminEmail != "" && strings.EqualFold(email, adminEmail) {
|
||||
return false
|
||||
}
|
||||
// First user is never gated
|
||||
var count int64
|
||||
tx.Model(&User{}).Count(&count)
|
||||
if count == 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
84
core/http/auth/roles_test.go
Normal file
84
core/http/auth/roles_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ = Describe("Roles", func() {
|
||||
var db *gorm.DB
|
||||
|
||||
BeforeEach(func() {
|
||||
db = testDB()
|
||||
})
|
||||
|
||||
Describe("AssignRole", func() {
|
||||
It("returns admin for the first user (empty DB)", func() {
|
||||
role := auth.AssignRole(db, "first@example.com", "")
|
||||
Expect(role).To(Equal(auth.RoleAdmin))
|
||||
})
|
||||
|
||||
It("returns user for the second user", func() {
|
||||
createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
|
||||
role := auth.AssignRole(db, "second@example.com", "")
|
||||
Expect(role).To(Equal(auth.RoleUser))
|
||||
})
|
||||
|
||||
It("returns admin when email matches adminEmail", func() {
|
||||
createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
|
||||
role := auth.AssignRole(db, "admin@example.com", "admin@example.com")
|
||||
Expect(role).To(Equal(auth.RoleAdmin))
|
||||
})
|
||||
|
||||
It("is case-insensitive for admin email match", func() {
|
||||
createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
|
||||
role := auth.AssignRole(db, "Admin@Example.COM", "admin@example.com")
|
||||
Expect(role).To(Equal(auth.RoleAdmin))
|
||||
})
|
||||
|
||||
It("returns user when email does not match adminEmail", func() {
|
||||
createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
|
||||
role := auth.AssignRole(db, "other@example.com", "admin@example.com")
|
||||
Expect(role).To(Equal(auth.RoleUser))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("MaybePromote", func() {
|
||||
It("promotes user to admin when email matches", func() {
|
||||
user := createTestUser(db, "promoted@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
|
||||
promoted := auth.MaybePromote(db, user, "promoted@example.com")
|
||||
Expect(promoted).To(BeTrue())
|
||||
Expect(user.Role).To(Equal(auth.RoleAdmin))
|
||||
|
||||
// Verify in DB
|
||||
var dbUser auth.User
|
||||
db.First(&dbUser, "id = ?", user.ID)
|
||||
Expect(dbUser.Role).To(Equal(auth.RoleAdmin))
|
||||
})
|
||||
|
||||
It("does not promote when email does not match", func() {
|
||||
user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
|
||||
promoted := auth.MaybePromote(db, user, "admin@example.com")
|
||||
Expect(promoted).To(BeFalse())
|
||||
Expect(user.Role).To(Equal(auth.RoleUser))
|
||||
})
|
||||
|
||||
It("does not demote an existing admin", func() {
|
||||
user := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub)
|
||||
|
||||
promoted := auth.MaybePromote(db, user, "other@example.com")
|
||||
Expect(promoted).To(BeFalse())
|
||||
Expect(user.Role).To(Equal(auth.RoleAdmin))
|
||||
})
|
||||
})
|
||||
})
|
||||
182
core/http/auth/session.go
Normal file
182
core/http/auth/session.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionDuration = 30 * 24 * time.Hour // 30 days
|
||||
sessionIDBytes = 32 // 32 bytes = 64 hex chars
|
||||
sessionCookie = "session"
|
||||
sessionRotationInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// CreateSession creates a new session for the given user, returning the
|
||||
// plaintext token (64-char hex string). The stored session ID is the
|
||||
// HMAC-SHA256 hash of the token.
|
||||
func CreateSession(db *gorm.DB, userID, hmacSecret string) (string, error) {
|
||||
b := make([]byte, sessionIDBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
plaintext := hex.EncodeToString(b)
|
||||
hash := HashAPIKey(plaintext, hmacSecret)
|
||||
|
||||
now := time.Now()
|
||||
session := Session{
|
||||
ID: hash,
|
||||
UserID: userID,
|
||||
ExpiresAt: now.Add(sessionDuration),
|
||||
RotatedAt: now,
|
||||
}
|
||||
|
||||
if err := db.Create(&session).Error; err != nil {
|
||||
return "", fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// ValidateSession hashes the plaintext token and looks up the session.
|
||||
// Returns the associated user and session, or (nil, nil) if not found/expired.
|
||||
func ValidateSession(db *gorm.DB, token, hmacSecret string) (*User, *Session) {
|
||||
hash := HashAPIKey(token, hmacSecret)
|
||||
|
||||
var session Session
|
||||
if err := db.Preload("User").Where("id = ? AND expires_at > ?", hash, time.Now()).First(&session).Error; err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
if session.User.Status != StatusActive {
|
||||
return nil, nil
|
||||
}
|
||||
return &session.User, &session
|
||||
}
|
||||
|
||||
// DeleteSession removes a session by hashing the plaintext token.
|
||||
func DeleteSession(db *gorm.DB, token, hmacSecret string) error {
|
||||
hash := HashAPIKey(token, hmacSecret)
|
||||
return db.Where("id = ?", hash).Delete(&Session{}).Error
|
||||
}
|
||||
|
||||
// CleanExpiredSessions removes all sessions that have passed their expiry time.
|
||||
func CleanExpiredSessions(db *gorm.DB) error {
|
||||
return db.Where("expires_at < ?", time.Now()).Delete(&Session{}).Error
|
||||
}
|
||||
|
||||
// DeleteUserSessions removes all sessions for the given user.
|
||||
func DeleteUserSessions(db *gorm.DB, userID string) error {
|
||||
return db.Where("user_id = ?", userID).Delete(&Session{}).Error
|
||||
}
|
||||
|
||||
// RotateSession creates a new session for the same user, deletes the old one,
|
||||
// and returns the new plaintext token.
|
||||
func RotateSession(db *gorm.DB, oldSession *Session, hmacSecret string) (string, error) {
|
||||
b := make([]byte, sessionIDBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
plaintext := hex.EncodeToString(b)
|
||||
hash := HashAPIKey(plaintext, hmacSecret)
|
||||
|
||||
now := time.Now()
|
||||
newSession := Session{
|
||||
ID: hash,
|
||||
UserID: oldSession.UserID,
|
||||
ExpiresAt: oldSession.ExpiresAt,
|
||||
RotatedAt: now,
|
||||
}
|
||||
|
||||
err := db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(&newSession).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Where("id = ?", oldSession.ID).Delete(&Session{}).Error
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to rotate session: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// MaybeRotateSession checks if the session should be rotated and does so if needed.
|
||||
// Called from the auth middleware after successful cookie-based authentication.
|
||||
func MaybeRotateSession(c echo.Context, db *gorm.DB, session *Session, hmacSecret string) {
|
||||
if session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
rotatedAt := session.RotatedAt
|
||||
if rotatedAt.IsZero() {
|
||||
rotatedAt = session.CreatedAt
|
||||
}
|
||||
|
||||
if time.Since(rotatedAt) < sessionRotationInterval {
|
||||
return
|
||||
}
|
||||
|
||||
newToken, err := RotateSession(db, session, hmacSecret)
|
||||
if err != nil {
|
||||
// Rotation failure is non-fatal; the old session remains valid
|
||||
return
|
||||
}
|
||||
|
||||
SetSessionCookie(c, newToken)
|
||||
}
|
||||
|
||||
// isSecure returns true when the request arrived over HTTPS, either directly
|
||||
// or via a reverse proxy that sets X-Forwarded-Proto.
|
||||
func isSecure(c echo.Context) bool {
|
||||
return c.Scheme() == "https"
|
||||
}
|
||||
|
||||
// SetSessionCookie sets the session cookie on the response.
|
||||
func SetSessionCookie(c echo.Context, sessionID string) {
|
||||
cookie := &http.Cookie{
|
||||
Name: sessionCookie,
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure(c),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(sessionDuration.Seconds()),
|
||||
}
|
||||
c.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// SetTokenCookie sets an httpOnly "token" cookie for legacy API key auth.
|
||||
func SetTokenCookie(c echo.Context, token string) {
|
||||
cookie := &http.Cookie{
|
||||
Name: "token",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure(c),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(sessionDuration.Seconds()),
|
||||
}
|
||||
c.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// ClearSessionCookie clears the session cookie.
|
||||
func ClearSessionCookie(c echo.Context) {
|
||||
cookie := &http.Cookie{
|
||||
Name: sessionCookie,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure(c),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: -1,
|
||||
}
|
||||
c.SetCookie(cookie)
|
||||
}
|
||||
272
core/http/auth/session_test.go
Normal file
272
core/http/auth/session_test.go
Normal file
@@ -0,0 +1,272 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ = Describe("Sessions", func() {
|
||||
var (
|
||||
db *gorm.DB
|
||||
user *auth.User
|
||||
)
|
||||
|
||||
// Use empty HMAC secret for basic tests
|
||||
hmacSecret := ""
|
||||
|
||||
BeforeEach(func() {
|
||||
db = testDB()
|
||||
user = createTestUser(db, "session@example.com", auth.RoleUser, auth.ProviderGitHub)
|
||||
})
|
||||
|
||||
Describe("CreateSession", func() {
|
||||
It("creates a session and returns 64-char hex plaintext token", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(token).To(HaveLen(64))
|
||||
})
|
||||
|
||||
It("stores the hash (not plaintext) in the DB", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
var session auth.Session
|
||||
err = db.First(&session, "id = ?", hash).Error
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(session.UserID).To(Equal(user.ID))
|
||||
// The plaintext token should NOT be stored as the ID
|
||||
Expect(session.ID).ToNot(Equal(token))
|
||||
Expect(session.ID).To(Equal(hash))
|
||||
})
|
||||
|
||||
It("sets expiry to approximately 30 days from now", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
var session auth.Session
|
||||
db.First(&session, "id = ?", hash)
|
||||
|
||||
expectedExpiry := time.Now().Add(30 * 24 * time.Hour)
|
||||
Expect(session.ExpiresAt).To(BeTemporally("~", expectedExpiry, time.Minute))
|
||||
})
|
||||
|
||||
It("sets RotatedAt on creation", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
var session auth.Session
|
||||
db.First(&session, "id = ?", hash)
|
||||
|
||||
Expect(session.RotatedAt).To(BeTemporally("~", time.Now(), time.Minute))
|
||||
})
|
||||
|
||||
It("associates session with correct user", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
var session auth.Session
|
||||
db.First(&session, "id = ?", hash)
|
||||
Expect(session.UserID).To(Equal(user.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ValidateSession", func() {
|
||||
It("returns user for valid session", func() {
|
||||
token := createTestSession(db, user.ID)
|
||||
|
||||
found, session := auth.ValidateSession(db, token, hmacSecret)
|
||||
Expect(found).ToNot(BeNil())
|
||||
Expect(found.ID).To(Equal(user.ID))
|
||||
Expect(session).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil for non-existent session", func() {
|
||||
found, session := auth.ValidateSession(db, "nonexistent-session-id", hmacSecret)
|
||||
Expect(found).To(BeNil())
|
||||
Expect(session).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil for expired session", func() {
|
||||
token := createTestSession(db, user.ID)
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
|
||||
// Manually expire the session
|
||||
db.Model(&auth.Session{}).Where("id = ?", hash).
|
||||
Update("expires_at", time.Now().Add(-1*time.Hour))
|
||||
|
||||
found, _ := auth.ValidateSession(db, token, hmacSecret)
|
||||
Expect(found).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("DeleteSession", func() {
|
||||
It("removes the session from DB", func() {
|
||||
token := createTestSession(db, user.ID)
|
||||
|
||||
err := auth.DeleteSession(db, token, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, _ := auth.ValidateSession(db, token, hmacSecret)
|
||||
Expect(found).To(BeNil())
|
||||
})
|
||||
|
||||
It("does not error on non-existent session", func() {
|
||||
err := auth.DeleteSession(db, "nonexistent", hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("CleanExpiredSessions", func() {
|
||||
It("removes expired sessions", func() {
|
||||
token := createTestSession(db, user.ID)
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
|
||||
// Manually expire the session
|
||||
db.Model(&auth.Session{}).Where("id = ?", hash).
|
||||
Update("expires_at", time.Now().Add(-1*time.Hour))
|
||||
|
||||
err := auth.CleanExpiredSessions(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var count int64
|
||||
db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count)
|
||||
Expect(count).To(Equal(int64(0)))
|
||||
})
|
||||
|
||||
It("keeps active sessions", func() {
|
||||
token := createTestSession(db, user.ID)
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
|
||||
err := auth.CleanExpiredSessions(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var count int64
|
||||
db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count)
|
||||
Expect(count).To(Equal(int64(1)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RotateSession", func() {
|
||||
It("creates a new session and deletes the old one", func() {
|
||||
token := createTestSession(db, user.ID)
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
|
||||
// Get the old session
|
||||
var oldSession auth.Session
|
||||
db.First(&oldSession, "id = ?", hash)
|
||||
|
||||
newToken, err := auth.RotateSession(db, &oldSession, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(newToken).To(HaveLen(64))
|
||||
Expect(newToken).ToNot(Equal(token))
|
||||
|
||||
// Old session should be gone
|
||||
var count int64
|
||||
db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count)
|
||||
Expect(count).To(Equal(int64(0)))
|
||||
|
||||
// New session should exist and validate
|
||||
found, _ := auth.ValidateSession(db, newToken, hmacSecret)
|
||||
Expect(found).ToNot(BeNil())
|
||||
Expect(found.ID).To(Equal(user.ID))
|
||||
})
|
||||
|
||||
It("preserves user ID and expiry", func() {
|
||||
token := createTestSession(db, user.ID)
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
|
||||
var oldSession auth.Session
|
||||
db.First(&oldSession, "id = ?", hash)
|
||||
|
||||
newToken, err := auth.RotateSession(db, &oldSession, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
newHash := auth.HashAPIKey(newToken, hmacSecret)
|
||||
var newSession auth.Session
|
||||
db.First(&newSession, "id = ?", newHash)
|
||||
|
||||
Expect(newSession.UserID).To(Equal(oldSession.UserID))
|
||||
Expect(newSession.ExpiresAt).To(BeTemporally("~", oldSession.ExpiresAt, time.Second))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with HMAC secret", func() {
|
||||
hmacSecret := "test-hmac-secret-123"
|
||||
|
||||
It("creates and validates sessions with HMAC secret", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, session := auth.ValidateSession(db, token, hmacSecret)
|
||||
Expect(found).ToNot(BeNil())
|
||||
Expect(found.ID).To(Equal(user.ID))
|
||||
Expect(session).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("does not validate with wrong HMAC secret", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, _ := auth.ValidateSession(db, token, "wrong-secret")
|
||||
Expect(found).To(BeNil())
|
||||
})
|
||||
|
||||
It("does not validate with empty HMAC secret", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, _ := auth.ValidateSession(db, token, "")
|
||||
Expect(found).To(BeNil())
|
||||
})
|
||||
|
||||
It("session created with empty secret does not validate with non-empty secret", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, _ := auth.ValidateSession(db, token, hmacSecret)
|
||||
Expect(found).To(BeNil())
|
||||
})
|
||||
|
||||
It("deletes session with correct HMAC secret", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = auth.DeleteSession(db, token, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found, _ := auth.ValidateSession(db, token, hmacSecret)
|
||||
Expect(found).To(BeNil())
|
||||
})
|
||||
|
||||
It("rotates session with HMAC secret", func() {
|
||||
token, err := auth.CreateSession(db, user.ID, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
hash := auth.HashAPIKey(token, hmacSecret)
|
||||
var oldSession auth.Session
|
||||
db.First(&oldSession, "id = ?", hash)
|
||||
|
||||
newToken, err := auth.RotateSession(db, &oldSession, hmacSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Old token should not validate
|
||||
found, _ := auth.ValidateSession(db, token, hmacSecret)
|
||||
Expect(found).To(BeNil())
|
||||
|
||||
// New token should validate
|
||||
found, _ = auth.ValidateSession(db, newToken, hmacSecret)
|
||||
Expect(found).ToNot(BeNil())
|
||||
Expect(found.ID).To(Equal(user.ID))
|
||||
})
|
||||
})
|
||||
})
|
||||
151
core/http/auth/usage.go
Normal file
151
core/http/auth/usage.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UsageRecord represents a single API request's token usage.
|
||||
type UsageRecord struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement"`
|
||||
UserID string `gorm:"size:36;index:idx_usage_user_time"`
|
||||
UserName string `gorm:"size:255"`
|
||||
Model string `gorm:"size:255;index"`
|
||||
Endpoint string `gorm:"size:255"`
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
TotalTokens int64
|
||||
Duration int64 // milliseconds
|
||||
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
|
||||
}
|
||||
|
||||
// RecordUsage inserts a usage record.
|
||||
func RecordUsage(db *gorm.DB, record *UsageRecord) error {
|
||||
return db.Create(record).Error
|
||||
}
|
||||
|
||||
// UsageBucket is an aggregated time bucket for the dashboard.
|
||||
type UsageBucket struct {
|
||||
Bucket string `json:"bucket"`
|
||||
Model string `json:"model"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
UserName string `json:"user_name,omitempty"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
}
|
||||
|
||||
// UsageTotals is a summary of all usage.
|
||||
type UsageTotals struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
}
|
||||
|
||||
// periodToWindow returns the time window and SQL date format for a period.
|
||||
func periodToWindow(period string, isSQLite bool) (time.Time, string) {
|
||||
now := time.Now()
|
||||
var since time.Time
|
||||
var dateFmt string
|
||||
|
||||
switch period {
|
||||
case "day":
|
||||
since = now.Add(-24 * time.Hour)
|
||||
if isSQLite {
|
||||
dateFmt = "strftime('%Y-%m-%d %H:00', created_at)"
|
||||
} else {
|
||||
dateFmt = "to_char(date_trunc('hour', created_at), 'YYYY-MM-DD HH24:00')"
|
||||
}
|
||||
case "week":
|
||||
since = now.Add(-7 * 24 * time.Hour)
|
||||
if isSQLite {
|
||||
dateFmt = "strftime('%Y-%m-%d', created_at)"
|
||||
} else {
|
||||
dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')"
|
||||
}
|
||||
case "all":
|
||||
since = time.Time{} // zero time = no filter
|
||||
if isSQLite {
|
||||
dateFmt = "strftime('%Y-%m', created_at)"
|
||||
} else {
|
||||
dateFmt = "to_char(date_trunc('month', created_at), 'YYYY-MM')"
|
||||
}
|
||||
default: // "month"
|
||||
since = now.Add(-30 * 24 * time.Hour)
|
||||
if isSQLite {
|
||||
dateFmt = "strftime('%Y-%m-%d', created_at)"
|
||||
} else {
|
||||
dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')"
|
||||
}
|
||||
}
|
||||
|
||||
return since, dateFmt
|
||||
}
|
||||
|
||||
func isSQLiteDB(db *gorm.DB) bool {
|
||||
return strings.Contains(db.Dialector.Name(), "sqlite")
|
||||
}
|
||||
|
||||
// GetUserUsage returns aggregated usage for a single user.
|
||||
func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) {
|
||||
sqlite := isSQLiteDB(db)
|
||||
since, dateFmt := periodToWindow(period, sqlite)
|
||||
|
||||
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||
|
||||
query := db.Model(&UsageRecord{}).
|
||||
Select(bucketExpr+", model, "+
|
||||
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||
"SUM(completion_tokens) as completion_tokens, "+
|
||||
"SUM(total_tokens) as total_tokens, "+
|
||||
"COUNT(*) as request_count").
|
||||
Where("user_id = ?", userID).
|
||||
Group("bucket, model").
|
||||
Order("bucket ASC")
|
||||
|
||||
if !since.IsZero() {
|
||||
query = query.Where("created_at >= ?", since)
|
||||
}
|
||||
|
||||
var buckets []UsageBucket
|
||||
if err := query.Find(&buckets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buckets, nil
|
||||
}
|
||||
|
||||
// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter.
|
||||
func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
|
||||
sqlite := isSQLiteDB(db)
|
||||
since, dateFmt := periodToWindow(period, sqlite)
|
||||
|
||||
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
|
||||
|
||||
query := db.Model(&UsageRecord{}).
|
||||
Select(bucketExpr+", model, user_id, user_name, "+
|
||||
"SUM(prompt_tokens) as prompt_tokens, "+
|
||||
"SUM(completion_tokens) as completion_tokens, "+
|
||||
"SUM(total_tokens) as total_tokens, "+
|
||||
"COUNT(*) as request_count").
|
||||
Group("bucket, model, user_id, user_name").
|
||||
Order("bucket ASC")
|
||||
|
||||
if !since.IsZero() {
|
||||
query = query.Where("created_at >= ?", since)
|
||||
}
|
||||
|
||||
if userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
|
||||
var buckets []UsageBucket
|
||||
if err := query.Find(&buckets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buckets, nil
|
||||
}
|
||||
161
core/http/auth/usage_test.go
Normal file
161
core/http/auth/usage_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Usage", func() {
|
||||
Describe("RecordUsage", func() {
|
||||
It("inserts a usage record", func() {
|
||||
db := testDB()
|
||||
record := &auth.UsageRecord{
|
||||
UserID: "user-1",
|
||||
UserName: "Test User",
|
||||
Model: "gpt-4",
|
||||
Endpoint: "/v1/chat/completions",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
Duration: 1200,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
err := auth.RecordUsage(db, record)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(record.ID).ToNot(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetUserUsage", func() {
|
||||
It("returns aggregated usage for a specific user", func() {
|
||||
db := testDB()
|
||||
|
||||
// Insert records for two users
|
||||
for i := 0; i < 3; i++ {
|
||||
err := auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: "user-a",
|
||||
UserName: "Alice",
|
||||
Model: "gpt-4",
|
||||
Endpoint: "/v1/chat/completions",
|
||||
PromptTokens: 100,
|
||||
TotalTokens: 150,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: "user-b",
|
||||
UserName: "Bob",
|
||||
Model: "gpt-4",
|
||||
PromptTokens: 200,
|
||||
TotalTokens: 300,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
buckets, err := auth.GetUserUsage(db, "user-a", "month")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(buckets).ToNot(BeEmpty())
|
||||
|
||||
// All returned buckets should be for user-a's model
|
||||
totalPrompt := int64(0)
|
||||
for _, b := range buckets {
|
||||
totalPrompt += b.PromptTokens
|
||||
}
|
||||
Expect(totalPrompt).To(Equal(int64(300)))
|
||||
})
|
||||
|
||||
It("filters by period", func() {
|
||||
db := testDB()
|
||||
|
||||
// Record in the past (beyond day window)
|
||||
err := auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: "user-c",
|
||||
UserName: "Carol",
|
||||
Model: "gpt-4",
|
||||
PromptTokens: 100,
|
||||
TotalTokens: 100,
|
||||
CreatedAt: time.Now().Add(-48 * time.Hour),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Record now
|
||||
err = auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: "user-c",
|
||||
UserName: "Carol",
|
||||
Model: "gpt-4",
|
||||
PromptTokens: 200,
|
||||
TotalTokens: 200,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Day period should only include recent record
|
||||
buckets, err := auth.GetUserUsage(db, "user-c", "day")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
totalPrompt := int64(0)
|
||||
for _, b := range buckets {
|
||||
totalPrompt += b.PromptTokens
|
||||
}
|
||||
Expect(totalPrompt).To(Equal(int64(200)))
|
||||
|
||||
// Month period should include both
|
||||
buckets, err = auth.GetUserUsage(db, "user-c", "month")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
totalPrompt = 0
|
||||
for _, b := range buckets {
|
||||
totalPrompt += b.PromptTokens
|
||||
}
|
||||
Expect(totalPrompt).To(Equal(int64(300)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetAllUsage", func() {
|
||||
It("returns usage for all users", func() {
|
||||
db := testDB()
|
||||
|
||||
for _, uid := range []string{"user-x", "user-y"} {
|
||||
err := auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: uid,
|
||||
UserName: uid,
|
||||
Model: "gpt-4",
|
||||
PromptTokens: 100,
|
||||
TotalTokens: 150,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
buckets, err := auth.GetAllUsage(db, "month", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(buckets)).To(BeNumerically(">=", 2))
|
||||
})
|
||||
|
||||
It("filters by user ID when specified", func() {
|
||||
db := testDB()
|
||||
|
||||
err := auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: "user-p", UserName: "Pat", Model: "gpt-4",
|
||||
PromptTokens: 100, TotalTokens: 100, CreatedAt: time.Now(),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = auth.RecordUsage(db, &auth.UsageRecord{
|
||||
UserID: "user-q", UserName: "Quinn", Model: "gpt-4",
|
||||
PromptTokens: 200, TotalTokens: 200, CreatedAt: time.Now(),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
buckets, err := auth.GetAllUsage(db, "month", "user-p")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for _, b := range buckets {
|
||||
Expect(b.UserID).To(Equal("user-p"))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -197,54 +198,24 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
xlog.Debug("Anthropic MCP re-templating", "iteration", mcpIteration, "prompt_len", len(predInput))
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertFuncsToOpenAITools(funcs)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
openAIReq.Metadata = input.Metadata
|
||||
|
||||
toolsJSON := ""
|
||||
if len(funcs) > 0 {
|
||||
openAITools := make([]functions.Tool, len(funcs))
|
||||
for i, f := range funcs {
|
||||
openAITools[i] = functions.Tool{Type: "function", Function: f}
|
||||
}
|
||||
if toolsBytes, err := json.Marshal(openAITools); err == nil {
|
||||
toolsJSON = string(toolsBytes)
|
||||
}
|
||||
var result string
|
||||
cb := func(s string, c *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, nil, nil, nil, input.Metadata)
|
||||
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic model inference failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var prediction backend.LLMResponse
|
||||
var result string
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic prediction failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
|
||||
}
|
||||
result = backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
if result != "" || !shouldUseFn {
|
||||
break
|
||||
}
|
||||
xlog.Warn("Anthropic: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
}
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first, fall back to text parsing
|
||||
var toolCalls []functions.FuncCallResults
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] Anthropic: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
toolCalls = deltaToolCalls
|
||||
} else {
|
||||
@@ -350,8 +321,8 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
StopReason: &stopReason,
|
||||
Content: contentBlocks,
|
||||
Usage: schema.AnthropicUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
InputTokens: tokenUsage.Prompt,
|
||||
OutputTokens: tokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -397,12 +368,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
xlog.Debug("Anthropic MCP stream re-templating", "iteration", mcpIteration)
|
||||
}
|
||||
|
||||
openAIMessages := openAIReq.Messages
|
||||
images := []string{}
|
||||
for _, m := range openAIMessages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
// Track accumulated content for tool call detection
|
||||
accumulatedContent := ""
|
||||
currentBlockIndex := 0
|
||||
@@ -481,38 +446,19 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
return true
|
||||
}
|
||||
|
||||
toolsJSON := ""
|
||||
if len(funcs) > 0 {
|
||||
openAITools := make([]functions.Tool, len(funcs))
|
||||
for i, f := range funcs {
|
||||
openAITools[i] = functions.Tool{Type: "function", Function: f}
|
||||
}
|
||||
if toolsBytes, err := json.Marshal(openAITools); err == nil {
|
||||
toolsJSON = string(toolsBytes)
|
||||
}
|
||||
}
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertFuncsToOpenAITools(funcs)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
openAIReq.Metadata = input.Metadata
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, nil, nil, nil, input.Metadata)
|
||||
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic stream model inference failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic stream prediction failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
|
||||
}
|
||||
|
||||
// Also check chat deltas for tool calls
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
|
||||
collectedToolCalls = deltaToolCalls
|
||||
}
|
||||
|
||||
@@ -595,7 +541,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
StopReason: &stopReason,
|
||||
},
|
||||
Usage: &schema.AnthropicUsage{
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
OutputTokens: tokenUsage.Completion,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -613,6 +559,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
|
||||
tools := make([]functions.Tool, len(funcs))
|
||||
for i, f := range funcs {
|
||||
tools[i] = functions.Tool{Type: "function", Function: f}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
|
||||
@@ -12,27 +12,54 @@ import (
|
||||
func ListCollectionsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
collections, err := svc.ListCollections()
|
||||
userID := getUserID(c)
|
||||
cols, err := svc.ListCollectionsForUser(userID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"collections": collections,
|
||||
"count": len(collections),
|
||||
})
|
||||
|
||||
resp := map[string]any{
|
||||
"collections": cols,
|
||||
"count": len(cols),
|
||||
}
|
||||
|
||||
// Admin cross-user aggregation
|
||||
if wantsAllUsers(c) {
|
||||
usm := svc.UserServicesManager()
|
||||
if usm != nil {
|
||||
userIDs, _ := usm.ListAllUserIDs()
|
||||
userGroups := map[string]any{}
|
||||
for _, uid := range userIDs {
|
||||
if uid == userID {
|
||||
continue
|
||||
}
|
||||
userCols, err := svc.ListCollectionsForUser(uid)
|
||||
if err != nil || len(userCols) == 0 {
|
||||
continue
|
||||
}
|
||||
userGroups[uid] = map[string]any{"collections": userCols}
|
||||
}
|
||||
if len(userGroups) > 0 {
|
||||
resp["user_groups"] = userGroups
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
var payload struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.CreateCollection(payload.Name); err != nil {
|
||||
if err := svc.CreateCollectionForUser(userID, payload.Name); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusCreated, map[string]string{"status": "ok", "name": payload.Name})
|
||||
@@ -42,33 +69,33 @@ func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
|
||||
}
|
||||
if svc.CollectionEntryExists(name, file.Filename) {
|
||||
return c.JSON(http.StatusConflict, map[string]string{"error": "entry already exists"})
|
||||
}
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
defer src.Close()
|
||||
if err := svc.UploadToCollection(name, file.Filename, src); err != nil {
|
||||
key, err := svc.UploadToCollectionForUser(userID, name, file.Filename, src)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename})
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename, "key": key})
|
||||
}
|
||||
}
|
||||
|
||||
func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
entries, err := svc.ListCollectionEntries(c.Param("name"))
|
||||
userID := effectiveUserID(c)
|
||||
entries, err := svc.ListCollectionEntriesForUser(userID, c.Param("name"))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -85,12 +112,13 @@ func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFun
|
||||
func GetCollectionEntryContentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
entryParam := c.Param("*")
|
||||
entry, err := url.PathUnescape(entryParam)
|
||||
if err != nil {
|
||||
entry = entryParam
|
||||
}
|
||||
content, chunkCount, err := svc.GetCollectionEntryContent(c.Param("name"), entry)
|
||||
content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, c.Param("name"), entry)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -107,6 +135,7 @@ func GetCollectionEntryContentEndpoint(app *application.Application) echo.Handle
|
||||
func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
var payload struct {
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results"`
|
||||
@@ -114,7 +143,7 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
results, err := svc.SearchCollection(c.Param("name"), payload.Query, payload.MaxResults)
|
||||
results, err := svc.SearchCollectionForUser(userID, c.Param("name"), payload.Query, payload.MaxResults)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -131,7 +160,8 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
if err := svc.ResetCollection(c.Param("name")); err != nil {
|
||||
userID := effectiveUserID(c)
|
||||
if err := svc.ResetCollectionForUser(userID, c.Param("name")); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -144,13 +174,14 @@ func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
var payload struct {
|
||||
Entry string `json:"entry"`
|
||||
}
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
remaining, err := svc.DeleteCollectionEntry(c.Param("name"), payload.Entry)
|
||||
remaining, err := svc.DeleteCollectionEntryForUser(userID, c.Param("name"), payload.Entry)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -167,6 +198,7 @@ func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFun
|
||||
func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
var payload struct {
|
||||
URL string `json:"url"`
|
||||
UpdateInterval int `json:"update_interval"`
|
||||
@@ -177,7 +209,7 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
if payload.UpdateInterval < 1 {
|
||||
payload.UpdateInterval = 60
|
||||
}
|
||||
if err := svc.AddCollectionSource(c.Param("name"), payload.URL, payload.UpdateInterval); err != nil {
|
||||
if err := svc.AddCollectionSourceForUser(userID, c.Param("name"), payload.URL, payload.UpdateInterval); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -190,23 +222,46 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
var payload struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.RemoveCollectionSource(c.Param("name"), payload.URL); err != nil {
|
||||
if err := svc.RemoveCollectionSourceForUser(userID, c.Param("name"), payload.URL); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
}
|
||||
|
||||
// GetCollectionEntryRawFileEndpoint serves the original uploaded binary file.
|
||||
func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
entryParam := c.Param("*")
|
||||
entry, err := url.PathUnescape(entryParam)
|
||||
if err != nil {
|
||||
entry = entryParam
|
||||
}
|
||||
fpath, err := svc.GetCollectionEntryFilePathForUser(userID, c.Param("name"), entry)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.File(fpath)
|
||||
}
|
||||
}
|
||||
|
||||
func ListCollectionSourcesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
sources, err := svc.ListCollectionSources(c.Param("name"))
|
||||
userID := effectiveUserID(c)
|
||||
sources, err := svc.ListCollectionSourcesForUser(userID, c.Param("name"))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
|
||||
@@ -8,19 +8,27 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// CreateTaskEndpoint creates a new agent task
|
||||
// @Summary Create a new agent task
|
||||
// @Description Create a new reusable agent task with prompt template and configuration
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task body schema.Task true "Task definition"
|
||||
// @Success 201 {object} map[string]string "Task created"
|
||||
// @Failure 400 {object} map[string]string "Invalid request"
|
||||
// @Failure 500 {object} map[string]string "Internal server error"
|
||||
// @Router /api/agent/tasks [post]
|
||||
// getJobService returns the job service for the current user.
|
||||
// Falls back to the global service when no user is authenticated.
|
||||
func getJobService(app *application.Application, c echo.Context) *services.AgentJobService {
|
||||
userID := getUserID(c)
|
||||
if userID == "" {
|
||||
return app.AgentJobService()
|
||||
}
|
||||
svc := app.AgentPoolService()
|
||||
if svc == nil {
|
||||
return app.AgentJobService()
|
||||
}
|
||||
jobSvc, err := svc.JobServiceForUser(userID)
|
||||
if err != nil {
|
||||
return app.AgentJobService()
|
||||
}
|
||||
return jobSvc
|
||||
}
|
||||
|
||||
func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var task schema.Task
|
||||
@@ -28,7 +36,7 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()})
|
||||
}
|
||||
|
||||
id, err := app.AgentJobService().CreateTask(task)
|
||||
id, err := getJobService(app, c).CreateTask(task)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -37,18 +45,6 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTaskEndpoint updates an existing task
|
||||
// @Summary Update an agent task
|
||||
// @Description Update an existing agent task
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Param task body schema.Task true "Updated task definition"
|
||||
// @Success 200 {object} map[string]string "Task updated"
|
||||
// @Failure 400 {object} map[string]string "Invalid request"
|
||||
// @Failure 404 {object} map[string]string "Task not found"
|
||||
// @Router /api/agent/tasks/{id} [put]
|
||||
func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
@@ -57,7 +53,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()})
|
||||
}
|
||||
|
||||
if err := app.AgentJobService().UpdateTask(id, task); err != nil {
|
||||
if err := getJobService(app, c).UpdateTask(id, task); err != nil {
|
||||
if err.Error() == "task not found: "+id {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -68,19 +64,10 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteTaskEndpoint deletes a task
|
||||
// @Summary Delete an agent task
|
||||
// @Description Delete an agent task by ID
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Success 200 {object} map[string]string "Task deleted"
|
||||
// @Failure 404 {object} map[string]string "Task not found"
|
||||
// @Router /api/agent/tasks/{id} [delete]
|
||||
func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
if err := app.AgentJobService().DeleteTask(id); err != nil {
|
||||
if err := getJobService(app, c).DeleteTask(id); err != nil {
|
||||
if err.Error() == "task not found: "+id {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -91,33 +78,52 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ListTasksEndpoint lists all tasks
|
||||
// @Summary List all agent tasks
|
||||
// @Description Get a list of all agent tasks
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Success 200 {array} schema.Task "List of tasks"
|
||||
// @Router /api/agent/tasks [get]
|
||||
func ListTasksEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
tasks := app.AgentJobService().ListTasks()
|
||||
jobSvc := getJobService(app, c)
|
||||
tasks := jobSvc.ListTasks()
|
||||
|
||||
// Admin cross-user aggregation
|
||||
if wantsAllUsers(c) {
|
||||
svc := app.AgentPoolService()
|
||||
if svc != nil {
|
||||
usm := svc.UserServicesManager()
|
||||
if usm != nil {
|
||||
userID := getUserID(c)
|
||||
userIDs, _ := usm.ListAllUserIDs()
|
||||
userGroups := map[string]any{}
|
||||
for _, uid := range userIDs {
|
||||
if uid == userID {
|
||||
continue
|
||||
}
|
||||
userJobSvc, err := svc.JobServiceForUser(uid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
userTasks := userJobSvc.ListTasks()
|
||||
if len(userTasks) == 0 {
|
||||
continue
|
||||
}
|
||||
userGroups[uid] = map[string]any{"tasks": userTasks}
|
||||
}
|
||||
if len(userGroups) > 0 {
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"tasks": tasks,
|
||||
"user_groups": userGroups,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, tasks)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTaskEndpoint gets a task by ID
|
||||
// @Summary Get an agent task
|
||||
// @Description Get an agent task by ID
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Task ID"
|
||||
// @Success 200 {object} schema.Task "Task details"
|
||||
// @Failure 404 {object} map[string]string "Task not found"
|
||||
// @Router /api/agent/tasks/{id} [get]
|
||||
func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
task, err := app.AgentJobService().GetTask(id)
|
||||
task, err := getJobService(app, c).GetTask(id)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -126,16 +132,6 @@ func GetTaskEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteJobEndpoint executes a job
|
||||
// @Summary Execute an agent job
|
||||
// @Description Create and execute a new agent job
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.JobExecutionRequest true "Job execution request"
|
||||
// @Success 201 {object} schema.JobExecutionResponse "Job created"
|
||||
// @Failure 400 {object} map[string]string "Invalid request"
|
||||
// @Router /api/agent/jobs/execute [post]
|
||||
func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.JobExecutionRequest
|
||||
@@ -147,7 +143,6 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
req.Parameters = make(map[string]string)
|
||||
}
|
||||
|
||||
// Build multimedia struct from request
|
||||
var multimedia *schema.MultimediaAttachment
|
||||
if len(req.Images) > 0 || len(req.Videos) > 0 || len(req.Audios) > 0 || len(req.Files) > 0 {
|
||||
multimedia = &schema.MultimediaAttachment{
|
||||
@@ -158,7 +153,7 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
jobID, err := app.AgentJobService().ExecuteJob(req.TaskID, req.Parameters, "api", multimedia)
|
||||
jobID, err := getJobService(app, c).ExecuteJob(req.TaskID, req.Parameters, "api", multimedia)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -172,19 +167,10 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// GetJobEndpoint gets a job by ID
|
||||
// @Summary Get an agent job
|
||||
// @Description Get an agent job by ID
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} schema.Job "Job details"
|
||||
// @Failure 404 {object} map[string]string "Job not found"
|
||||
// @Router /api/agent/jobs/{id} [get]
|
||||
func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
job, err := app.AgentJobService().GetJob(id)
|
||||
job, err := getJobService(app, c).GetJob(id)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -193,16 +179,6 @@ func GetJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ListJobsEndpoint lists jobs with optional filtering
|
||||
// @Summary List agent jobs
|
||||
// @Description Get a list of agent jobs, optionally filtered by task_id and status
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param task_id query string false "Filter by task ID"
|
||||
// @Param status query string false "Filter by status (pending, running, completed, failed, cancelled)"
|
||||
// @Param limit query int false "Limit number of results"
|
||||
// @Success 200 {array} schema.Job "List of jobs"
|
||||
// @Router /api/agent/jobs [get]
|
||||
func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var taskID *string
|
||||
@@ -224,25 +200,50 @@ func ListJobsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
jobs := app.AgentJobService().ListJobs(taskID, status, limit)
|
||||
jobSvc := getJobService(app, c)
|
||||
jobs := jobSvc.ListJobs(taskID, status, limit)
|
||||
|
||||
// Admin cross-user aggregation
|
||||
if wantsAllUsers(c) {
|
||||
svc := app.AgentPoolService()
|
||||
if svc != nil {
|
||||
usm := svc.UserServicesManager()
|
||||
if usm != nil {
|
||||
userID := getUserID(c)
|
||||
userIDs, _ := usm.ListAllUserIDs()
|
||||
userGroups := map[string]any{}
|
||||
for _, uid := range userIDs {
|
||||
if uid == userID {
|
||||
continue
|
||||
}
|
||||
userJobSvc, err := svc.JobServiceForUser(uid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
userJobs := userJobSvc.ListJobs(taskID, status, limit)
|
||||
if len(userJobs) == 0 {
|
||||
continue
|
||||
}
|
||||
userGroups[uid] = map[string]any{"jobs": userJobs}
|
||||
}
|
||||
if len(userGroups) > 0 {
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"jobs": jobs,
|
||||
"user_groups": userGroups,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, jobs)
|
||||
}
|
||||
}
|
||||
|
||||
// CancelJobEndpoint cancels a running job
|
||||
// @Summary Cancel an agent job
|
||||
// @Description Cancel a running or pending agent job
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} map[string]string "Job cancelled"
|
||||
// @Failure 400 {object} map[string]string "Job cannot be cancelled"
|
||||
// @Failure 404 {object} map[string]string "Job not found"
|
||||
// @Router /api/agent/jobs/{id}/cancel [post]
|
||||
func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
if err := app.AgentJobService().CancelJob(id); err != nil {
|
||||
if err := getJobService(app, c).CancelJob(id); err != nil {
|
||||
if err.Error() == "job not found: "+id {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -253,19 +254,10 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteJobEndpoint deletes a job
|
||||
// @Summary Delete an agent job
|
||||
// @Description Delete an agent job by ID
|
||||
// @Tags agent-jobs
|
||||
// @Produce json
|
||||
// @Param id path string true "Job ID"
|
||||
// @Success 200 {object} map[string]string "Job deleted"
|
||||
// @Failure 404 {object} map[string]string "Job not found"
|
||||
// @Router /api/agent/jobs/{id} [delete]
|
||||
func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := c.Param("id")
|
||||
if err := app.AgentJobService().DeleteJob(id); err != nil {
|
||||
if err := getJobService(app, c).DeleteJob(id); err != nil {
|
||||
if err.Error() == "job not found: "+id {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -276,52 +268,33 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTaskByNameEndpoint executes a task by name
|
||||
// @Summary Execute a task by name
|
||||
// @Description Execute an agent task by its name (convenience endpoint). Parameters can be provided in the request body as a JSON object with string values.
|
||||
// @Tags agent-jobs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param name path string true "Task name"
|
||||
// @Param request body map[string]string false "Template parameters (JSON object with string values)"
|
||||
// @Success 201 {object} schema.JobExecutionResponse "Job created"
|
||||
// @Failure 400 {object} map[string]string "Invalid request"
|
||||
// @Failure 404 {object} map[string]string "Task not found"
|
||||
// @Router /api/agent/tasks/{name}/execute [post]
|
||||
func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
name := c.Param("name")
|
||||
var params map[string]string
|
||||
|
||||
// Try to bind parameters from request body
|
||||
// If body is empty or invalid, use empty params
|
||||
if c.Request().ContentLength > 0 {
|
||||
if err := c.Bind(¶ms); err != nil {
|
||||
// If binding fails, try to read as raw JSON
|
||||
body := make(map[string]interface{})
|
||||
if err := c.Bind(&body); err == nil {
|
||||
// Convert interface{} values to strings
|
||||
params = make(map[string]string)
|
||||
for k, v := range body {
|
||||
if str, ok := v.(string); ok {
|
||||
params[k] = str
|
||||
} else {
|
||||
// Convert non-string values to string
|
||||
params[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If all binding fails, use empty params
|
||||
params = make(map[string]string)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No body provided, use empty params
|
||||
params = make(map[string]string)
|
||||
}
|
||||
|
||||
// Find task by name
|
||||
tasks := app.AgentJobService().ListTasks()
|
||||
jobSvc := getJobService(app, c)
|
||||
tasks := jobSvc.ListTasks()
|
||||
var task *schema.Task
|
||||
for _, t := range tasks {
|
||||
if t.Name == name {
|
||||
@@ -334,7 +307,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Task not found: " + name})
|
||||
}
|
||||
|
||||
jobID, err := app.AgentJobService().ExecuteJob(task.ID, params, "api", nil)
|
||||
jobID, err := jobSvc.ExecuteJob(task.ID, params, "api", nil)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func skillToResponse(s skilldomain.Skill) skillResponse {
|
||||
out.License = s.Metadata.License
|
||||
out.Compatibility = s.Metadata.Compatibility
|
||||
out.Metadata = s.Metadata.Metadata
|
||||
out.AllowedTools = s.Metadata.AllowedTools
|
||||
out.AllowedTools = s.Metadata.AllowedTools.String()
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -44,10 +44,38 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse {
|
||||
func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
skills, err := svc.ListSkills()
|
||||
userID := getUserID(c)
|
||||
skills, err := svc.ListSkillsForUser(userID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Admin cross-user aggregation
|
||||
if wantsAllUsers(c) {
|
||||
usm := svc.UserServicesManager()
|
||||
if usm != nil {
|
||||
userIDs, _ := usm.ListAllUserIDs()
|
||||
userGroups := map[string]any{}
|
||||
for _, uid := range userIDs {
|
||||
if uid == userID {
|
||||
continue
|
||||
}
|
||||
userSkills, err := svc.ListSkillsForUser(uid)
|
||||
if err != nil || len(userSkills) == 0 {
|
||||
continue
|
||||
}
|
||||
userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)}
|
||||
}
|
||||
resp := map[string]any{
|
||||
"skills": skillsToResponses(skills),
|
||||
}
|
||||
if len(userGroups) > 0 {
|
||||
resp["user_groups"] = userGroups
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, skillsToResponses(skills))
|
||||
}
|
||||
}
|
||||
@@ -55,7 +83,8 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
cfg := svc.GetSkillsConfig()
|
||||
userID := getUserID(c)
|
||||
cfg := svc.GetSkillsConfigForUser(userID)
|
||||
return c.JSON(http.StatusOK, cfg)
|
||||
}
|
||||
}
|
||||
@@ -63,8 +92,9 @@ func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
query := c.QueryParam("q")
|
||||
skills, err := svc.SearchSkills(query)
|
||||
skills, err := svc.SearchSkillsForUser(userID, query)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -75,6 +105,7 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
var payload struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
@@ -87,7 +118,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
skill, err := svc.CreateSkill(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||
skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "already exists") {
|
||||
return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()})
|
||||
@@ -101,7 +132,8 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
skill, err := svc.GetSkill(c.Param("name"))
|
||||
userID := effectiveUserID(c)
|
||||
skill, err := svc.GetSkillForUser(userID, c.Param("name"))
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -112,6 +144,7 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
var payload struct {
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
@@ -123,7 +156,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
skill, err := svc.UpdateSkill(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||
skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -137,7 +170,8 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
if err := svc.DeleteSkill(c.Param("name")); err != nil {
|
||||
userID := effectiveUserID(c)
|
||||
if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -147,9 +181,9 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
// The wildcard param captures the path after /export/
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("*")
|
||||
data, err := svc.ExportSkill(name)
|
||||
data, err := svc.ExportSkillForUser(userID, name)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -162,6 +196,7 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
|
||||
@@ -175,7 +210,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
skill, err := svc.ImportSkill(data)
|
||||
skill, err := svc.ImportSkillForUser(userID, data)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -188,7 +223,8 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
resources, skill, err := svc.ListSkillResources(c.Param("name"))
|
||||
userID := effectiveUserID(c)
|
||||
resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name"))
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -225,7 +261,8 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
content, info, err := svc.GetSkillResource(c.Param("name"), c.Param("*"))
|
||||
userID := effectiveUserID(c)
|
||||
content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*"))
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -245,6 +282,7 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"})
|
||||
@@ -262,7 +300,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.CreateSkillResource(c.Param("name"), path, data); err != nil {
|
||||
if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusCreated, map[string]string{"path": path})
|
||||
@@ -272,13 +310,14 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
var payload struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.UpdateSkillResource(c.Param("name"), c.Param("*"), payload.Content); err != nil {
|
||||
if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -288,7 +327,8 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
if err := svc.DeleteSkillResource(c.Param("name"), c.Param("*")); err != nil {
|
||||
userID := getUserID(c)
|
||||
if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -300,7 +340,8 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc
|
||||
func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
repos, err := svc.ListGitRepos()
|
||||
userID := getUserID(c)
|
||||
repos, err := svc.ListGitReposForUser(userID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -311,13 +352,14 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
var payload struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
repo, err := svc.AddGitRepo(payload.URL)
|
||||
repo, err := svc.AddGitRepoForUser(userID, payload.URL)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -328,6 +370,7 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
var payload struct {
|
||||
URL string `json:"url"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
@@ -335,7 +378,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
repo, err := svc.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled)
|
||||
repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -349,7 +392,8 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
if err := svc.DeleteGitRepo(c.Param("id")); err != nil {
|
||||
userID := getUserID(c)
|
||||
if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -362,7 +406,8 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
if err := svc.SyncGitRepo(c.Param("id")); err != nil {
|
||||
userID := getUserID(c)
|
||||
if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"})
|
||||
@@ -372,7 +417,8 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
repo, err := svc.ToggleGitRepo(c.Param("id"))
|
||||
userID := getUserID(c)
|
||||
repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id"))
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
@@ -19,10 +20,42 @@ import (
|
||||
agiServices "github.com/mudler/LocalAGI/services"
|
||||
)
|
||||
|
||||
// getUserID extracts the scoped user ID from the request context.
|
||||
// Returns empty string when auth is not active (backward compat).
|
||||
func getUserID(c echo.Context) string {
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
return user.ID
|
||||
}
|
||||
|
||||
// isAdminUser returns true if the authenticated user has admin role.
|
||||
func isAdminUser(c echo.Context) bool {
|
||||
user := auth.GetUser(c)
|
||||
return user != nil && user.Role == auth.RoleAdmin
|
||||
}
|
||||
|
||||
// wantsAllUsers returns true if the request has ?all_users=true and the user is admin.
|
||||
func wantsAllUsers(c echo.Context) bool {
|
||||
return c.QueryParam("all_users") == "true" && isAdminUser(c)
|
||||
}
|
||||
|
||||
// effectiveUserID returns the user ID to scope operations to.
|
||||
// SECURITY: Only admins may supply ?user_id=<id> to operate on another user's
|
||||
// resources. Non-admin callers always get their own ID regardless of query params.
|
||||
func effectiveUserID(c echo.Context) string {
|
||||
if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) {
|
||||
return targetUID
|
||||
}
|
||||
return getUserID(c)
|
||||
}
|
||||
|
||||
func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
statuses := svc.ListAgents()
|
||||
userID := getUserID(c)
|
||||
statuses := svc.ListAgentsForUser(userID)
|
||||
agents := make([]string, 0, len(statuses))
|
||||
for name := range statuses {
|
||||
agents = append(agents, name)
|
||||
@@ -38,6 +71,22 @@ func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if hubURL := svc.AgentHubURL(); hubURL != "" {
|
||||
resp["agent_hub_url"] = hubURL
|
||||
}
|
||||
|
||||
// Admin cross-user aggregation
|
||||
if wantsAllUsers(c) {
|
||||
grouped := svc.ListAllAgentsGrouped()
|
||||
userGroups := map[string]any{}
|
||||
for uid, agentList := range grouped {
|
||||
if uid == userID || uid == "" {
|
||||
continue
|
||||
}
|
||||
userGroups[uid] = map[string]any{"agents": agentList}
|
||||
}
|
||||
if len(userGroups) > 0 {
|
||||
resp["user_groups"] = userGroups
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
@@ -45,11 +94,12 @@ func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
var cfg state.AgentConfig
|
||||
if err := c.Bind(&cfg); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.CreateAgent(&cfg); err != nil {
|
||||
if err := svc.CreateAgentForUser(userID, &cfg); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusCreated, map[string]string{"status": "ok"})
|
||||
@@ -59,8 +109,9 @@ func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
ag := svc.GetAgent(name)
|
||||
ag := svc.GetAgentForUser(userID, name)
|
||||
if ag == nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||
}
|
||||
@@ -73,12 +124,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
var cfg state.AgentConfig
|
||||
if err := c.Bind(&cfg); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.UpdateAgent(name, &cfg); err != nil {
|
||||
if err := svc.UpdateAgentForUser(userID, name, &cfg); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -91,8 +143,9 @@ func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
if err := svc.DeleteAgent(name); err != nil {
|
||||
if err := svc.DeleteAgentForUser(userID, name); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -102,8 +155,9 @@ func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
cfg := svc.GetAgentConfig(name)
|
||||
cfg := svc.GetAgentConfigForUser(userID, name)
|
||||
if cfg == nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||
}
|
||||
@@ -114,7 +168,8 @@ func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
if err := svc.PauseAgent(c.Param("name")); err != nil {
|
||||
userID := effectiveUserID(c)
|
||||
if err := svc.PauseAgentForUser(userID, c.Param("name")); err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -124,7 +179,8 @@ func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
if err := svc.ResumeAgent(c.Param("name")); err != nil {
|
||||
userID := effectiveUserID(c)
|
||||
if err := svc.ResumeAgentForUser(userID, c.Param("name")); err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
@@ -134,8 +190,9 @@ func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
history := svc.GetAgentStatus(name)
|
||||
history := svc.GetAgentStatusForUser(userID, name)
|
||||
if history == nil {
|
||||
history = &state.Status{ActionResults: []coreTypes.ActionState{}}
|
||||
}
|
||||
@@ -162,8 +219,9 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
history, err := svc.GetAgentObservables(name)
|
||||
history, err := svc.GetAgentObservablesForUser(userID, name)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -177,8 +235,9 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc
|
||||
func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
if err := svc.ClearAgentObservables(name); err != nil {
|
||||
if err := svc.ClearAgentObservablesForUser(userID, name); err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{"Name": name, "cleared": true})
|
||||
@@ -188,6 +247,7 @@ func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFun
|
||||
func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
var payload struct {
|
||||
Message string `json:"message"`
|
||||
@@ -199,7 +259,7 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if message == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Message cannot be empty"})
|
||||
}
|
||||
messageID, err := svc.Chat(name, message)
|
||||
messageID, err := svc.ChatForUser(userID, name, message)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
@@ -216,8 +276,9 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
manager := svc.GetSSEManager(name)
|
||||
manager := svc.GetSSEManagerForUser(userID, name)
|
||||
if manager == nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
|
||||
}
|
||||
@@ -243,8 +304,9 @@ func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := effectiveUserID(c)
|
||||
name := c.Param("name")
|
||||
data, err := svc.ExportAgent(name)
|
||||
data, err := svc.ExportAgentForUser(userID, name)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
@@ -256,6 +318,7 @@ func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
svc := app.AgentPoolService()
|
||||
userID := getUserID(c)
|
||||
|
||||
// Try multipart form file first
|
||||
file, err := c.FormFile("file")
|
||||
@@ -269,7 +332,7 @@ func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to read file"})
|
||||
}
|
||||
if err := svc.ImportAgent(data); err != nil {
|
||||
if err := svc.ImportAgentForUser(userID, data); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusCreated, map[string]string{"status": "ok"})
|
||||
@@ -284,7 +347,7 @@ func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
if err := svc.ImportAgent(data); err != nil {
|
||||
if err := svc.ImportAgentForUser(userID, data); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusCreated, map[string]string{"status": "ok"})
|
||||
@@ -358,10 +421,16 @@ func AgentFileEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": "file not found"})
|
||||
}
|
||||
|
||||
// Only serve files from the outputs subdirectory
|
||||
outputsDir, _ := filepath.EvalSymlinks(filepath.Clean(svc.OutputsDir()))
|
||||
// Determine the allowed outputs directory — scoped to the user when auth is active
|
||||
allowedDir := svc.OutputsDir()
|
||||
user := auth.GetUser(c)
|
||||
if user != nil {
|
||||
allowedDir = filepath.Join(allowedDir, user.ID)
|
||||
}
|
||||
|
||||
if utils.InTrustedRoot(resolved, outputsDir) != nil {
|
||||
allowedDirResolved, _ := filepath.EvalSymlinks(filepath.Clean(allowedDir))
|
||||
|
||||
if utils.InTrustedRoot(resolved, allowedDirResolved) != nil {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "access denied"})
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -20,6 +21,9 @@ import (
|
||||
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||
modelName = decoded
|
||||
}
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
@@ -82,6 +86,9 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if decoded, err := url.PathUnescape(modelName); err == nil {
|
||||
modelName = decoded
|
||||
}
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
|
||||
362
core/http/endpoints/localai/finetune.go
Normal file
362
core/http/endpoints/localai/finetune.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// StartFineTuneJobEndpoint starts a new fine-tuning job.
|
||||
func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
|
||||
var req schema.FineTuneJobRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "Invalid request: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "model is required",
|
||||
})
|
||||
}
|
||||
if req.DatasetSource == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "dataset_source is required",
|
||||
})
|
||||
}
|
||||
|
||||
resp, err := ftService.StartJob(c.Request().Context(), userID, req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusCreated, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user.
|
||||
func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobs := ftService.ListJobs(userID)
|
||||
if jobs == nil {
|
||||
jobs = []*schema.FineTuneJob{}
|
||||
}
|
||||
return c.JSON(http.StatusOK, jobs)
|
||||
}
|
||||
}
|
||||
|
||||
// GetFineTuneJobEndpoint gets a specific fine-tuning job.
|
||||
func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
job, err := ftService.GetJob(userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, job)
|
||||
}
|
||||
}
|
||||
|
||||
// StopFineTuneJobEndpoint stops a running fine-tuning job.
|
||||
func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
// Check for save_checkpoint query param
|
||||
saveCheckpoint := c.QueryParam("save_checkpoint") == "true"
|
||||
|
||||
err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"status": "stopped",
|
||||
"message": "Fine-tuning job stopped",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteFineTuneJobEndpoint deletes a fine-tuning job and its data.
|
||||
func DeleteFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
err := ftService.DeleteJob(userID, jobID)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
} else if strings.Contains(err.Error(), "cannot delete") {
|
||||
status = http.StatusConflict
|
||||
}
|
||||
return c.JSON(status, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"status": "deleted",
|
||||
"message": "Fine-tuning job deleted",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// FineTuneProgressEndpoint streams progress updates via SSE.
|
||||
func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
// Set SSE headers
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(c.Response(), "data: %s\n\n", data)
|
||||
c.Response().Flush()
|
||||
})
|
||||
if err != nil {
|
||||
// If headers already sent, we can't send a JSON error
|
||||
fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error())
|
||||
c.Response().Flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListCheckpointsEndpoint lists checkpoints for a job.
|
||||
func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"checkpoints": checkpoints,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ExportModelEndpoint exports a model from a checkpoint.
|
||||
func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
var req schema.ExportRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "Invalid request: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusAccepted, map[string]string{
|
||||
"status": "exporting",
|
||||
"message": "Export started for model '" + modelName + "'",
|
||||
"model_name": modelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadExportedModelEndpoint streams the exported model directory as a tar.gz archive.
|
||||
func DownloadExportedModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
modelDir, modelName, err := ftService.GetExportedModelPath(userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
c.Response().Header().Set("Content-Type", "application/gzip")
|
||||
c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.tar.gz"`, modelName))
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
gw := gzip.NewWriter(c.Response())
|
||||
defer gw.Close()
|
||||
|
||||
tw := tar.NewWriter(gw)
|
||||
defer tw.Close()
|
||||
|
||||
err = filepath.Walk(modelDir, func(path string, info os.FileInfo, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(modelDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = filepath.Join(modelName, relPath)
|
||||
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(tw, f)
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Headers already sent, can't return JSON error
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning".
|
||||
func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to list backends: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
type backendInfo struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
var result []backendInfo
|
||||
for _, b := range backends {
|
||||
if !b.Installed {
|
||||
continue
|
||||
}
|
||||
hasTag := false
|
||||
for _, t := range b.Tags {
|
||||
if strings.EqualFold(t, "fine-tuning") {
|
||||
hasTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasTag {
|
||||
continue
|
||||
}
|
||||
name := b.Name
|
||||
if b.Alias != "" {
|
||||
name = b.Alias
|
||||
}
|
||||
result = append(result, backendInfo{
|
||||
Name: name,
|
||||
Description: b.Description,
|
||||
Tags: b.Tags,
|
||||
})
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = []backendInfo{}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
// UploadDatasetEndpoint handles dataset file upload.
|
||||
func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "file is required",
|
||||
})
|
||||
}
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to open file",
|
||||
})
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
data, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to read file",
|
||||
})
|
||||
}
|
||||
|
||||
path, err := ftService.UploadDataset(file.Filename, data)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"path": path,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ type ModelGalleryEndpointService struct {
|
||||
backendGalleries []config.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryService
|
||||
configLoader *config.ModelConfigLoader
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
@@ -27,12 +28,13 @@ type GalleryModel struct {
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService, configLoader *config.ModelConfigLoader) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
configLoader: configLoader,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +105,8 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle
|
||||
GalleryElementName: modelName,
|
||||
}
|
||||
|
||||
mgs.configLoader.RemoveModelConfig(modelName)
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -136,6 +136,12 @@ func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
appConfig.ApiKeys = append(envKeys, runtimeKeys...)
|
||||
}
|
||||
|
||||
// Update backend logging dynamically
|
||||
if settings.EnableBackendLogging != nil {
|
||||
app.ModelLoader().SetBackendLoggingEnabled(*settings.EnableBackendLogging)
|
||||
xlog.Info("Updated backend logging setting", "enableBackendLogging", *settings.EnableBackendLogging)
|
||||
}
|
||||
|
||||
// Update watchdog dynamically for settings that don't require restart
|
||||
if settings.ForceEvictionWhenBusy != nil {
|
||||
currentWD := app.ModelLoader().GetWatchDog()
|
||||
|
||||
@@ -82,51 +82,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
template = s
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
|
||||
// Track accumulated content for reasoning extraction
|
||||
accumulatedContent := ""
|
||||
lastEmittedReasoning := ""
|
||||
lastEmittedCleanedContent := ""
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
accumulatedContent += s
|
||||
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
// Calculate new reasoning delta (what we haven't emitted yet)
|
||||
var reasoningDelta *string
|
||||
if currentReasoning != lastEmittedReasoning {
|
||||
// Extract only the new part
|
||||
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
|
||||
newReasoning := currentReasoning[len(lastEmittedReasoning):]
|
||||
reasoningDelta = &newReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
} else if currentReasoning != "" {
|
||||
// If reasoning changed in a non-append way, emit the full current reasoning
|
||||
reasoningDelta = ¤tReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate content delta from cleaned content
|
||||
var deltaContent string
|
||||
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
|
||||
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else if cleanedContent != lastEmittedCleanedContent {
|
||||
// If cleaned content changed but not in a simple append, extract delta from cleaned content
|
||||
// This handles cases where thinking tags are removed mid-stream
|
||||
if lastEmittedCleanedContent == "" {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else {
|
||||
// Content changed in non-append way, use the new cleaned content
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
}
|
||||
}
|
||||
// Only emit content if there's actual content (not just thinking tags)
|
||||
// If deltaContent is empty, we still emit the response but with empty content
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(s)
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
@@ -139,12 +98,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
// Only include content if there's actual content (not just thinking tags)
|
||||
if deltaContent != "" {
|
||||
delta.Content = &deltaContent
|
||||
if contentDelta != "" {
|
||||
delta.Content = &contentDelta
|
||||
}
|
||||
if reasoningDelta != nil && *reasoningDelta != "" {
|
||||
delta.Reasoning = reasoningDelta
|
||||
if reasoningDelta != "" {
|
||||
delta.Reasoning = &reasoningDelta
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
@@ -171,11 +129,53 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
template = prompt
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
sentInitialRole := false
|
||||
|
||||
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(s)
|
||||
|
||||
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
||||
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||
if reasoningDelta != "" {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Reasoning: &reasoningDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
// Stream content deltas (cleaned of reasoning tags) while no tool calls
|
||||
// have been detected. Once the incremental parser finds tool calls,
|
||||
// content stops — per OpenAI spec, content and tool_calls don't mix.
|
||||
if lastEmittedCount == 0 && contentDelta != "" {
|
||||
if !sentInitialRole {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole = true
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: &contentDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
// Try incremental XML parsing for streaming support using iterative parser
|
||||
// This allows emitting partial tool calls as they're being generated
|
||||
cleanedResult := functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||
@@ -279,7 +279,25 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
},
|
||||
func(attempt int) bool {
|
||||
// After streaming completes: check if we got actionable content
|
||||
cleaned := extractor.CleanedContent()
|
||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||
// but we need to know here whether to retry)
|
||||
hasToolCalls := lastEmittedCount > 0
|
||||
if cleaned == "" && !hasToolCalls {
|
||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||
extractor.ResetAndSuppressReasoning()
|
||||
result = ""
|
||||
lastEmittedCount = 0
|
||||
sentInitialRole = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -296,30 +314,17 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
||||
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
reasoning, result = reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig)
|
||||
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
|
||||
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(result, config.FunctionsConfig)
|
||||
reasoning = extractor.Reasoning()
|
||||
cleanedResult := extractor.CleanedContent()
|
||||
textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig)
|
||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", textContentToReturn)
|
||||
noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0
|
||||
|
||||
switch {
|
||||
case noActionToRun:
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
result, err := handleQuestion(config, functionResults, result, prompt)
|
||||
if err != nil {
|
||||
xlog.Error("error handling question", "error", err)
|
||||
return err
|
||||
}
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
@@ -330,25 +335,43 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
var deltaReasoning *string
|
||||
if reasoning != "" {
|
||||
deltaReasoning = &reasoning
|
||||
}
|
||||
delta := &schema.Message{Content: &result}
|
||||
if deltaReasoning != nil {
|
||||
delta.Reasoning = deltaReasoning
|
||||
}
|
||||
if sentInitialRole {
|
||||
// Content was already streamed during the callback — just emit usage.
|
||||
delta := &schema.Message{}
|
||||
if reasoning != "" && extractor.Reasoning() == "" {
|
||||
delta.Reasoning = &reasoning
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: usage,
|
||||
}
|
||||
} else {
|
||||
// Content was NOT streamed — send everything at once (fallback).
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: usage,
|
||||
}
|
||||
result, err := handleQuestion(config, functionResults, extractor.CleanedContent(), prompt)
|
||||
if err != nil {
|
||||
xlog.Error("error handling question", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
delta := &schema.Message{Content: &result}
|
||||
if reasoning != "" {
|
||||
delta.Reasoning = &reasoning
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: delta, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
for i, ss := range functionResults {
|
||||
@@ -907,7 +930,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// is deferred to after ComputeChoices so we can check chat deltas first
|
||||
// and avoid redundant Go-side parsing.
|
||||
var cbRawResult, cbReasoning string
|
||||
var emptyRetryNeeded bool
|
||||
|
||||
tokenCallback := func(s string, c *[]schema.Choice) {
|
||||
reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig)
|
||||
@@ -927,146 +949,133 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
cbReasoning = reasoning
|
||||
}
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var result []schema.Choice
|
||||
var tokenUsage backend.TokenUsage
|
||||
var err error
|
||||
|
||||
var chatDeltas []*pb.ChatDelta
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
emptyRetryNeeded = false
|
||||
result, tokenUsage, chatDeltas, err = ComputeChoices(
|
||||
input,
|
||||
predInput,
|
||||
config,
|
||||
cl,
|
||||
startupOptions,
|
||||
ml,
|
||||
tokenCallback,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Tool parsing is deferred here (only when shouldUseFn)
|
||||
if shouldUseFn {
|
||||
var funcResults []functions.FuncCallResults
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls))
|
||||
funcResults = deltaToolCalls
|
||||
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text
|
||||
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
|
||||
textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig)
|
||||
cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig)
|
||||
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
|
||||
result, tokenUsage, chatDeltas, err = ComputeChoices(
|
||||
input,
|
||||
predInput,
|
||||
config,
|
||||
cl,
|
||||
startupOptions,
|
||||
ml,
|
||||
tokenCallback,
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
if !shouldUseFn {
|
||||
return false
|
||||
}
|
||||
|
||||
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
if cbRawResult == "" && textContentToReturn == "" {
|
||||
xlog.Warn("Backend returned empty content in tool-calling context, will retry")
|
||||
emptyRetryNeeded = true
|
||||
continue
|
||||
}
|
||||
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
|
||||
if qErr != nil {
|
||||
xlog.Error("error handling question", "error", qErr)
|
||||
emptyRetryNeeded = true
|
||||
continue
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
message := &schema.Message{Role: "assistant", Content: &qResult}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Message: message,
|
||||
})
|
||||
default:
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice := schema.Choice{
|
||||
FinishReason: &toolCallsReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
toolChoice.Message.Reasoning = &cbReasoning
|
||||
}
|
||||
|
||||
for _, ss := range funcResults {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
toolCallID := ss.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = id
|
||||
}
|
||||
if len(input.Tools) > 0 {
|
||||
toolChoice.Message.Content = textContentToReturn
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Deprecated function_call format
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
message := &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &functionCallReason,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
result = append(result, toolChoice)
|
||||
}
|
||||
// Retry when backend produced only reasoning and no content/tool calls.
|
||||
// Full tool parsing is deferred until after ComputeChoices returns
|
||||
// (when chat deltas are available), but we can detect the empty case here.
|
||||
if cbRawResult == "" && textContentToReturn == "" {
|
||||
xlog.Warn("Backend produced reasoning without actionable content, retrying",
|
||||
"reasoning_len", len(cbReasoning), "attempt", attempt+1)
|
||||
cbRawResult = ""
|
||||
cbReasoning = ""
|
||||
textContentToReturn = ""
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if !emptyRetryNeeded {
|
||||
break
|
||||
}
|
||||
xlog.Warn("Retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
}
|
||||
return false
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if emptyRetryNeeded {
|
||||
xlog.Warn("All retries exhausted, backend still returning empty content")
|
||||
stopReason := FinishReasonStop
|
||||
empty := ""
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Index: 0,
|
||||
Message: &schema.Message{Role: "assistant", Content: &empty},
|
||||
})
|
||||
// Tool parsing is deferred here (only when shouldUseFn) so chat deltas are available
|
||||
if shouldUseFn {
|
||||
var funcResults []functions.FuncCallResults
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls))
|
||||
funcResults = deltaToolCalls
|
||||
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text
|
||||
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
|
||||
textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig)
|
||||
cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig)
|
||||
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
|
||||
}
|
||||
|
||||
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
|
||||
if qErr != nil {
|
||||
xlog.Error("error handling question", "error", qErr)
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
message := &schema.Message{Role: "assistant", Content: &qResult}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Message: message,
|
||||
})
|
||||
default:
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice := schema.Choice{
|
||||
FinishReason: &toolCallsReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
toolChoice.Message.Reasoning = &cbReasoning
|
||||
}
|
||||
|
||||
for _, ss := range funcResults {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
toolCallID := ss.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = id
|
||||
}
|
||||
if len(input.Tools) > 0 {
|
||||
toolChoice.Message.Content = textContentToReturn
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Deprecated function_call format
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
message := &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &functionCallReason,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
result = append(result, toolChoice)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MCP server-side tool execution loop:
|
||||
@@ -1203,5 +1212,5 @@ func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCall
|
||||
|
||||
xlog.Debug("No action received from LLM, without a message, computing a reply")
|
||||
|
||||
return "", fmt.Errorf("no action received from LLM, without a message, computing a reply")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
157
core/http/endpoints/openai/chat_test.go
Normal file
157
core/http/endpoints/openai/chat_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
var _ = Describe("handleQuestion", func() {
|
||||
var cfg *config.ModelConfig
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = &config.ModelConfig{}
|
||||
})
|
||||
|
||||
Context("with no function results but non-empty result", func() {
|
||||
It("should return the result directly", func() {
|
||||
result, err := handleQuestion(cfg, nil, "Hello world", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("Hello world"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with no function results and empty result", func() {
|
||||
It("should return empty string", func() {
|
||||
result, err := handleQuestion(cfg, nil, "", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with function result containing a message argument", func() {
|
||||
It("should extract the message from function arguments", func() {
|
||||
funcResults := []functions.FuncCallResults{
|
||||
{
|
||||
Name: "answer",
|
||||
Arguments: `{"message": "This is the answer"}`,
|
||||
},
|
||||
}
|
||||
result, err := handleQuestion(cfg, funcResults, "", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("This is the answer"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with function result containing empty message", func() {
|
||||
It("should return empty string when message is empty", func() {
|
||||
funcResults := []functions.FuncCallResults{
|
||||
{
|
||||
Name: "answer",
|
||||
Arguments: `{"message": ""}`,
|
||||
},
|
||||
}
|
||||
result, err := handleQuestion(cfg, funcResults, "", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with function result containing invalid JSON arguments", func() {
|
||||
It("should return empty string gracefully", func() {
|
||||
funcResults := []functions.FuncCallResults{
|
||||
{
|
||||
Name: "answer",
|
||||
Arguments: "not json",
|
||||
},
|
||||
}
|
||||
result, err := handleQuestion(cfg, funcResults, "", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with cleaned content (no think tags)", func() {
|
||||
It("should return content without think tags", func() {
|
||||
// This tests the bug fix: handleQuestion should receive cleaned content,
|
||||
// not raw text with <think> tags
|
||||
result, err := handleQuestion(cfg, nil, "Just the answer", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("Just the answer"))
|
||||
Expect(result).ToNot(ContainSubstring("<think>"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with raw think tags passed as result", func() {
|
||||
It("would return content with think tags", func() {
|
||||
result, err := handleQuestion(cfg, nil, "<think>reasoning</think>answer", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("<think>reasoning</think>answer"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("mergeToolCallDeltas", func() {
|
||||
Context("with new tool calls", func() {
|
||||
It("should append new tool calls", func() {
|
||||
existing := []schema.ToolCall{}
|
||||
deltas := []schema.ToolCall{
|
||||
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}},
|
||||
}
|
||||
result := mergeToolCallDeltas(existing, deltas)
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(result[0].ID).To(Equal("tc1"))
|
||||
Expect(result[0].FunctionCall.Name).To(Equal("search"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with argument appending", func() {
|
||||
It("should append arguments to existing tool call", func() {
|
||||
existing := []schema.ToolCall{
|
||||
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search", Arguments: `{"q":`}},
|
||||
}
|
||||
deltas := []schema.ToolCall{
|
||||
{Index: 0, FunctionCall: schema.FunctionCall{Arguments: `"hello"}`}},
|
||||
}
|
||||
result := mergeToolCallDeltas(existing, deltas)
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(result[0].FunctionCall.Arguments).To(Equal(`{"q":"hello"}`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with multiple tool calls", func() {
|
||||
It("should track multiple tool calls by index", func() {
|
||||
existing := []schema.ToolCall{}
|
||||
deltas1 := []schema.ToolCall{
|
||||
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}},
|
||||
}
|
||||
result := mergeToolCallDeltas(existing, deltas1)
|
||||
|
||||
deltas2 := []schema.ToolCall{
|
||||
{Index: 1, ID: "tc2", Type: "function", FunctionCall: schema.FunctionCall{Name: "browse"}},
|
||||
}
|
||||
result = mergeToolCallDeltas(result, deltas2)
|
||||
Expect(result).To(HaveLen(2))
|
||||
Expect(result[0].FunctionCall.Name).To(Equal("search"))
|
||||
Expect(result[1].FunctionCall.Name).To(Equal("browse"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with ID update on existing tool call", func() {
|
||||
It("should update ID when provided in delta", func() {
|
||||
existing := []schema.ToolCall{
|
||||
{Index: 0, FunctionCall: schema.FunctionCall{Name: "search"}},
|
||||
}
|
||||
deltas := []schema.ToolCall{
|
||||
{Index: 0, ID: "new-id"},
|
||||
}
|
||||
result := mergeToolCallDeltas(existing, deltas)
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(result[0].ID).To(Equal("new-id"))
|
||||
Expect(result[0].FunctionCall.Name).To(Equal("search"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func ComputeChoices(
|
||||
@@ -19,7 +21,9 @@ func ComputeChoices(
|
||||
o *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
shouldRetry ...func(int) bool,
|
||||
) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
|
||||
n := req.N // number of completions to return
|
||||
result := []schema.Choice{}
|
||||
|
||||
@@ -27,6 +31,12 @@ func ComputeChoices(
|
||||
n = 1
|
||||
}
|
||||
|
||||
// Extract the optional shouldRetry callback
|
||||
var shouldRetryFn func(int) bool
|
||||
if len(shouldRetry) > 0 {
|
||||
shouldRetryFn = shouldRetry[0]
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
for _, m := range req.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
@@ -82,7 +92,7 @@ func ComputeChoices(
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(
|
||||
predFunc, err := backend.ModelInferenceFunc(
|
||||
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, req.Metadata)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, nil, err
|
||||
@@ -91,32 +101,49 @@ func ComputeChoices(
|
||||
tokenUsage := backend.TokenUsage{}
|
||||
var allChatDeltas []*pb.ChatDelta
|
||||
|
||||
const maxRetries = 5
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, nil, err
|
||||
var prediction backend.LLMResponse
|
||||
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
p, err := predFunc()
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, nil, err
|
||||
}
|
||||
prediction = p
|
||||
|
||||
// Built-in: retry on truly empty response (no tokens at all)
|
||||
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries {
|
||||
xlog.Warn("Backend returned empty response, retrying",
|
||||
"attempt", attempt+1, "maxRetries", maxRetries)
|
||||
continue
|
||||
}
|
||||
|
||||
tokenUsage.Prompt = prediction.Usage.Prompt
|
||||
tokenUsage.Completion = prediction.Usage.Completion
|
||||
tokenUsage.TimingPromptProcessing = prediction.Usage.TimingPromptProcessing
|
||||
tokenUsage.TimingTokenGeneration = prediction.Usage.TimingTokenGeneration
|
||||
|
||||
allChatDeltas = prediction.ChatDeltas
|
||||
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Caller-driven retry (tool parsing, reasoning-only, etc.)
|
||||
if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries {
|
||||
// Caller has already reset its state inside shouldRetry
|
||||
result = result[:0]
|
||||
allChatDeltas = nil
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
tokenUsage.Prompt += prediction.Usage.Prompt
|
||||
tokenUsage.Completion += prediction.Usage.Completion
|
||||
tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing
|
||||
tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration
|
||||
|
||||
// Collect chat deltas from C++ autoparser
|
||||
if len(prediction.ChatDeltas) > 0 {
|
||||
allChatDeltas = append(allChatDeltas, prediction.ChatDeltas...)
|
||||
}
|
||||
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Add logprobs to the last choice if present
|
||||
if prediction.Logprobs != nil && len(result) > 0 {
|
||||
result[len(result)-1].Logprobs = prediction.Logprobs
|
||||
}
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
return result, tokenUsage, allChatDeltas, err
|
||||
}
|
||||
|
||||
402
core/http/endpoints/openai/inference_test.go
Normal file
402
core/http/endpoints/openai/inference_test.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type modelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error)
|
||||
|
||||
var _ = Describe("ComputeChoices", func() {
|
||||
var (
|
||||
origInference modelInferenceFunc
|
||||
cfg *config.ModelConfig
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
// mockInference installs a stub that yields the given responses sequentially.
|
||||
// After all responses are consumed, the last one is repeated.
|
||||
mockInference := func(responses []backend.LLMResponse) {
|
||||
idx := 0
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
resp := responses[idx]
|
||||
if idx < len(responses)-1 {
|
||||
idx++
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
origInference = backend.ModelInferenceFunc
|
||||
cfg = &config.ModelConfig{}
|
||||
appCfg = config.NewApplicationConfig()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
backend.ModelInferenceFunc = origInference
|
||||
})
|
||||
|
||||
makeReq := func() *schema.OpenAIRequest {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
_ = cancel
|
||||
return &schema.OpenAIRequest{
|
||||
Context: ctx,
|
||||
Cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
Context("normal response (no retry needed)", func() {
|
||||
It("should return choices on first attempt", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "Hello world", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}},
|
||||
})
|
||||
|
||||
var captured string
|
||||
choices, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test prompt", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
captured = s
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(captured).To(Equal("Hello world"))
|
||||
Expect(usage.Prompt).To(Equal(10))
|
||||
Expect(usage.Completion).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Context("empty response triggers built-in retry", func() {
|
||||
It("should retry and eventually return non-empty response", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: ""}, // attempt 0: empty
|
||||
{Response: " "}, // attempt 1: whitespace-only
|
||||
{Response: "Got it", Usage: backend.TokenUsage{Prompt: 8, Completion: 3}}, // attempt 2: success
|
||||
})
|
||||
|
||||
choices, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("Got it"))
|
||||
Expect(usage.Prompt).To(Equal(8))
|
||||
Expect(usage.Completion).To(Equal(3))
|
||||
})
|
||||
})
|
||||
|
||||
Context("all retries exhausted on empty response", func() {
|
||||
It("should return the empty response after max retries", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: ""}, // always empty
|
||||
})
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// After maxRetries, it proceeds with the empty response
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("shouldRetry callback", func() {
|
||||
It("should call shouldRetry and retry when it returns true", func() {
|
||||
callCount := 0
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "reasoning-only", Usage: backend.TokenUsage{Prompt: 5, Completion: 2}},
|
||||
{Response: "actual-answer", Usage: backend.TokenUsage{Prompt: 5, Completion: 4}},
|
||||
})
|
||||
|
||||
retryAttempts := []int{}
|
||||
choices, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
callCount++
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
retryAttempts = append(retryAttempts, attempt)
|
||||
// Retry on first attempt only
|
||||
return attempt == 0
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("actual-answer"))
|
||||
// shouldRetry was called twice: once returning true (retry), once returning false (proceed)
|
||||
Expect(retryAttempts).To(Equal([]int{0, 1}))
|
||||
// cb was called twice (once per attempt)
|
||||
Expect(callCount).To(Equal(2))
|
||||
// Token usage should be from the LATEST attempt
|
||||
Expect(usage.Prompt).To(Equal(5))
|
||||
Expect(usage.Completion).To(Equal(4))
|
||||
})
|
||||
|
||||
It("should not retry when shouldRetry returns false", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "first-response"},
|
||||
})
|
||||
|
||||
shouldRetryCalled := false
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
shouldRetryCalled = true
|
||||
return false
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("first-response"))
|
||||
Expect(shouldRetryCalled).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("shouldRetry not provided (variadic omitted)", func() {
|
||||
It("should work without shouldRetry parameter", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "works"},
|
||||
})
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("works"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("token usage from latest attempt", func() {
|
||||
It("should use token usage from the last attempt, not accumulated", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "retry-me", Usage: backend.TokenUsage{Prompt: 100, Completion: 50}},
|
||||
{Response: "final", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}},
|
||||
})
|
||||
|
||||
_, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool { return attempt == 0 },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Should be the LATEST attempt's usage, not accumulated
|
||||
Expect(usage.Prompt).To(Equal(10))
|
||||
Expect(usage.Completion).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Context("chat deltas from latest attempt", func() {
|
||||
It("should return chat deltas from the last attempt only", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{
|
||||
Response: "retry-me",
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: "old"}},
|
||||
},
|
||||
{
|
||||
Response: "final",
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: "new"}},
|
||||
},
|
||||
})
|
||||
|
||||
_, _, deltas, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool { return attempt == 0 },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(deltas).To(HaveLen(1))
|
||||
Expect(deltas[0].Content).To(Equal("new"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("result choices cleared on retry", func() {
|
||||
It("should only contain choices from the final attempt", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "bad-choice"},
|
||||
{Response: "good-choice"},
|
||||
})
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool { return attempt == 0 },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("good-choice"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("shouldRetry with max retries cap", func() {
|
||||
It("should stop retrying after maxRetries even if shouldRetry returns true", func() {
|
||||
attempts := 0
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "always-retry"},
|
||||
})
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
attempts++
|
||||
return true // always want to retry
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
// maxRetries is 5, so shouldRetry is called for attempts 0..4,
|
||||
// but attempt 5 is the final one where shouldRetry can't trigger continue
|
||||
Expect(attempts).To(BeNumerically("<=", 6))
|
||||
})
|
||||
})
|
||||
|
||||
Context("N > 1 completions", func() {
|
||||
It("should produce N separate completions", func() {
|
||||
callIdx := 0
|
||||
responses := []string{"first", "second", "third"}
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
resp := backend.LLMResponse{Response: responses[callIdx]}
|
||||
if callIdx < len(responses)-1 {
|
||||
callIdx++
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
|
||||
req := makeReq()
|
||||
req.N = 3
|
||||
choices, _, _, err := ComputeChoices(
|
||||
req, "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(3))
|
||||
Expect(choices[0].Text).To(Equal("first"))
|
||||
Expect(choices[1].Text).To(Equal("second"))
|
||||
Expect(choices[2].Text).To(Equal("third"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with streaming token callback", func() {
|
||||
It("should call tokenCallback for streaming responses", func() {
|
||||
var streamedTokens []string
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
if tokenCallback != nil {
|
||||
tokenCallback("Hello", backend.TokenUsage{Prompt: 5})
|
||||
tokenCallback(" world", backend.TokenUsage{Prompt: 5, Completion: 2})
|
||||
}
|
||||
return backend.LLMResponse{
|
||||
Response: "Hello world",
|
||||
Usage: backend.TokenUsage{Prompt: 5, Completion: 2},
|
||||
}, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
func(s string, usage backend.TokenUsage) bool {
|
||||
streamedTokens = append(streamedTokens, s)
|
||||
return true
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(streamedTokens).To(Equal([]string{"Hello", " world"}))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -3,16 +3,22 @@ package openai
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, db ...*gorm.DB) echo.HandlerFunc {
|
||||
var authDB *gorm.DB
|
||||
if len(db) > 0 {
|
||||
authDB = db[0]
|
||||
}
|
||||
return func(c echo.Context) error {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.QueryParam("filter")
|
||||
@@ -36,6 +42,26 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
return err
|
||||
}
|
||||
|
||||
// Filter models by user's allowlist if auth is enabled
|
||||
if authDB != nil {
|
||||
if user := auth.GetUser(c); user != nil && user.Role != auth.RoleAdmin {
|
||||
perm, err := auth.GetCachedUserPermissions(c, authDB, user.ID)
|
||||
if err == nil && perm.AllowedModels.Enabled {
|
||||
allowed := map[string]bool{}
|
||||
for _, m := range perm.AllowedModels.Models {
|
||||
allowed[m] = true
|
||||
}
|
||||
filtered := make([]string, 0, len(modelNames))
|
||||
for _, m := range modelNames {
|
||||
if allowed[m] {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
modelNames = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map from a slice of names to a slice of OpenAIModel response objects
|
||||
dataModels := []schema.OpenAIModel{}
|
||||
for _, m := range modelNames {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -879,34 +880,14 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
xlog.Debug("Background MCP re-templating", "iteration", mcpIteration)
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
videos := []string{}
|
||||
audios := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
toolsJSON := serializeToolsForBackend(input.Tools)
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
|
||||
var logprobs *int
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
logprobs = input.TopLogprobs
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model inference failed: %w", err)
|
||||
openAIReq.TopLogprobs = input.TopLogprobs
|
||||
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
|
||||
}
|
||||
openAIReq.LogitBias = input.LogitBias
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -914,24 +895,19 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
default:
|
||||
}
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var prediction backend.LLMResponse
|
||||
var result string
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("prediction failed: %w", err)
|
||||
}
|
||||
result = backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
if result != "" || !shouldUseFn {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
xlog.Warn("Open Responses background: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
cb := func(s string, c *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model inference failed: %w", err)
|
||||
}
|
||||
|
||||
// Extract logprobs from choices if available
|
||||
var resultLogprobs *schema.Logprobs
|
||||
if len(choices) > 0 {
|
||||
resultLogprobs = choices[0].Logprobs
|
||||
}
|
||||
|
||||
// Parse tool calls
|
||||
@@ -939,9 +915,9 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
var textContent string
|
||||
|
||||
if shouldUseFn {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
funcCallResults = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
textContent = functions.ContentFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||
funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
@@ -1021,7 +997,7 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed", Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
@@ -1034,22 +1010,22 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed", Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed", Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
return buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, allOutputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: tokenUsage.Prompt,
|
||||
OutputTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
}, true), nil
|
||||
} // end MCP iteration loop
|
||||
|
||||
@@ -1058,23 +1034,14 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
|
||||
// handleBackgroundStream handles background streaming responses with event buffering
|
||||
func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) {
|
||||
images := []string{}
|
||||
videos := []string{}
|
||||
audios := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
toolsJSON := serializeToolsForBackend(input.Tools)
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
openAIReq.TopLogprobs = input.TopLogprobs
|
||||
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
|
||||
}
|
||||
openAIReq.LogitBias = input.LogitBias
|
||||
|
||||
sequenceNumber := 0
|
||||
|
||||
@@ -1105,20 +1072,13 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
}
|
||||
hasMCPTools := len(mcpToolInfos) > 0
|
||||
|
||||
var prediction backend.LLMResponse
|
||||
var lastTokenUsage backend.TokenUsage
|
||||
var lastLogprobs *schema.Logprobs
|
||||
|
||||
for mcpIter := 0; mcpIter <= mcpBgStreamMaxIterations; mcpIter++ {
|
||||
if mcpIter > 0 {
|
||||
predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
|
||||
xlog.Debug("Background stream MCP re-templating", "iteration", mcpIter)
|
||||
images = images[:0]
|
||||
videos = videos[:0]
|
||||
audios = audios[:0]
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
}
|
||||
|
||||
accumulatedText = ""
|
||||
@@ -1177,28 +1137,23 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
return true
|
||||
}
|
||||
|
||||
var streamLogprobs *int
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
streamLogprobs = input.TopLogprobs
|
||||
var result string
|
||||
cb := func(s string, c *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, tokenCallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model inference failed: %w", err)
|
||||
}
|
||||
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("prediction failed: %w", err)
|
||||
lastTokenUsage = tokenUsage
|
||||
if len(choices) > 0 {
|
||||
lastLogprobs = choices[0].Logprobs
|
||||
}
|
||||
|
||||
result := backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
|
||||
// Check for MCP tool calls in the streamed result
|
||||
if shouldUseFn && hasMCPTools {
|
||||
var funcCallResults []functions.FuncCallResults
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
funcCallResults = deltaToolCalls
|
||||
} else {
|
||||
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||
@@ -1315,7 +1270,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
}
|
||||
|
||||
// No MCP tools — close the message and break
|
||||
streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
|
||||
streamEventLogprobs := convertLogprobsForStreaming(lastLogprobs)
|
||||
bufferEvent(store, responseID, &schema.ORStreamEvent{
|
||||
Type: "response.output_text.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -1327,7 +1282,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
})
|
||||
sequenceNumber++
|
||||
|
||||
textPart := makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs)
|
||||
textPart := makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs)
|
||||
bufferEvent(store, responseID, &schema.ORStreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -1343,7 +1298,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
ID: currentMessageID,
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs)},
|
||||
}
|
||||
bufferEvent(store, responseID, &schema.ORStreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
@@ -1360,9 +1315,9 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
// Build final response
|
||||
now := time.Now().Unix()
|
||||
response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, collectedOutputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: lastTokenUsage.Prompt,
|
||||
OutputTokens: lastTokenUsage.Completion,
|
||||
TotalTokens: lastTokenUsage.Prompt + lastTokenUsage.Completion,
|
||||
}, true)
|
||||
|
||||
// Emit response.completed
|
||||
@@ -1377,6 +1332,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
|
||||
// bufferEvent stores an SSE event in the response store for streaming resume
|
||||
func bufferEvent(store *ResponseStore, responseID string, event *schema.ORStreamEvent) {
|
||||
normalizeORStreamEvent(event)
|
||||
if err := store.AppendEvent(responseID, event); err != nil {
|
||||
xlog.Error("Failed to buffer event", "response_id", responseID, "error", err)
|
||||
}
|
||||
@@ -1391,52 +1347,27 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
if mcpIteration > mcpMaxIterations {
|
||||
return sendOpenResponsesError(c, 500, "server_error", "MCP iteration limit reached", "")
|
||||
}
|
||||
images := []string{}
|
||||
videos := []string{}
|
||||
audios := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
// Convert and serialize tools to OpenAI format for the backend
|
||||
toolsJSON := serializeToolsForBackend(input.Tools)
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Pass logprobs and logit_bias parameters if requested
|
||||
var logprobs *int
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
logprobs = input.TopLogprobs
|
||||
openAIReq.TopLogprobs = input.TopLogprobs
|
||||
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
|
||||
}
|
||||
openAIReq.LogitBias = input.LogitBias
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
var result string
|
||||
cb := func(s string, c *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses model inference failed", "error", err)
|
||||
return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("model inference failed: %v", err), "")
|
||||
}
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var prediction backend.LLMResponse
|
||||
var result string
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses prediction failed", "error", err)
|
||||
return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("prediction failed: %v", err), "")
|
||||
}
|
||||
result = backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
if result != "" || !shouldUseFn {
|
||||
break
|
||||
}
|
||||
xlog.Warn("Open Responses: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
var resultLogprobs *schema.Logprobs
|
||||
if len(choices) > 0 {
|
||||
resultLogprobs = choices[0].Logprobs
|
||||
}
|
||||
xlog.Debug("Open Responses - Raw model result", "result", result, "shouldUseFn", shouldUseFn)
|
||||
|
||||
@@ -1473,10 +1404,10 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
var textContent string
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
funcCallResults = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
textContent = functions.ContentFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
// Clean up the result (already extracted reasoning above)
|
||||
@@ -1574,7 +1505,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1605,7 +1536,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
@@ -1615,7 +1546,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)},
|
||||
}
|
||||
outputItems = append(outputItems, messageItem)
|
||||
}
|
||||
@@ -1633,9 +1564,9 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
// Build response with all required fields
|
||||
now := time.Now().Unix()
|
||||
response := buildORResponse(responseID, createdAt, &now, "completed", input, outputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: tokenUsage.Prompt,
|
||||
OutputTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
OutputTokensDetails: &schema.OROutputTokensDetails{
|
||||
ReasoningTokens: reasoningTokens,
|
||||
},
|
||||
@@ -1675,24 +1606,14 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
})
|
||||
sequenceNumber++
|
||||
|
||||
images := []string{}
|
||||
videos := []string{}
|
||||
audios := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
// Convert and serialize tools to OpenAI format for the backend
|
||||
toolsJSON := serializeToolsForBackend(input.Tools)
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
openAIReq.TopLogprobs = input.TopLogprobs
|
||||
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
|
||||
}
|
||||
openAIReq.LogitBias = input.LogitBias
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
var template string
|
||||
@@ -1714,10 +1635,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
// Track reasoning state for streaming
|
||||
var currentReasoningID string
|
||||
var currentReasoningContentIndex int
|
||||
var accumulatedContent string
|
||||
var lastEmittedReasoning string
|
||||
var lastEmittedCleanedContent string
|
||||
var reasoningTokens int
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
|
||||
|
||||
// Collect all output items for storage
|
||||
var collectedOutputItems []schema.ORItemField
|
||||
@@ -1729,24 +1648,25 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
hasMCPToolsStream := len(mcpToolInfos) > 0
|
||||
|
||||
var prediction backend.LLMResponse
|
||||
var result, finalReasoning, finalCleanedResult string
|
||||
var textContent string
|
||||
var parsedToolCalls []functions.FuncCallResults
|
||||
var toolCalls []functions.FuncCallResults
|
||||
var lastStreamTokenUsage backend.TokenUsage
|
||||
var lastStreamLogprobs *schema.Logprobs
|
||||
|
||||
for mcpStreamIter := 0; mcpStreamIter <= mcpStreamMaxIterations; mcpStreamIter++ {
|
||||
if mcpStreamIter > 0 {
|
||||
// Reset reasoning and tool-call state for re-inference so reasoning
|
||||
// extraction runs again on subsequent iterations
|
||||
inToolCallMode = false
|
||||
extractor.Reset()
|
||||
currentMessageID = ""
|
||||
lastEmittedToolCallCount = 0
|
||||
currentReasoningID = ""
|
||||
|
||||
predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
|
||||
xlog.Debug("Open Responses stream MCP re-templating", "iteration", mcpStreamIter)
|
||||
images = images[:0]
|
||||
videos = videos[:0]
|
||||
audios = audios[:0]
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
}
|
||||
|
||||
// For tool calls, we need to track accumulated result and parse incrementally
|
||||
@@ -1901,11 +1821,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
// If no tool calls detected yet, handle reasoning and text
|
||||
if !inToolCallMode {
|
||||
accumulatedContent += token
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, cfg.ReasoningConfig)
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(token)
|
||||
|
||||
// Handle reasoning item
|
||||
if currentReasoning != "" {
|
||||
if extractor.Reasoning() != "" {
|
||||
// Check if we need to create reasoning item
|
||||
if currentReasoningID == "" {
|
||||
outputIndex++
|
||||
@@ -1937,16 +1856,6 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
sequenceNumber++
|
||||
}
|
||||
|
||||
// Calculate reasoning delta
|
||||
var reasoningDelta string
|
||||
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
|
||||
reasoningDelta = currentReasoning[len(lastEmittedReasoning):]
|
||||
lastEmittedReasoning = currentReasoning
|
||||
} else if currentReasoning != lastEmittedReasoning {
|
||||
reasoningDelta = currentReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
}
|
||||
|
||||
// Emit reasoning delta if there's new content
|
||||
if reasoningDelta != "" {
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -1963,23 +1872,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
}
|
||||
|
||||
// Handle message content (cleaned content without reasoning tags)
|
||||
var deltaContent string
|
||||
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
|
||||
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else if cleanedContent != lastEmittedCleanedContent {
|
||||
if lastEmittedCleanedContent == "" {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
}
|
||||
}
|
||||
|
||||
// Only emit message content if there's actual content (not just reasoning)
|
||||
if deltaContent != "" {
|
||||
if contentDelta != "" {
|
||||
if currentMessageID == "" {
|
||||
// Emit output_item.added for message
|
||||
outputIndex++
|
||||
@@ -2020,7 +1914,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
ItemID: currentMessageID,
|
||||
OutputIndex: &outputIndex,
|
||||
ContentIndex: ¤tContentIndex,
|
||||
Delta: strPtr(deltaContent),
|
||||
Delta: strPtr(contentDelta),
|
||||
Logprobs: emptyLogprobs(),
|
||||
})
|
||||
sequenceNumber++
|
||||
@@ -2030,14 +1924,11 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass logprobs and logit_bias parameters if requested
|
||||
var streamLogprobs *int
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
streamLogprobs = input.TopLogprobs
|
||||
var ccResult string
|
||||
ccCb := func(s string, c *[]schema.Choice) {
|
||||
ccResult = s
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
choices, ccTokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, ccCb, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses stream model inference failed", "error", err)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -2061,36 +1952,27 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses stream prediction failed", "error", err)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "error",
|
||||
SequenceNumber: sequenceNumber,
|
||||
Error: &schema.ORErrorPayload{
|
||||
Type: "model_error",
|
||||
Message: fmt.Sprintf("prediction failed: %v", err),
|
||||
},
|
||||
})
|
||||
sequenceNumber++
|
||||
responseFailed := responseCreated
|
||||
responseFailed.Status = "failed"
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.failed",
|
||||
SequenceNumber: sequenceNumber,
|
||||
Response: responseFailed,
|
||||
})
|
||||
// Send [DONE] even on error
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
result = ccResult
|
||||
lastStreamTokenUsage = ccTokenUsage
|
||||
if len(choices) > 0 {
|
||||
lastStreamLogprobs = choices[0].Logprobs
|
||||
}
|
||||
|
||||
result = backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
|
||||
// Extract reasoning from final result
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
// Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's
|
||||
// streaming state, (3) final extraction from the finetuned result.
|
||||
if chatDeltaReasoning := functions.ReasoningFromChatDeltas(chatDeltas); chatDeltaReasoning != "" {
|
||||
finalReasoning = chatDeltaReasoning
|
||||
finalCleanedResult = functions.ContentFromChatDeltas(chatDeltas)
|
||||
if finalCleanedResult == "" {
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
} else {
|
||||
finalReasoning = extractor.Reasoning()
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
if finalReasoning == "" && finalCleanedResult == "" {
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
}
|
||||
|
||||
// Close reasoning item if it exists and wasn't closed yet
|
||||
if currentReasoningID != "" && finalReasoning != "" {
|
||||
@@ -2147,10 +2029,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
textContent = ""
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
parsedToolCalls = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
textContent = functions.ContentFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig)
|
||||
@@ -2269,8 +2151,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
|
||||
|
||||
// Convert prediction logprobs for streaming events
|
||||
streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
|
||||
// Convert logprobs for streaming events
|
||||
streamEventLogprobs := convertLogprobsForStreaming(lastStreamLogprobs)
|
||||
|
||||
// If we have no output but the model did produce something, use the cleaned result (without reasoning tags)
|
||||
if textContent == "" && len(toolCalls) == 0 && finalCleanedResult != "" {
|
||||
@@ -2293,7 +2175,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
sequenceNumber++
|
||||
|
||||
// Emit content_part.done (with actual logprobs)
|
||||
textPart := makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)
|
||||
textPart := makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -2310,7 +2192,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
ID: currentMessageID,
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)},
|
||||
}
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
@@ -2379,7 +2261,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
ID: currentMessageID,
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)},
|
||||
})
|
||||
}
|
||||
// Add tool call items
|
||||
@@ -2398,9 +2280,9 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
// Emit response.completed
|
||||
now := time.Now().Unix()
|
||||
responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, allOutputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: lastStreamTokenUsage.Prompt,
|
||||
OutputTokens: lastStreamTokenUsage.Completion,
|
||||
TotalTokens: lastStreamTokenUsage.Prompt + lastStreamTokenUsage.Completion,
|
||||
OutputTokensDetails: &schema.OROutputTokensDetails{
|
||||
ReasoningTokens: reasoningTokens,
|
||||
},
|
||||
@@ -2459,12 +2341,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
// Stream text deltas with reasoning extraction
|
||||
tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool {
|
||||
accumulatedText += token
|
||||
accumulatedContent += token
|
||||
// Prepend thinking token if needed, then extract reasoning
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, cfg.ReasoningConfig)
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(token)
|
||||
|
||||
// Handle reasoning item
|
||||
if currentReasoning != "" {
|
||||
if extractor.Reasoning() != "" {
|
||||
// Check if we need to create reasoning item
|
||||
if currentReasoningID == "" {
|
||||
outputIndex++
|
||||
@@ -2496,16 +2376,6 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
sequenceNumber++
|
||||
}
|
||||
|
||||
// Calculate reasoning delta
|
||||
var reasoningDelta string
|
||||
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
|
||||
reasoningDelta = currentReasoning[len(lastEmittedReasoning):]
|
||||
lastEmittedReasoning = currentReasoning
|
||||
} else if currentReasoning != lastEmittedReasoning {
|
||||
reasoningDelta = currentReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
}
|
||||
|
||||
// Emit reasoning delta if there's new content
|
||||
if reasoningDelta != "" {
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -2522,23 +2392,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
}
|
||||
|
||||
// Handle message content (cleaned content without reasoning tags)
|
||||
var deltaContent string
|
||||
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
|
||||
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else if cleanedContent != lastEmittedCleanedContent {
|
||||
if lastEmittedCleanedContent == "" {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
}
|
||||
}
|
||||
|
||||
// Only emit message content if there's actual content (not just reasoning)
|
||||
if deltaContent != "" {
|
||||
if contentDelta != "" {
|
||||
// Emit text delta
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
@@ -2546,7 +2401,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
ItemID: currentMessageID,
|
||||
OutputIndex: &outputIndex,
|
||||
ContentIndex: ¤tContentIndex,
|
||||
Delta: strPtr(deltaContent),
|
||||
Delta: strPtr(contentDelta),
|
||||
Logprobs: emptyLogprobs(),
|
||||
})
|
||||
sequenceNumber++
|
||||
@@ -2555,14 +2410,11 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass logprobs and logit_bias parameters if requested
|
||||
var mcpLogprobs *int
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
mcpLogprobs = input.TopLogprobs
|
||||
var noToolResult string
|
||||
noToolCb := func(s string, c *[]schema.Choice) {
|
||||
noToolResult = s
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, mcpLogprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
noToolChoices, noToolTokenUsage, noToolChatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, noToolCb, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses stream model inference failed", "error", err)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -2586,36 +2438,28 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses stream prediction failed", "error", err)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "error",
|
||||
SequenceNumber: sequenceNumber,
|
||||
Error: &schema.ORErrorPayload{
|
||||
Type: "model_error",
|
||||
Message: fmt.Sprintf("prediction failed: %v", err),
|
||||
},
|
||||
})
|
||||
sequenceNumber++
|
||||
responseFailed := responseCreated
|
||||
responseFailed.Status = "failed"
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.failed",
|
||||
SequenceNumber: sequenceNumber,
|
||||
Response: responseFailed,
|
||||
})
|
||||
// Send [DONE] even on error
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
result := noToolResult
|
||||
var noToolLogprobs *schema.Logprobs
|
||||
if len(noToolChoices) > 0 {
|
||||
noToolLogprobs = noToolChoices[0].Logprobs
|
||||
}
|
||||
|
||||
result := backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
|
||||
// Extract reasoning from final result for non-tool-call path
|
||||
finalReasoning, finalCleanedResult := reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
// Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's
|
||||
// streaming state, (3) final extraction from the finetuned result.
|
||||
var finalReasoning, finalCleanedResult string
|
||||
if chatDeltaReasoning := functions.ReasoningFromChatDeltas(noToolChatDeltas); chatDeltaReasoning != "" {
|
||||
finalReasoning = chatDeltaReasoning
|
||||
finalCleanedResult = functions.ContentFromChatDeltas(noToolChatDeltas)
|
||||
if finalCleanedResult == "" {
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
} else {
|
||||
finalReasoning = extractor.Reasoning()
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
if finalReasoning == "" && finalCleanedResult == "" {
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
}
|
||||
|
||||
// Close reasoning item if it exists and wasn't closed yet
|
||||
if currentReasoningID != "" && finalReasoning != "" {
|
||||
@@ -2670,8 +2514,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
result = finalCleanedResult
|
||||
|
||||
// Convert prediction logprobs for streaming events
|
||||
mcpStreamLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
|
||||
// Convert logprobs for streaming events
|
||||
mcpStreamLogprobs := convertLogprobsForStreaming(noToolLogprobs)
|
||||
|
||||
// Emit output_text.done
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -2686,7 +2530,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
sequenceNumber++
|
||||
|
||||
// Emit content_part.done (with actual logprobs)
|
||||
resultPart := makeOutputTextPartWithLogprobs(result, prediction.Logprobs)
|
||||
resultPart := makeOutputTextPartWithLogprobs(result, noToolLogprobs)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -2699,7 +2543,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
// Emit output_item.done (with actual logprobs)
|
||||
messageItem.Status = "completed"
|
||||
messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}
|
||||
messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, noToolLogprobs)}
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -2734,9 +2578,9 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
finalOutputItems = append(finalOutputItems, *messageItem)
|
||||
}
|
||||
responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, finalOutputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: noToolTokenUsage.Prompt,
|
||||
OutputTokens: noToolTokenUsage.Completion,
|
||||
TotalTokens: noToolTokenUsage.Prompt + noToolTokenUsage.Completion,
|
||||
OutputTokensDetails: &schema.OROutputTokensDetails{
|
||||
ReasoningTokens: reasoningTokens,
|
||||
},
|
||||
@@ -2762,6 +2606,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
// sendSSEEvent sends a Server-Sent Event
|
||||
func sendSSEEvent(c echo.Context, event *schema.ORStreamEvent) {
|
||||
normalizeORStreamEvent(event)
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to marshal SSE event", "error", err)
|
||||
@@ -2770,6 +2615,13 @@ func sendSSEEvent(c echo.Context, event *schema.ORStreamEvent) {
|
||||
fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.Type, string(data))
|
||||
}
|
||||
|
||||
// normalizeORStreamEvent ensures required fields like Summary are never null.
|
||||
func normalizeORStreamEvent(event *schema.ORStreamEvent) {
|
||||
if event.Item != nil && event.Item.Summary == nil {
|
||||
event.Item.Summary = []schema.ORContentPart{}
|
||||
}
|
||||
}
|
||||
|
||||
// getTopLogprobs returns the top_logprobs value, defaulting to 0 if nil
|
||||
func getTopLogprobs(topLogprobs *int) int {
|
||||
if topLogprobs != nil {
|
||||
@@ -2850,6 +2702,13 @@ func buildORResponse(responseID string, createdAt int64, completedAt *int64, sta
|
||||
outputItems = []schema.ORItemField{}
|
||||
}
|
||||
|
||||
// Ensure Summary is never null on any output item
|
||||
for i := range outputItems {
|
||||
if outputItems[i].Summary == nil {
|
||||
outputItems[i].Summary = []schema.ORContentPart{}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure tools is never null - always an array
|
||||
tools := input.Tools
|
||||
if tools == nil {
|
||||
@@ -3025,19 +2884,6 @@ func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.T
|
||||
return result
|
||||
}
|
||||
|
||||
// serializeToolsForBackend converts and serializes Open Responses tools to JSON for the backend
|
||||
func serializeToolsForBackend(orTools []schema.ORFunctionTool) string {
|
||||
if len(orTools) == 0 {
|
||||
return ""
|
||||
}
|
||||
openAITools := convertORToolsToOpenAIFormat(orTools)
|
||||
toolsBytes, err := json.Marshal(openAITools)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(toolsBytes)
|
||||
}
|
||||
|
||||
// GetResponseEndpoint returns a handler for GET /responses/:id
|
||||
// This endpoint is used for polling background responses or resuming streaming
|
||||
// @Summary Get a response by ID
|
||||
|
||||
@@ -2,15 +2,16 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/emirpasic/gods/v2/queues/circularbuffer"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/emirpasic/gods/v2/queues/circularbuffer"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -29,8 +30,12 @@ type APIExchangeResponse struct {
|
||||
|
||||
type APIExchange struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Request APIExchangeRequest `json:"request"`
|
||||
Response APIExchangeResponse `json:"response"`
|
||||
Error string `json:"error,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
UserName string `json:"user_name,omitempty"`
|
||||
}
|
||||
|
||||
var traceBuffer *circularbuffer.Queue[APIExchange]
|
||||
@@ -108,13 +113,18 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
}
|
||||
c.Response().Writer = mw
|
||||
|
||||
err = next(c)
|
||||
if err != nil {
|
||||
c.Response().Writer = mw.ResponseWriter // Restore original writer if error
|
||||
return err
|
||||
handlerErr := next(c)
|
||||
|
||||
// Restore original writer unconditionally
|
||||
c.Response().Writer = mw.ResponseWriter
|
||||
|
||||
// Determine response status (use 500 if handler errored and no status was set)
|
||||
status := c.Response().Status
|
||||
if status == 0 && handlerErr != nil {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// Create exchange log
|
||||
// Create exchange log (always, even on error)
|
||||
requestHeaders := c.Request().Header.Clone()
|
||||
requestBody := make([]byte, len(body))
|
||||
copy(requestBody, body)
|
||||
@@ -123,6 +133,7 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
copy(responseBody, resBody.Bytes())
|
||||
exchange := APIExchange{
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Request: APIExchangeRequest{
|
||||
Method: c.Request().Method,
|
||||
Path: c.Path(),
|
||||
@@ -130,11 +141,19 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
Body: &requestBody,
|
||||
},
|
||||
Response: APIExchangeResponse{
|
||||
Status: c.Response().Status,
|
||||
Status: status,
|
||||
Headers: &responseHeaders,
|
||||
Body: &responseBody,
|
||||
},
|
||||
}
|
||||
if handlerErr != nil {
|
||||
exchange.Error = handlerErr.Error()
|
||||
}
|
||||
|
||||
if user := auth.GetUser(c); user != nil {
|
||||
exchange.UserID = user.ID
|
||||
exchange.UserName = user.Name
|
||||
}
|
||||
|
||||
select {
|
||||
case logChan <- exchange:
|
||||
@@ -142,7 +161,7 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
xlog.Warn("Trace channel full, dropping trace")
|
||||
}
|
||||
|
||||
return nil
|
||||
return handlerErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
185
core/http/middleware/usage.go
Normal file
185
core/http/middleware/usage.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
usageFlushInterval = 5 * time.Second
|
||||
usageMaxPending = 5000
|
||||
)
|
||||
|
||||
// usageBatcher accumulates usage records and flushes them to the DB periodically.
|
||||
type usageBatcher struct {
|
||||
mu sync.Mutex
|
||||
pending []*auth.UsageRecord
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func (b *usageBatcher) add(r *auth.UsageRecord) {
|
||||
b.mu.Lock()
|
||||
b.pending = append(b.pending, r)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *usageBatcher) flush() {
|
||||
b.mu.Lock()
|
||||
batch := b.pending
|
||||
b.pending = nil
|
||||
b.mu.Unlock()
|
||||
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := b.db.Create(&batch).Error; err != nil {
|
||||
xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err)
|
||||
// Re-queue failed records with a cap to avoid unbounded growth
|
||||
b.mu.Lock()
|
||||
if len(b.pending) < usageMaxPending {
|
||||
b.pending = append(batch, b.pending...)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
var batcher *usageBatcher
|
||||
|
||||
// InitUsageRecorder starts a background goroutine that periodically flushes
|
||||
// accumulated usage records to the database.
|
||||
func InitUsageRecorder(db *gorm.DB) {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
batcher = &usageBatcher{db: db}
|
||||
go func() {
|
||||
ticker := time.NewTicker(usageFlushInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
batcher.flush()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// usageResponseBody is the minimal structure we need from the response JSON.
|
||||
type usageResponseBody struct {
|
||||
Model string `json:"model"`
|
||||
Usage *struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// UsageMiddleware extracts token usage from OpenAI-compatible response JSON
|
||||
// and records it per-user.
|
||||
func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if db == nil || batcher == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Wrap response writer to capture body
|
||||
resBody := new(bytes.Buffer)
|
||||
origWriter := c.Response().Writer
|
||||
mw := &bodyWriter{
|
||||
ResponseWriter: origWriter,
|
||||
body: resBody,
|
||||
}
|
||||
c.Response().Writer = mw
|
||||
|
||||
handlerErr := next(c)
|
||||
|
||||
// Restore original writer
|
||||
c.Response().Writer = origWriter
|
||||
|
||||
// Only record on successful responses
|
||||
if c.Response().Status < 200 || c.Response().Status >= 300 {
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
// Get authenticated user
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
// Try to parse usage from response
|
||||
responseBytes := resBody.Bytes()
|
||||
if len(responseBytes) == 0 {
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
// Check content type
|
||||
ct := c.Response().Header().Get("Content-Type")
|
||||
isJSON := ct == "" || ct == "application/json" || bytes.HasPrefix([]byte(ct), []byte("application/json"))
|
||||
isSSE := bytes.HasPrefix([]byte(ct), []byte("text/event-stream"))
|
||||
|
||||
if !isJSON && !isSSE {
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
var resp usageResponseBody
|
||||
if isSSE {
|
||||
last, ok := lastSSEData(responseBytes)
|
||||
if !ok {
|
||||
return handlerErr
|
||||
}
|
||||
if err := json.Unmarshal(last, &resp); err != nil {
|
||||
return handlerErr
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(responseBytes, &resp); err != nil {
|
||||
return handlerErr
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Usage == nil {
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
record := &auth.UsageRecord{
|
||||
UserID: user.ID,
|
||||
UserName: user.Name,
|
||||
Model: resp.Model,
|
||||
Endpoint: c.Request().URL.Path,
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
CreatedAt: startTime,
|
||||
}
|
||||
|
||||
batcher.add(record)
|
||||
|
||||
return handlerErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]".
|
||||
func lastSSEData(b []byte) ([]byte, bool) {
|
||||
prefix := []byte("data: ")
|
||||
var last []byte
|
||||
for _, line := range bytes.Split(b, []byte("\n")) {
|
||||
line = bytes.TrimRight(line, "\r")
|
||||
if bytes.HasPrefix(line, prefix) {
|
||||
payload := line[len(prefix):]
|
||||
if !bytes.Equal(payload, []byte("[DONE]")) {
|
||||
last = payload
|
||||
}
|
||||
}
|
||||
}
|
||||
return last, last != nil
|
||||
}
|
||||
64
core/http/react-ui/e2e/backend-logs.spec.js
Normal file
64
core/http/react-ui/e2e/backend-logs.spec.js
Normal file
@@ -0,0 +1,64 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
test.describe('Backend Logs', () => {
|
||||
test('model detail page shows title', async ({ page }) => {
|
||||
await page.goto('/app/backend-logs/mock-model')
|
||||
await expect(page.locator('.page-title')).toContainText('mock-model')
|
||||
})
|
||||
|
||||
test('no back arrow link on detail page', async ({ page }) => {
|
||||
await page.goto('/app/backend-logs/mock-model')
|
||||
await expect(page.locator('a[href="/app/backend-logs"]')).not.toBeVisible()
|
||||
})
|
||||
|
||||
test('filter buttons are visible', async ({ page }) => {
|
||||
await page.goto('/app/backend-logs/mock-model')
|
||||
await expect(page.locator('button', { hasText: 'All' })).toBeVisible()
|
||||
await expect(page.locator('button', { hasText: 'stdout' })).toBeVisible()
|
||||
await expect(page.locator('button', { hasText: 'stderr' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('filter buttons toggle active state', async ({ page }) => {
|
||||
await page.goto('/app/backend-logs/mock-model')
|
||||
|
||||
const allBtn = page.locator('button', { hasText: 'All' })
|
||||
const stdoutBtn = page.locator('button', { hasText: 'stdout' })
|
||||
|
||||
// All is active by default
|
||||
await expect(allBtn).toHaveClass(/btn-primary/)
|
||||
|
||||
// Click stdout
|
||||
await stdoutBtn.click()
|
||||
await expect(stdoutBtn).toHaveClass(/btn-primary/)
|
||||
await expect(allBtn).not.toHaveClass(/btn-primary/)
|
||||
})
|
||||
|
||||
test('export button is present', async ({ page }) => {
|
||||
await page.goto('/app/backend-logs/mock-model')
|
||||
await expect(page.locator('button', { hasText: 'Export' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('auto-scroll checkbox is present', async ({ page }) => {
|
||||
await page.goto('/app/backend-logs/mock-model')
|
||||
await expect(page.locator('text=Auto-scroll')).toBeVisible()
|
||||
})
|
||||
|
||||
test('clear button is present', async ({ page }) => {
|
||||
await page.goto('/app/backend-logs/mock-model')
|
||||
await expect(page.locator('button', { hasText: 'Clear' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('details toggle button is present and toggles', async ({ page }) => {
|
||||
await page.goto('/app/backend-logs/mock-model')
|
||||
|
||||
// "Text only" button visible by default (details are shown)
|
||||
const toggleBtn = page.locator('button', { hasText: 'Text only' })
|
||||
await expect(toggleBtn).toBeVisible()
|
||||
|
||||
// Click to hide details
|
||||
await toggleBtn.click()
|
||||
|
||||
// Button label changes to "Show details"
|
||||
await expect(page.locator('button', { hasText: 'Show details' })).toBeVisible()
|
||||
})
|
||||
})
|
||||
29
core/http/react-ui/e2e/manage-logs-link.spec.js
Normal file
29
core/http/react-ui/e2e/manage-logs-link.spec.js
Normal file
@@ -0,0 +1,29 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
test.describe('Manage Page - Backend Logs Link', () => {
|
||||
test('models table shows terminal icon for logs', async ({ page }) => {
|
||||
await page.goto('/app/manage')
|
||||
// Wait for models to load
|
||||
await expect(page.locator('.table')).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
// Check for terminal icon (backend logs link)
|
||||
const terminalIcon = page.locator('a[title="Backend logs"] i.fa-terminal')
|
||||
await expect(terminalIcon.first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('terminal icon links to backend-logs page', async ({ page }) => {
|
||||
await page.goto('/app/manage')
|
||||
await expect(page.locator('.table')).toBeVisible({ timeout: 10_000 })
|
||||
|
||||
const logsLink = page.locator('a[title="Backend logs"]').first()
|
||||
await expect(logsLink).toBeVisible()
|
||||
|
||||
// Link uses href="#" with onClick for navigation
|
||||
const href = await logsLink.getAttribute('href')
|
||||
expect(href).toBe('#')
|
||||
|
||||
// Click and verify navigation
|
||||
await logsLink.click()
|
||||
await expect(page).toHaveURL(/\/app\/backend-logs\//)
|
||||
})
|
||||
})
|
||||
80
core/http/react-ui/e2e/models-gallery.spec.js
Normal file
80
core/http/react-ui/e2e/models-gallery.spec.js
Normal file
@@ -0,0 +1,80 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
const MOCK_MODELS_RESPONSE = {
|
||||
models: [
|
||||
{ name: 'llama-model', description: 'A llama model', backend: 'llama-cpp', installed: false, tags: ['llm'] },
|
||||
{ name: 'whisper-model', description: 'A whisper model', backend: 'whisper', installed: true, tags: ['stt'] },
|
||||
{ name: 'stablediffusion-model', description: 'An image model', backend: 'stablediffusion', installed: false, tags: ['sd'] },
|
||||
{ name: 'unknown-model', description: 'No backend', backend: '', installed: false, tags: [] },
|
||||
],
|
||||
allBackends: ['llama-cpp', 'stablediffusion', 'whisper'],
|
||||
allTags: ['llm', 'sd', 'stt'],
|
||||
availableModels: 4,
|
||||
installedModels: 1,
|
||||
totalPages: 1,
|
||||
currentPage: 1,
|
||||
}
|
||||
|
||||
test.describe('Models Gallery - Backend Features', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.route('**/api/models*', (route) => {
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(MOCK_MODELS_RESPONSE),
|
||||
})
|
||||
})
|
||||
await page.goto('/app/models')
|
||||
// Wait for the table to render
|
||||
await expect(page.locator('th', { hasText: 'Backend' })).toBeVisible({ timeout: 10_000 })
|
||||
})
|
||||
|
||||
test('backend column header is visible', async ({ page }) => {
|
||||
await expect(page.locator('th', { hasText: 'Backend' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('backend badges shown in table rows', async ({ page }) => {
|
||||
const table = page.locator('table')
|
||||
await expect(table.locator('.badge', { hasText: 'llama-cpp' })).toBeVisible()
|
||||
await expect(table.locator('.badge', { hasText: /^whisper$/ })).toBeVisible()
|
||||
})
|
||||
|
||||
test('backend dropdown is visible', async ({ page }) => {
|
||||
await expect(page.locator('button', { hasText: 'All Backends' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('clicking backend dropdown opens searchable panel', async ({ page }) => {
|
||||
await page.locator('button', { hasText: 'All Backends' }).click()
|
||||
await expect(page.locator('input[placeholder="Search backends..."]')).toBeVisible()
|
||||
})
|
||||
|
||||
test('typing in search filters dropdown options', async ({ page }) => {
|
||||
await page.locator('button', { hasText: 'All Backends' }).click()
|
||||
const searchInput = page.locator('input[placeholder="Search backends..."]')
|
||||
await searchInput.fill('llama')
|
||||
|
||||
// llama-cpp option should be visible, whisper should not
|
||||
const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..') .locator('..')
|
||||
await expect(dropdown.locator('text=llama-cpp')).toBeVisible()
|
||||
await expect(dropdown.locator('text=whisper')).not.toBeVisible()
|
||||
})
|
||||
|
||||
test('selecting a backend updates the dropdown label', async ({ page }) => {
|
||||
await page.locator('button', { hasText: 'All Backends' }).click()
|
||||
// Click the llama-cpp option within the dropdown (not the table badge)
|
||||
const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..').locator('..')
|
||||
await dropdown.locator('text=llama-cpp').click()
|
||||
|
||||
// The dropdown button should now show the selected backend instead of "All Backends"
|
||||
await expect(page.locator('button span', { hasText: 'llama-cpp' })).toBeVisible()
|
||||
})
|
||||
|
||||
test('expanded row shows backend in detail', async ({ page }) => {
|
||||
// Click the first model row to expand it
|
||||
await page.locator('tr', { hasText: 'llama-model' }).click()
|
||||
|
||||
// The detail view should show Backend label and value
|
||||
const detail = page.locator('td[colspan="8"]')
|
||||
await expect(detail.locator('text=Backend')).toBeVisible()
|
||||
await expect(detail.locator('text=llama-cpp')).toBeVisible()
|
||||
})
|
||||
})
|
||||
23
core/http/react-ui/e2e/navigation.spec.js
Normal file
23
core/http/react-ui/e2e/navigation.spec.js
Normal file
@@ -0,0 +1,23 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
test.describe('Navigation', () => {
|
||||
test('/ redirects to /app', async ({ page }) => {
|
||||
await page.goto('/')
|
||||
await expect(page).toHaveURL(/\/app/)
|
||||
})
|
||||
|
||||
test('/app shows home page with LocalAI title', async ({ page }) => {
|
||||
await page.goto('/app')
|
||||
await expect(page.locator('.sidebar')).toBeVisible()
|
||||
await expect(page.locator('.home-page')).toBeVisible()
|
||||
})
|
||||
|
||||
test('sidebar traces link navigates to /app/traces', async ({ page }) => {
|
||||
await page.goto('/app')
|
||||
const tracesLink = page.locator('a.nav-item[href="/app/traces"]')
|
||||
await expect(tracesLink).toBeVisible()
|
||||
await tracesLink.click()
|
||||
await expect(page).toHaveURL(/\/app\/traces/)
|
||||
await expect(page.getByRole('heading', { name: 'Traces', exact: true })).toBeVisible()
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user