mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-20 14:46:38 -04:00
Compare commits
208 Commits
v4.0.0
...
feat/backe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fe87cb0d5 | ||
|
|
6dd37a95c4 | ||
|
|
ee00a10836 | ||
|
|
948f3bfaa4 | ||
|
|
1e083cd870 | ||
|
|
b19e60d03a | ||
|
|
4d463e9f0d | ||
|
|
ae4ae5f425 | ||
|
|
7c1865b307 | ||
|
|
62a674ce12 | ||
|
|
c39213443b | ||
|
|
606f462da4 | ||
|
|
5c35e85fe2 | ||
|
|
062e0d0d00 | ||
|
|
d4cd6c284f | ||
|
|
3bb8b65d31 | ||
|
|
9748a1cbc6 | ||
|
|
6bc76dda6d | ||
|
|
e1a6010874 | ||
|
|
706cf5d43c | ||
|
|
13a6ed709c | ||
|
|
85be4ff03c | ||
|
|
b0d9ce4905 | ||
|
|
7081b54c09 | ||
|
|
2b05420f95 | ||
|
|
b64347b6aa | ||
|
|
e00ce981f0 | ||
|
|
285f7d4340 | ||
|
|
ea6e850809 | ||
|
|
b7247fc148 | ||
|
|
39c6b3ed66 | ||
|
|
0e9d1a6588 | ||
|
|
510d6759fe | ||
|
|
154fa000d3 | ||
|
|
0526e60f8d | ||
|
|
db600fb5b2 | ||
|
|
9ac1bdc587 | ||
|
|
fdc9f7bf35 | ||
|
|
8e59346091 | ||
|
|
e6e4e19633 | ||
|
|
505c417fa7 | ||
|
|
17215f6fbc | ||
|
|
bccaba1f66 | ||
|
|
0f9d516a6c | ||
|
|
33b124c6f1 | ||
|
|
6b8007e88e | ||
|
|
b3837c2078 | ||
|
|
92f99b1ec3 | ||
|
|
ad232fdb1a | ||
|
|
11637b5a1b | ||
|
|
0dda4fe6f0 | ||
|
|
773489eeb1 | ||
|
|
06fbe48b3f | ||
|
|
232e324a68 | ||
|
|
39c954764c | ||
|
|
9b7d5513fc | ||
|
|
84cd8c0e7f | ||
|
|
d990f2790c | ||
|
|
53deeb1107 | ||
|
|
c5a840f6af | ||
|
|
6d9d77d590 | ||
|
|
6f304d1201 | ||
|
|
557d0f0f04 | ||
|
|
b7e3589875 | ||
|
|
716ddd697b | ||
|
|
223deb908d | ||
|
|
9f8821bba8 | ||
|
|
84e51b68ef | ||
|
|
7962dd16f7 | ||
|
|
a1466b305a | ||
|
|
57c0026715 | ||
|
|
1ed6b9e5ed | ||
|
|
e4ee74354f | ||
|
|
8577bdcebc | ||
|
|
0d489c7a0d | ||
|
|
11dc54bda9 | ||
|
|
7e0b73deaa | ||
|
|
c0a023d13d | ||
|
|
0d3ae1c295 | ||
|
|
e9f10f2f50 | ||
|
|
b95b0b72ff | ||
|
|
26f1b94f4d | ||
|
|
2d40725ca2 | ||
|
|
6c635e8353 | ||
|
|
cc5f33ce95 | ||
|
|
ba7cdd532a | ||
|
|
6b6c136210 | ||
|
|
e587ecc485 | ||
|
|
f259036a27 | ||
|
|
221ff0f28f | ||
|
|
16d5cb00bd | ||
|
|
952635fba6 | ||
|
|
3cc05af2e5 | ||
|
|
87a63316c7 | ||
|
|
efdcbbe332 | ||
|
|
b4fff9293d | ||
|
|
8180221b7e | ||
|
|
52a9755e08 | ||
|
|
a2a1d919f9 | ||
|
|
a3d37931ec | ||
|
|
5b2e25ebb0 | ||
|
|
b0b37a472f | ||
|
|
3db12eaa7a | ||
|
|
8862e3ce60 | ||
|
|
80699a3f70 | ||
|
|
309a59f61e | ||
|
|
65c9380389 | ||
|
|
79963c56bf | ||
|
|
7004ce0b78 | ||
|
|
702d0e0e4d | ||
|
|
d6de208d6c | ||
|
|
7451145e0c | ||
|
|
cfda3dd0df | ||
|
|
e0eb2fd734 | ||
|
|
dd3376e0a9 | ||
|
|
520e1ce3cd | ||
|
|
3d738164b7 | ||
|
|
56db76599a | ||
|
|
ad57cdfefe | ||
|
|
c2f7d1c18b | ||
|
|
afe79568d6 | ||
|
|
59108fbe32 | ||
|
|
4c870288d9 | ||
|
|
8da7212763 | ||
|
|
6e76052f9d | ||
|
|
cf84db36ec | ||
|
|
d3f629f183 | ||
|
|
b1aa707a92 | ||
|
|
731176ce3a | ||
|
|
b86fa63f70 | ||
|
|
00fcf6936c | ||
|
|
26384c5c70 | ||
|
|
7209457f53 | ||
|
|
9bc68b2721 | ||
|
|
7bdd198fd3 | ||
|
|
b296e3d94b | ||
|
|
c91855a9b2 | ||
|
|
e8e445cd43 | ||
|
|
735c426072 | ||
|
|
0976b8a17b | ||
|
|
2ad8c149e0 | ||
|
|
31fcb1425d | ||
|
|
470d5e506f | ||
|
|
0ee49cf42e | ||
|
|
cecd8d6aa5 | ||
|
|
15935e9d5f | ||
|
|
5d410e5a03 | ||
|
|
5df77d7e8c | ||
|
|
f891d60d26 | ||
|
|
be25217955 | ||
|
|
b74111feed | ||
|
|
bf92117259 | ||
|
|
031a36c995 | ||
|
|
8036d22ec6 | ||
|
|
f7e8d9e791 | ||
|
|
4b183b7bb6 | ||
|
|
f38e91d80b | ||
|
|
aa3e82976e | ||
|
|
d9c1db2b87 | ||
|
|
f7e3aab4fc | ||
|
|
73bdc3b50d | ||
|
|
cb63bdb9e4 | ||
|
|
8cd3f9fc47 | ||
|
|
e0ab1a8b43 | ||
|
|
c3174f9543 | ||
|
|
2b12875302 | ||
|
|
9cdbd89c1f | ||
|
|
7d81bf0aa3 | ||
|
|
a6d0e29eba | ||
|
|
6054d2a91b | ||
|
|
aea21951a2 | ||
|
|
bbe9067227 | ||
|
|
9a9da062e1 | ||
|
|
dd1a8b174f | ||
|
|
cfb7641eea | ||
|
|
e832efeb9e | ||
|
|
a42548e9d1 | ||
|
|
8615ce28a8 | ||
|
|
8336efec41 | ||
|
|
8560a1e571 | ||
|
|
29c33e6a6a | ||
|
|
a58475dbef | ||
|
|
8a0edd0809 | ||
|
|
35d509d8e7 | ||
|
|
eef808d921 | ||
|
|
9d9ea5c1a0 | ||
|
|
e21ad5cfaa | ||
|
|
05ab0c0aa2 | ||
|
|
e2b6233570 | ||
|
|
19f995f38f | ||
|
|
ac168bbc60 | ||
|
|
5c5e537b31 | ||
|
|
118bcee196 | ||
|
|
3eabd6d1d0 | ||
|
|
ee96e5e08d | ||
|
|
3d9ccd1ddc | ||
|
|
d8161bfe57 | ||
|
|
5fd42399d4 | ||
|
|
b2030255ca | ||
|
|
9f903ec06e | ||
|
|
4ea461c330 | ||
|
|
042a9b8ef6 | ||
|
|
65f1a4154a | ||
|
|
c6a51289b0 | ||
|
|
87525109f1 | ||
|
|
c596d8a5d9 | ||
|
|
d79ad76e48 | ||
|
|
dde0353432 |
111
.agents/adding-gallery-models.md
Normal file
111
.agents/adding-gallery-models.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# Adding GGUF Models from HuggingFace to the Gallery
|
||||
|
||||
When adding a GGUF model from HuggingFace to the LocalAI model gallery, follow this guide.
|
||||
|
||||
## Gallery file
|
||||
|
||||
All models are defined in `gallery/index.yaml`. Find the appropriate section (embedding models near other embeddings, chat models near similar chat models) and add a new entry.
|
||||
|
||||
## Getting the SHA256
|
||||
|
||||
GGUF files on HuggingFace expose their SHA256 via the `x-linked-etag` HTTP header. Fetch it with:
|
||||
|
||||
```bash
|
||||
curl -sI "https://huggingface.co/<org>/<repo>/resolve/main/<filename>.gguf" | grep -i x-linked-etag
|
||||
```
|
||||
|
||||
The value (without quotes) is the SHA256 hash. Example:
|
||||
|
||||
```bash
|
||||
curl -sI "https://huggingface.co/ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/resolve/main/embeddinggemma-300m-qat-Q8_0.gguf" | grep -i x-linked-etag
|
||||
# x-linked-etag: "6fa0c02a9c302be6f977521d399b4de3a46310a4f2621ee0063747881b673f67"
|
||||
```
|
||||
|
||||
**Important**: Pay attention to exact filename casing — HuggingFace filenames are case-sensitive (e.g., `Q8_0` vs `q8_0`). Check the repo's file listing to get the exact name.
|
||||
|
||||
## Entry format — Embedding models
|
||||
|
||||
Embedding models use `gallery/virtual.yaml` as the base config and set `embeddings: true`:
|
||||
|
||||
```yaml
|
||||
- name: "model-name"
|
||||
url: github:mudler/LocalAI/gallery/virtual.yaml@master
|
||||
urls:
|
||||
- https://huggingface.co/<original-model-org>/<original-model-name>
|
||||
- https://huggingface.co/<gguf-org>/<gguf-repo-name>
|
||||
description: |
|
||||
Short description of the model, its size, and capabilities.
|
||||
tags:
|
||||
- embeddings
|
||||
overrides:
|
||||
backend: llama-cpp
|
||||
embeddings: true
|
||||
parameters:
|
||||
model: <filename>.gguf
|
||||
files:
|
||||
- filename: <filename>.gguf
|
||||
uri: huggingface://<gguf-org>/<gguf-repo-name>/<filename>.gguf
|
||||
sha256: <sha256-hash>
|
||||
```
|
||||
|
||||
## Entry format — Chat/LLM models
|
||||
|
||||
Chat models typically reference a template config (e.g., `gallery/gemma.yaml`, `gallery/chatml.yaml`) that defines the prompt format. Use YAML anchors (`&name` / `*name`) if adding multiple quantization variants of the same model:
|
||||
|
||||
```yaml
|
||||
- &model-anchor
|
||||
url: "github:mudler/LocalAI/gallery/<template>.yaml@master"
|
||||
name: "model-name"
|
||||
icon: https://example.com/icon.png
|
||||
license: <license>
|
||||
urls:
|
||||
- https://huggingface.co/<org>/<model>
|
||||
- https://huggingface.co/<gguf-org>/<gguf-repo>
|
||||
description: |
|
||||
Model description.
|
||||
tags:
|
||||
- llm
|
||||
- gguf
|
||||
- gpu
|
||||
- cpu
|
||||
overrides:
|
||||
parameters:
|
||||
model: <filename>-Q4_K_M.gguf
|
||||
files:
|
||||
- filename: <filename>-Q4_K_M.gguf
|
||||
sha256: <sha256>
|
||||
uri: huggingface://<gguf-org>/<gguf-repo>/<filename>-Q4_K_M.gguf
|
||||
```
|
||||
|
||||
To add a variant (e.g., different quantization), use YAML merge:
|
||||
|
||||
```yaml
|
||||
- !!merge <<: *model-anchor
|
||||
name: "model-name-q8"
|
||||
overrides:
|
||||
parameters:
|
||||
model: <filename>-Q8_0.gguf
|
||||
files:
|
||||
- filename: <filename>-Q8_0.gguf
|
||||
sha256: <sha256>
|
||||
uri: huggingface://<gguf-org>/<gguf-repo>/<filename>-Q8_0.gguf
|
||||
```
|
||||
|
||||
## Available template configs
|
||||
|
||||
Look at existing `.yaml` files in `gallery/` to find the right prompt template for your model architecture:
|
||||
|
||||
- `gemma.yaml` — Gemma-family models (gemma, embeddinggemma, etc.)
|
||||
- `chatml.yaml` — ChatML format (many Mistral/OpenHermes models)
|
||||
- `deepseek.yaml` — DeepSeek models
|
||||
- `virtual.yaml` — Minimal base (good for embedding models that don't need chat templates)
|
||||
|
||||
## Checklist
|
||||
|
||||
1. **Find the GGUF file** on HuggingFace — note exact filename (case-sensitive)
|
||||
2. **Get the SHA256** using the `curl -sI` + `x-linked-etag` method above
|
||||
3. **Choose the right template** config from `gallery/` based on model architecture
|
||||
4. **Add the entry** to `gallery/index.yaml` near similar models
|
||||
5. **Set `embeddings: true`** if it's an embedding model
|
||||
6. **Include both URLs** — the original model page and the GGUF repo
|
||||
7. **Write a description** — mention model size, capabilities, and quantization type
|
||||
259
.agents/api-endpoints-and-auth.md
Normal file
259
.agents/api-endpoints-and-auth.md
Normal file
@@ -0,0 +1,259 @@
|
||||
# API Endpoints and Authentication
|
||||
|
||||
This guide covers how to add new API endpoints and properly integrate them with the auth/permissions system.
|
||||
|
||||
## Architecture overview
|
||||
|
||||
Authentication and authorization flow through three layers:
|
||||
|
||||
1. **Global auth middleware** (`core/http/auth/middleware.go` → `auth.Middleware`) — applied to every request in `core/http/app.go`. Handles session cookies, Bearer tokens, API keys, and legacy API keys. Populates `auth_user` and `auth_role` in the Echo context.
|
||||
2. **Feature middleware** (`auth.RequireFeature`) — per-feature access control applied to route groups or individual routes. Checks if the authenticated user has the specific feature enabled.
|
||||
3. **Admin middleware** (`auth.RequireAdmin`) — restricts endpoints to admin users only.
|
||||
|
||||
When auth is disabled (no auth DB, no legacy API keys), all middleware becomes pass-through (`auth.NoopMiddleware`).
|
||||
|
||||
## Adding a new API endpoint
|
||||
|
||||
### Step 1: Create the handler
|
||||
|
||||
Write the endpoint handler in the appropriate package under `core/http/endpoints/`. Follow existing patterns:
|
||||
|
||||
```go
|
||||
// core/http/endpoints/localai/my_feature.go
|
||||
func MyFeatureEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Use auth.GetUser(c) to get the authenticated user (may be nil if auth is disabled)
|
||||
user := auth.GetUser(c)
|
||||
|
||||
// Your logic here
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Register routes
|
||||
|
||||
Add routes in the appropriate file under `core/http/routes/`. The file you use depends on the endpoint category:
|
||||
|
||||
| File | Category |
|
||||
|------|----------|
|
||||
| `routes/openai.go` | OpenAI-compatible API endpoints (`/v1/...`) |
|
||||
| `routes/localai.go` | LocalAI-specific endpoints (`/api/...`, `/models/...`, `/backends/...`) |
|
||||
| `routes/agents.go` | Agent pool endpoints (`/api/agents/...`) |
|
||||
| `routes/auth.go` | Auth endpoints (`/api/auth/...`) |
|
||||
| `routes/ui_api.go` | UI backend API endpoints |
|
||||
|
||||
### Step 3: Apply the right middleware
|
||||
|
||||
Choose the appropriate protection level:
|
||||
|
||||
#### No auth required (public)
|
||||
Exempt paths bypass auth entirely. Add to `isExemptPath()` in `middleware.go` or use the `/api/auth/` prefix (always exempt). Use sparingly — most endpoints should require auth.
|
||||
|
||||
#### Standard auth (any authenticated user)
|
||||
The global middleware already handles this. API paths (`/api/`, `/v1/`, etc.) automatically require authentication when auth is enabled. You don't need to add any extra middleware.
|
||||
|
||||
```go
|
||||
router.GET("/v1/my-endpoint", myHandler) // auth enforced by global middleware
|
||||
```
|
||||
|
||||
#### Admin only
|
||||
Pass `adminMiddleware` to the route. This is set up in `app.go` and passed to `Register*Routes` functions:
|
||||
|
||||
```go
|
||||
// In the Register function signature, accept the middleware:
|
||||
func RegisterMyRoutes(router *echo.Echo, app *application.Application, adminMiddleware echo.MiddlewareFunc) {
|
||||
router.POST("/models/apply", myHandler, adminMiddleware)
|
||||
}
|
||||
```
|
||||
|
||||
#### Feature-gated
|
||||
For endpoints that should be toggleable per-user, use feature middleware. There are two approaches:
|
||||
|
||||
**Approach A: Route-level middleware** (preferred for groups of related endpoints)
|
||||
|
||||
```go
|
||||
// In app.go, create the feature middleware:
|
||||
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
|
||||
|
||||
// Pass it to the route registration function:
|
||||
routes.RegisterMyRoutes(e, app, myFeatureMw)
|
||||
|
||||
// In the routes file, apply to a group:
|
||||
g := e.Group("/api/my-feature", myFeatureMw)
|
||||
g.GET("", listHandler)
|
||||
g.POST("", createHandler)
|
||||
```
|
||||
|
||||
**Approach B: RouteFeatureRegistry** (preferred for individual OpenAI-compatible endpoints)
|
||||
|
||||
Add an entry to `RouteFeatureRegistry` in `core/http/auth/features.go`. The `RequireRouteFeature` global middleware will automatically enforce it:
|
||||
|
||||
```go
|
||||
var RouteFeatureRegistry = []RouteFeature{
|
||||
// ... existing entries ...
|
||||
{"POST", "/v1/my-endpoint", FeatureMyFeature},
|
||||
}
|
||||
```
|
||||
|
||||
## Adding a new feature
|
||||
|
||||
When you need a new toggleable feature (not just a new endpoint under an existing feature):
|
||||
|
||||
### 1. Define the feature constant
|
||||
|
||||
Add to `core/http/auth/permissions.go`:
|
||||
|
||||
```go
|
||||
const (
|
||||
// Add to the appropriate group:
|
||||
// Agent features (default OFF for new users)
|
||||
FeatureMyFeature = "my_feature"
|
||||
|
||||
// OR API features (default ON for new users)
|
||||
FeatureMyFeature = "my_feature"
|
||||
)
|
||||
```
|
||||
|
||||
Then add it to the appropriate slice:
|
||||
|
||||
```go
|
||||
// Default OFF — user must be explicitly granted access:
|
||||
var AgentFeatures = []string{..., FeatureMyFeature}
|
||||
|
||||
// Default ON — user has access unless explicitly revoked:
|
||||
var APIFeatures = []string{..., FeatureMyFeature}
|
||||
```
|
||||
|
||||
### 2. Add feature metadata
|
||||
|
||||
In `core/http/auth/features.go`, add to the appropriate `FeatureMetas` function so the admin UI can display it:
|
||||
|
||||
```go
|
||||
func AgentFeatureMetas() []FeatureMeta {
|
||||
return []FeatureMeta{
|
||||
// ... existing ...
|
||||
{FeatureMyFeature, "My Feature", false}, // false = default OFF
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Wire up the middleware
|
||||
|
||||
In `core/http/app.go`:
|
||||
|
||||
```go
|
||||
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
|
||||
```
|
||||
|
||||
Then pass it to the route registration function.
|
||||
|
||||
### 4. Register route-feature mappings (if applicable)
|
||||
|
||||
If your feature gates standard API endpoints (like `/v1/...`), add entries to `RouteFeatureRegistry` in `features.go` instead of using per-route middleware.
|
||||
|
||||
## Accessing the authenticated user in handlers
|
||||
|
||||
```go
|
||||
import "github.com/mudler/LocalAI/core/http/auth"
|
||||
|
||||
func MyHandler(c echo.Context) error {
|
||||
// Get the user (nil when auth is disabled or unauthenticated)
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
// Handle unauthenticated — or let middleware handle it
|
||||
}
|
||||
|
||||
// Check role
|
||||
if user.Role == auth.RoleAdmin {
|
||||
// admin-specific logic
|
||||
}
|
||||
|
||||
// Check feature access programmatically (when you need conditional behavior, not full blocking)
|
||||
if auth.HasFeatureAccess(db, user, auth.FeatureMyFeature) {
|
||||
// feature-specific logic
|
||||
}
|
||||
|
||||
// Check model access
|
||||
if !auth.IsModelAllowed(db, user, modelName) {
|
||||
return c.JSON(http.StatusForbidden, ...)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Middleware composition patterns
|
||||
|
||||
Middleware can be composed at different levels. Here are the patterns used in the codebase:
|
||||
|
||||
### Group-level middleware (agents pattern)
|
||||
```go
|
||||
// All routes in the group share the middleware
|
||||
g := e.Group("/api/agents", poolReadyMw, agentsMw)
|
||||
g.GET("", listHandler)
|
||||
g.POST("", createHandler)
|
||||
```
|
||||
|
||||
### Per-route middleware (localai pattern)
|
||||
```go
|
||||
// Individual routes get middleware as extra arguments
|
||||
router.POST("/models/apply", applyHandler, adminMiddleware)
|
||||
router.GET("/metrics", metricsHandler, adminMiddleware)
|
||||
```
|
||||
|
||||
### Middleware slice (openai pattern)
|
||||
```go
|
||||
// Build a middleware chain for a handler
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
modelFilterMiddleware,
|
||||
}
|
||||
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
|
||||
```
|
||||
|
||||
## Error response format
|
||||
|
||||
Always use `schema.ErrorResponse` for auth/permission errors to stay consistent with the OpenAI-compatible API:
|
||||
|
||||
```go
|
||||
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "feature not enabled for your account",
|
||||
Code: http.StatusForbidden,
|
||||
Type: "authorization_error",
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Use these HTTP status codes:
|
||||
- `401 Unauthorized` — no valid credentials provided
|
||||
- `403 Forbidden` — authenticated but lacking permission
|
||||
- `429 Too Many Requests` — rate limited (auth endpoints)
|
||||
|
||||
## Usage tracking
|
||||
|
||||
If your endpoint should be tracked for usage (token counts, request counts), add the `usageMiddleware` to its middleware chain. See `core/http/middleware/usage.go` and how it's applied in `routes/openai.go`.
|
||||
|
||||
## Path protection rules
|
||||
|
||||
The global auth middleware classifies paths as API paths or non-API paths:
|
||||
|
||||
- **API paths** (always require auth when auth is enabled): `/api/`, `/v1/`, `/models/`, `/backends/`, `/backend/`, `/tts`, `/vad`, `/video`, `/stores/`, `/system`, `/ws/`, `/metrics`
|
||||
- **Exempt paths** (never require auth): `/api/auth/` prefix, anything in `appConfig.PathWithoutAuth`
|
||||
- **Non-API paths** (UI, static assets): pass through without auth — the React UI handles login redirects client-side
|
||||
|
||||
If you add endpoints under a new top-level path prefix, add it to `isAPIPath()` in `middleware.go` to ensure it requires authentication.
|
||||
|
||||
## Checklist
|
||||
|
||||
When adding a new endpoint:
|
||||
|
||||
- [ ] Handler in `core/http/endpoints/`
|
||||
- [ ] Route registered in appropriate `core/http/routes/` file
|
||||
- [ ] Auth level chosen: public / standard / admin / feature-gated
|
||||
- [ ] If feature-gated: constant in `permissions.go`, metadata in `features.go`, middleware in `app.go`
|
||||
- [ ] If new path prefix: added to `isAPIPath()` in `middleware.go`
|
||||
- [ ] If OpenAI-compatible: entry in `RouteFeatureRegistry`
|
||||
- [ ] If token-counting: `usageMiddleware` added to middleware chain
|
||||
- [ ] Error responses use `schema.ErrorResponse` format
|
||||
- [ ] Tests cover both authenticated and unauthenticated access
|
||||
@@ -49,3 +49,4 @@ The project documentation is located in `docs/content`. When adding new features
|
||||
- **Feature Documentation**: If you add a new feature (like a new backend or API endpoint), create a new markdown file in `docs/content/features/` explaining what it is, how to configure it, and how to use it.
|
||||
- **Configuration**: If you modify configuration options, update the relevant sections in `docs/content/`.
|
||||
- **Examples**: providing concrete examples (like YAML configuration blocks) is highly encouraged to help users get started quickly.
|
||||
- **Shortcodes**: Use `{{% notice note %}}`, `{{% notice tip %}}`, or `{{% notice warning %}}` for callout boxes. Do **not** use `{{% alert %}}` — that shortcode does not exist in this project's Hugo theme and will break the docs build.
|
||||
|
||||
141
.agents/debugging-backends.md
Normal file
141
.agents/debugging-backends.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Debugging and Rebuilding Backends
|
||||
|
||||
When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
- **Source directory**: `backend/python/<name>/` (or `backend/go/<name>/`, `backend/cpp/<name>/`)
|
||||
- **Installed directory**: `backends/<name>/` — this is what LocalAI actually runs. It is populated by `make backends/<name>` which builds a Docker image, exports it, and installs it via `local-ai backends install`.
|
||||
- **Virtual environment**: `backends/<name>/venv/` — the installed Python venv (for Python backends). The Python binary is at `backends/<name>/venv/bin/python`.
|
||||
|
||||
Editing files in `backend/python/<name>/` does **not** affect the running backend until you rebuild with `make backends/<name>`.
|
||||
|
||||
## Diagnosing Failures
|
||||
|
||||
### 1. Check the logs
|
||||
|
||||
Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID:
|
||||
|
||||
```
|
||||
GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..."
|
||||
```
|
||||
|
||||
Common error patterns:
|
||||
- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub.
|
||||
- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`).
|
||||
- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback.
|
||||
|
||||
### 2. Test the Python environment directly
|
||||
|
||||
You can run the installed venv's Python to check imports without starting the full server:
|
||||
|
||||
```bash
|
||||
backends/<name>/venv/bin/python -c "import datasets; print(datasets.__version__)"
|
||||
```
|
||||
|
||||
If `pip` is missing from the venv, bootstrap it:
|
||||
|
||||
```bash
|
||||
backends/<name>/venv/bin/python -m ensurepip
|
||||
```
|
||||
|
||||
Then use `backends/<name>/venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements.
|
||||
|
||||
### 3. Check upstream dependency constraints
|
||||
|
||||
When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`:
|
||||
|
||||
```
|
||||
https://github.com/huggingface/trl/blob/main/requirements.txt
|
||||
```
|
||||
|
||||
Pin minimum versions in the backend's requirements files to match upstream.
|
||||
|
||||
## Common Fixes
|
||||
|
||||
### Missing gRPC methods
|
||||
|
||||
If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`:
|
||||
|
||||
```python
|
||||
def LoadModel(self, request, context):
|
||||
"""No-op — actual loading happens elsewhere."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
```
|
||||
|
||||
The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading.
|
||||
|
||||
### Dependency version conflicts
|
||||
|
||||
Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps:
|
||||
|
||||
1. Identify the broken import in the logs
|
||||
2. Test in the installed venv: `backends/<name>/venv/bin/python -c "import <module>"`
|
||||
3. Check upstream requirements for version constraints
|
||||
4. Update **all** requirements files in `backend/python/<name>/`:
|
||||
- `requirements.txt` — base deps (grpcio, protobuf)
|
||||
- `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index)
|
||||
- `requirements-cublas12.txt` — CUDA 12
|
||||
- `requirements-cublas13.txt` — CUDA 13
|
||||
5. Rebuild: `make backends/<name>`
|
||||
|
||||
### PyTorch index conflicts (uv resolver)
|
||||
|
||||
The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`:
|
||||
|
||||
```bash
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
installRequirements
|
||||
```
|
||||
|
||||
Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference.
|
||||
|
||||
## Rebuilding
|
||||
|
||||
### Rebuild a single backend
|
||||
|
||||
```bash
|
||||
make backends/<name>
|
||||
```
|
||||
|
||||
This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/<name>.tar`, and installs it into `backends/<name>/`. It also rebuilds the `local-ai` Go binary (without extra tags).
|
||||
|
||||
**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/<name>` step will overwrite your binary without that tag. Rebuild the Go binary afterward:
|
||||
|
||||
```bash
|
||||
GO_TAGS=auth make build
|
||||
```
|
||||
|
||||
### Rebuild and restart
|
||||
|
||||
After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded.
|
||||
|
||||
```bash
|
||||
# Kill existing process
|
||||
kill <pid>
|
||||
|
||||
# Restart
|
||||
./local-ai run --debug [your flags]
|
||||
```
|
||||
|
||||
### Quick iteration (skip Docker rebuild)
|
||||
|
||||
For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly:
|
||||
|
||||
```bash
|
||||
# Edit the installed copy
|
||||
vim backends/<name>/backend.py
|
||||
|
||||
# Restart LocalAI to respawn the gRPC process
|
||||
```
|
||||
|
||||
This is useful for testing but **does not persist** — the next `make backends/<name>` will overwrite it. Always commit fixes to the source in `backend/python/<name>/`.
|
||||
|
||||
## Verification
|
||||
|
||||
After fixing and rebuilding:
|
||||
|
||||
1. Start LocalAI and confirm the backend registers: look for `Registering backend name="<name>"` in the logs
|
||||
2. Trigger the operation that failed (e.g. start a fine-tuning job)
|
||||
3. Watch the GRPC stderr/stdout lines for the backend's model ID
|
||||
4. Confirm no errors in the traceback
|
||||
3
.github/gallery-agent/agent.go
vendored
3
.github/gallery-agent/agent.go
vendored
@@ -133,6 +133,7 @@ func getRealReadme(ctx context.Context, repository string) (string, error) {
|
||||
result, err := cogito.ExecuteTools(llm, fragment,
|
||||
cogito.WithIterations(3),
|
||||
cogito.WithMaxAttempts(3),
|
||||
cogito.DisableSinkState,
|
||||
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -406,7 +407,7 @@ func getHuggingFaceAvatarURL(author string) string {
|
||||
}
|
||||
|
||||
// Parse the response to get avatar URL
|
||||
var userInfo map[string]interface{}
|
||||
var userInfo map[string]any
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
|
||||
15
.github/gallery-agent/gallery.go
vendored
15
.github/gallery-agent/gallery.go
vendored
@@ -79,7 +79,20 @@ func generateYAMLEntry(model ProcessedModel, quantization string) string {
|
||||
description = cleanTextContent(description)
|
||||
formattedDescription := formatTextContent(description)
|
||||
|
||||
configFile := formatTextContent(modelConfig.ConfigFile)
|
||||
// Strip name and description from config file since they are
|
||||
// already present at the gallery entry level and should not
|
||||
// appear under overrides.
|
||||
configFileContent := modelConfig.ConfigFile
|
||||
var cfgMap map[string]any
|
||||
if err := yaml.Unmarshal([]byte(configFileContent), &cfgMap); err == nil {
|
||||
delete(cfgMap, "name")
|
||||
delete(cfgMap, "description")
|
||||
if cleaned, err := yaml.Marshal(cfgMap); err == nil {
|
||||
configFileContent = string(cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
configFile := formatTextContent(configFileContent)
|
||||
|
||||
filesYAML, _ := yaml.Marshal(modelConfig.Files)
|
||||
|
||||
|
||||
40
.github/gallery-agent/testing.go
vendored
40
.github/gallery-agent/testing.go
vendored
@@ -3,7 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -13,11 +13,11 @@ func runSyntheticMode() error {
|
||||
generator := NewSyntheticDataGenerator()
|
||||
|
||||
// Generate a random number of synthetic models (1-3)
|
||||
numModels := generator.rand.Intn(3) + 1
|
||||
numModels := generator.rand.IntN(3) + 1
|
||||
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
||||
|
||||
var models []ProcessedModel
|
||||
for i := 0; i < numModels; i++ {
|
||||
for range numModels {
|
||||
model := generator.GenerateProcessedModel()
|
||||
models = append(models, model)
|
||||
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
||||
@@ -42,14 +42,14 @@ type SyntheticDataGenerator struct {
|
||||
// NewSyntheticDataGenerator creates a new synthetic data generator
|
||||
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
||||
return &SyntheticDataGenerator{
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
||||
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
||||
fileTypes := []string{"model", "readme", "other"}
|
||||
fileType := fileTypes[g.rand.Intn(len(fileTypes))]
|
||||
fileType := fileTypes[g.rand.IntN(len(fileTypes))]
|
||||
|
||||
var path string
|
||||
var isReadme bool
|
||||
@@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile
|
||||
|
||||
return ProcessedModelFile{
|
||||
Path: path,
|
||||
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
|
||||
Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB
|
||||
SHA256: g.randomSHA256(),
|
||||
IsReadme: isReadme,
|
||||
FileType: fileType,
|
||||
@@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
||||
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
||||
|
||||
author := authors[g.rand.Intn(len(authors))]
|
||||
modelName := modelNames[g.rand.Intn(len(modelNames))]
|
||||
author := authors[g.rand.IntN(len(authors))]
|
||||
modelName := modelNames[g.rand.IntN(len(modelNames))]
|
||||
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
||||
|
||||
// Generate files
|
||||
numFiles := g.rand.Intn(5) + 2 // 2-6 files
|
||||
numFiles := g.rand.IntN(5) + 2 // 2-6 files
|
||||
files := make([]ProcessedModelFile, numFiles)
|
||||
|
||||
// Ensure at least one model file and one readme
|
||||
hasModelFile := false
|
||||
hasReadme := false
|
||||
|
||||
for i := 0; i < numFiles; i++ {
|
||||
for i := range numFiles {
|
||||
files[i] = g.GenerateProcessedModelFile()
|
||||
if files[i].FileType == "model" {
|
||||
hasModelFile = true
|
||||
@@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||
|
||||
// Generate sample metadata
|
||||
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
|
||||
license := licenses[g.rand.Intn(len(licenses))]
|
||||
license := licenses[g.rand.IntN(len(licenses))]
|
||||
|
||||
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
|
||||
numTags := g.rand.Intn(4) + 3 // 3-6 tags
|
||||
numTags := g.rand.IntN(4) + 3 // 3-6 tags
|
||||
tags := make([]string, numTags)
|
||||
for i := 0; i < numTags; i++ {
|
||||
tags[i] = sampleTags[g.rand.Intn(len(sampleTags))]
|
||||
for i := range numTags {
|
||||
tags[i] = sampleTags[g.rand.IntN(len(sampleTags))]
|
||||
}
|
||||
// Remove duplicates
|
||||
tags = g.removeDuplicates(tags)
|
||||
|
||||
// Optionally include icon (50% chance)
|
||||
icon := ""
|
||||
if g.rand.Intn(2) == 0 {
|
||||
if g.rand.IntN(2) == 0 {
|
||||
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
|
||||
}
|
||||
|
||||
return ProcessedModel{
|
||||
ModelID: modelID,
|
||||
Author: author,
|
||||
Downloads: g.rand.Intn(1000000) + 1000,
|
||||
Downloads: g.rand.IntN(1000000) + 1000,
|
||||
LastModified: g.randomDate(),
|
||||
Files: files,
|
||||
PreferredModelFile: preferredModelFile,
|
||||
@@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[g.rand.Intn(len(charset))]
|
||||
b[i] = charset[g.rand.IntN(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string {
|
||||
const charset = "0123456789abcdef"
|
||||
b := make([]byte, 64)
|
||||
for i := range b {
|
||||
b[i] = charset[g.rand.Intn(len(charset))]
|
||||
b[i] = charset[g.rand.IntN(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (g *SyntheticDataGenerator) randomDate() string {
|
||||
now := time.Now()
|
||||
daysAgo := g.rand.Intn(365) // Random date within last year
|
||||
daysAgo := g.rand.IntN(365) // Random date within last year
|
||||
pastDate := now.AddDate(0, 0, -daysAgo)
|
||||
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
||||
}
|
||||
@@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string)
|
||||
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
||||
}
|
||||
|
||||
return templates[g.rand.Intn(len(templates))]
|
||||
return templates[g.rand.IntN(len(templates))]
|
||||
}
|
||||
|
||||
225
.github/workflows/backend.yml
vendored
225
.github/workflows/backend.yml
vendored
@@ -105,6 +105,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-faster-whisper'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -118,6 +131,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-llama-cpp-quantization'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "llama-cpp-quantization"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -366,6 +405,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -522,6 +574,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -757,6 +822,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -913,6 +991,32 @@ jobs:
|
||||
backend: "mlx-distributed"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-whisperx'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "whisperx"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-faster-whisper'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1056,6 +1160,32 @@ jobs:
|
||||
backend: "stablediffusion-ggml"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-sam3-cpp'
|
||||
base-image: "ubuntu:24.04"
|
||||
ubuntu-version: '2404'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1592,6 +1722,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-whisperx'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "whisperx"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-faster-whisper'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "faster-whisper"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# SYCL additional backends
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
@@ -1790,6 +1946,59 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# sam3-cpp
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-sam3-cpp'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1842,6 +2051,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-sam3-cpp'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "sam3-cpp"
|
||||
dockerfile: "./backend/Dockerfile.golang"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# whisper
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
@@ -2373,6 +2595,9 @@ jobs:
|
||||
tag-suffix: "-metal-darwin-arm64-local-store"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "llama-cpp-quantization"
|
||||
tag-suffix: "-metal-darwin-arm64-llama-cpp-quantization"
|
||||
build-type: "mps"
|
||||
with:
|
||||
backend: ${{ matrix.backend }}
|
||||
build-type: ${{ matrix.build-type }}
|
||||
|
||||
48
.github/workflows/bump-inference-defaults.yml
vendored
Normal file
48
.github/workflows/bump-inference-defaults.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
name: Bump inference defaults
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Run daily at 06:00 UTC
|
||||
- cron: '0 6 * * *'
|
||||
workflow_dispatch: # Allow manual trigger
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
bump:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Re-fetch inference defaults
|
||||
run: make generate-force
|
||||
|
||||
- name: Check for changes
|
||||
id: diff
|
||||
run: |
|
||||
if git diff --quiet core/config/inference_defaults.json; then
|
||||
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.diff.outputs.changed == 'true'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
with:
|
||||
commit-message: "chore: bump inference defaults from unsloth"
|
||||
title: "chore: bump inference defaults from unsloth"
|
||||
body: |
|
||||
Auto-generated update of `core/config/inference_defaults.json` from
|
||||
[unsloth's inference_defaults.json](https://github.com/unslothai/unsloth/blob/main/studio/backend/assets/configs/inference_defaults.json).
|
||||
|
||||
This PR was created automatically by the `bump-inference-defaults` workflow.
|
||||
branch: chore/bump-inference-defaults
|
||||
delete-branch: true
|
||||
labels: automated
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -34,6 +34,10 @@ jobs:
|
||||
variable: "ACESTEP_CPP_VERSION"
|
||||
branch: "master"
|
||||
file: "backend/go/acestep-cpp/Makefile"
|
||||
- repository: "PABannier/sam3.cpp"
|
||||
variable: "SAM3_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/go/sam3-cpp/Makefile"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
2
.github/workflows/gallery-agent.yaml
vendored
2
.github/workflows/gallery-agent.yaml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
- name: Run gallery agent
|
||||
env:
|
||||
#OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
OPENAI_MODE: Qwen3.5-2B-GGUF
|
||||
OPENAI_MODEL: Qwen3.5-2B-GGUF
|
||||
OPENAI_BASE_URL: "http://localhost:8080"
|
||||
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
|
||||
#OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
||||
|
||||
75
.github/workflows/gh-pages.yml
vendored
Normal file
75
.github/workflows/gh-pages.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: Deploy docs to GitHub Pages
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- 'gallery/**'
|
||||
- 'images/**'
|
||||
- '.github/ci/modelslist.go'
|
||||
- '.github/workflows/gh-pages.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pages: write
|
||||
id-token: write
|
||||
|
||||
concurrency:
|
||||
group: pages
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
HUGO_VERSION: "0.146.3"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0 # needed for enableGitInfo
|
||||
submodules: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: false
|
||||
|
||||
- name: Setup Hugo
|
||||
uses: peaceiris/actions-hugo@v3
|
||||
with:
|
||||
hugo-version: ${{ env.HUGO_VERSION }}
|
||||
extended: true
|
||||
|
||||
- name: Setup Pages
|
||||
id: pages
|
||||
uses: actions/configure-pages@v6
|
||||
|
||||
- name: Generate gallery
|
||||
run: go run ./.github/ci/modelslist.go ./gallery/index.yaml > docs/static/gallery.html
|
||||
|
||||
- name: Build site
|
||||
working-directory: docs
|
||||
run: |
|
||||
mkdir -p layouts/_default
|
||||
hugo --minify --baseURL "${{ steps.pages.outputs.base_url }}/"
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-pages-artifact@v4
|
||||
with:
|
||||
path: docs/public
|
||||
|
||||
deploy:
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v5
|
||||
107
.github/workflows/test-extra.yml
vendored
107
.github/workflows/test-extra.yml
vendored
@@ -14,6 +14,38 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
detect-changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
run-all: ${{ steps.detect.outputs.run-all }}
|
||||
transformers: ${{ steps.detect.outputs.transformers }}
|
||||
rerankers: ${{ steps.detect.outputs.rerankers }}
|
||||
diffusers: ${{ steps.detect.outputs.diffusers }}
|
||||
coqui: ${{ steps.detect.outputs.coqui }}
|
||||
moonshine: ${{ steps.detect.outputs.moonshine }}
|
||||
pocket-tts: ${{ steps.detect.outputs.pocket-tts }}
|
||||
qwen-tts: ${{ steps.detect.outputs.qwen-tts }}
|
||||
qwen-asr: ${{ steps.detect.outputs.qwen-asr }}
|
||||
nemo: ${{ steps.detect.outputs.nemo }}
|
||||
voxcpm: ${{ steps.detect.outputs.voxcpm }}
|
||||
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
kokoros: ${{ steps.detect.outputs.kokoros }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
- name: Install dependencies
|
||||
run: bun add js-yaml @octokit/core
|
||||
- name: Detect changed backends
|
||||
id: detect
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_EVENT_PATH: ${{ github.event_path }}
|
||||
run: bun run scripts/changed-backends.js
|
||||
|
||||
# Requires CUDA
|
||||
# tests-chatterbox-tts:
|
||||
# runs-on: ubuntu-latest
|
||||
@@ -37,6 +69,8 @@ jobs:
|
||||
# make --jobs=5 --output-sync=target -C backend/python/chatterbox
|
||||
# make --jobs=5 --output-sync=target -C backend/python/chatterbox test
|
||||
tests-transformers:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.transformers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -58,6 +92,8 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/transformers
|
||||
make --jobs=5 --output-sync=target -C backend/python/transformers test
|
||||
tests-rerankers:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.rerankers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -80,6 +116,8 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/rerankers test
|
||||
|
||||
tests-diffusers:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.diffusers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -229,6 +267,8 @@ jobs:
|
||||
# make --jobs=5 --output-sync=target -C backend/python/vllm test
|
||||
|
||||
tests-coqui:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.coqui == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -248,6 +288,8 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/coqui
|
||||
make --jobs=5 --output-sync=target -C backend/python/coqui test
|
||||
tests-moonshine:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.moonshine == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -267,6 +309,8 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/moonshine
|
||||
make --jobs=5 --output-sync=target -C backend/python/moonshine test
|
||||
tests-pocket-tts:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.pocket-tts == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -286,6 +330,8 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/pocket-tts
|
||||
make --jobs=5 --output-sync=target -C backend/python/pocket-tts test
|
||||
tests-qwen-tts:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.qwen-tts == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -327,6 +373,8 @@ jobs:
|
||||
# make --jobs=5 --output-sync=target -C backend/python/fish-speech
|
||||
# make --jobs=5 --output-sync=target -C backend/python/fish-speech test
|
||||
tests-qwen-asr:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.qwen-asr == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -346,6 +394,8 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/qwen-asr
|
||||
make --jobs=5 --output-sync=target -C backend/python/qwen-asr test
|
||||
tests-nemo:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.nemo == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -365,6 +415,8 @@ jobs:
|
||||
make --jobs=5 --output-sync=target -C backend/python/nemo
|
||||
make --jobs=5 --output-sync=target -C backend/python/nemo test
|
||||
tests-voxcpm:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.voxcpm == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -383,7 +435,38 @@ jobs:
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm test
|
||||
tests-llama-cpp-quantization:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.llama-cpp-quantization == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake curl git python3-pip
|
||||
# Install UV
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||
- name: Build llama-quantize from llama.cpp
|
||||
run: |
|
||||
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git /tmp/llama.cpp
|
||||
cmake -B /tmp/llama.cpp/build -S /tmp/llama.cpp -DGGML_NATIVE=OFF
|
||||
cmake --build /tmp/llama.cpp/build --target llama-quantize -j$(nproc)
|
||||
sudo cp /tmp/llama.cpp/build/bin/llama-quantize /usr/local/bin/
|
||||
- name: Install backend
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization
|
||||
- name: Test llama-cpp-quantization
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization test
|
||||
tests-acestep-cpp:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.acestep-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -414,6 +497,8 @@ jobs:
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/acestep-cpp test
|
||||
tests-voxtral:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.voxtral == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -444,3 +529,25 @@ jobs:
|
||||
- name: Test voxtral
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/go/voxtral test
|
||||
tests-kokoros:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.kokoros == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake pkg-config protobuf-compiler clang libclang-dev
|
||||
sudo apt-get install -y espeak-ng libespeak-ng-dev libsonic-dev libpcaudio-dev libopus-dev libssl-dev
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
- name: Build kokoros
|
||||
run: |
|
||||
make -C backend/rust/kokoros kokoros-grpc
|
||||
- name: Test kokoros
|
||||
run: |
|
||||
make -C backend/rust/kokoros test
|
||||
|
||||
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.25.x']
|
||||
go-version: ['1.26.x']
|
||||
steps:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
@@ -179,7 +179,7 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.25.x']
|
||||
go-version: ['1.26.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
|
||||
72
.github/workflows/tests-ui-e2e.yml
vendored
Normal file
72
.github/workflows/tests-ui-e2e.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
---
|
||||
name: 'UI E2E Tests'
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'core/http/**'
|
||||
- 'tests/e2e-ui/**'
|
||||
- 'tests/e2e/mock-backend/**'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
concurrency:
|
||||
group: ci-tests-ui-e2e-${{ github.head_ref || github.ref }}-${{ github.repository }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
tests-ui-e2e:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.26.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: false
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '22'
|
||||
- name: Proto Dependencies
|
||||
run: |
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||
- name: System Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential libopus-dev
|
||||
- name: Build UI test server
|
||||
run: PATH="$PATH:$HOME/go/bin" make build-ui-test-server
|
||||
- name: Install Playwright
|
||||
working-directory: core/http/react-ui
|
||||
run: |
|
||||
npm install
|
||||
npx playwright install --with-deps chromium
|
||||
- name: Run Playwright tests
|
||||
working-directory: core/http/react-ui
|
||||
run: npx playwright test
|
||||
- name: Upload Playwright report
|
||||
if: ${{ failure() }}
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: playwright-report
|
||||
path: core/http/react-ui/playwright-report/
|
||||
retention-days: 7
|
||||
- name: Setup tmate session if tests fail
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
with:
|
||||
detached: true
|
||||
connect-timeout-seconds: 180
|
||||
limit-access-to-actor: true
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -72,3 +72,8 @@ core/http/react-ui/dist
|
||||
|
||||
# Extracted backend binaries for container-based testing
|
||||
local-backends/
|
||||
|
||||
# UI E2E test artifacts
|
||||
tests/e2e-ui/ui-test-server
|
||||
core/http/react-ui/playwright-report/
|
||||
core/http/react-ui/test-results/
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +1,6 @@
|
||||
[submodule "docs/themes/hugo-theme-relearn"]
|
||||
path = docs/themes/hugo-theme-relearn
|
||||
url = https://github.com/McShelby/hugo-theme-relearn.git
|
||||
[submodule "backend/rust/kokoros/sources/Kokoros"]
|
||||
path = backend/rust/kokoros/sources/Kokoros
|
||||
url = https://github.com/lucasjinreal/Kokoros
|
||||
|
||||
@@ -11,6 +11,9 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
|
||||
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
|
||||
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
| [.agents/adding-gallery-models.md](.agents/adding-gallery-models.md) | Adding GGUF models from HuggingFace to the model gallery |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH}
|
||||
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
|
||||
FROM requirements-drivers AS build-requirements
|
||||
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG GO_VERSION=1.26.0
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG TARGETARCH
|
||||
@@ -256,7 +256,7 @@ RUN apt-get update && \
|
||||
|
||||
FROM build-requirements AS builder-base
|
||||
|
||||
ARG GO_TAGS=""
|
||||
ARG GO_TAGS="auth"
|
||||
ARG GRPC_BACKENDS
|
||||
ARG MAKEFLAGS
|
||||
ARG LD_FLAGS="-s -w"
|
||||
@@ -319,7 +319,6 @@ COPY ./.git ./.git
|
||||
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
|
||||
COPY ./pkg/grpc ./pkg/grpc
|
||||
COPY ./pkg/utils ./pkg/utils
|
||||
COPY ./pkg/langchain ./pkg/langchain
|
||||
|
||||
RUN ls -l ./
|
||||
RUN make protogen-go
|
||||
|
||||
52
Makefile
52
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -107,7 +107,7 @@ core/http/react-ui/dist: react-ui
|
||||
|
||||
## Build:
|
||||
|
||||
build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project
|
||||
build: protogen-go generate install-go-tools core/http/react-ui/dist ## Build the project
|
||||
$(info ${GREEN}I local-ai build info:${RESET})
|
||||
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
||||
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
|
||||
@@ -148,7 +148,6 @@ test-models/testmodel.ggml:
|
||||
mkdir -p test-dir
|
||||
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
||||
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
|
||||
wget -q https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
|
||||
cp tests/models_fixtures/* test-models
|
||||
|
||||
@@ -398,6 +397,16 @@ protogen-go: protoc install-go-tools
|
||||
./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
||||
backend/backend.proto
|
||||
|
||||
core/config/inference_defaults.json: ## Fetch inference defaults from unsloth (only if missing)
|
||||
$(GOCMD) generate ./core/config/...
|
||||
|
||||
.PHONY: generate
|
||||
generate: core/config/inference_defaults.json ## Ensure inference defaults exist
|
||||
|
||||
.PHONY: generate-force
|
||||
generate-force: ## Re-fetch inference defaults from unsloth (always)
|
||||
$(GOCMD) generate ./core/config/...
|
||||
|
||||
.PHONY: protogen-go-clean
|
||||
protogen-go-clean:
|
||||
$(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
|
||||
@@ -419,8 +428,11 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/qwen-asr
|
||||
$(MAKE) -C backend/python/nemo
|
||||
$(MAKE) -C backend/python/voxcpm
|
||||
$(MAKE) -C backend/python/faster-whisper
|
||||
$(MAKE) -C backend/python/whisperx
|
||||
$(MAKE) -C backend/python/ace-step
|
||||
$(MAKE) -C backend/python/trl
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
@@ -438,8 +450,11 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/qwen-asr test
|
||||
$(MAKE) -C backend/python/nemo test
|
||||
$(MAKE) -C backend/python/voxcpm test
|
||||
$(MAKE) -C backend/python/faster-whisper test
|
||||
$(MAKE) -C backend/python/whisperx test
|
||||
$(MAKE) -C backend/python/ace-step test
|
||||
$(MAKE) -C backend/python/trl test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
IMAGE_TYPE?=core
|
||||
@@ -572,6 +587,14 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true
|
||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||
|
||||
# Rust backends
|
||||
BACKEND_KOKOROS = kokoros|rust|.|false|true
|
||||
|
||||
# C++ backends (Go wrapper with purego)
|
||||
BACKEND_SAM3_CPP = sam3-cpp|golang|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
@@ -629,12 +652,16 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
@@ -646,6 +673,23 @@ build-mock-backend: protogen-go
|
||||
clean-mock-backend:
|
||||
rm -f tests/e2e/mock-backend/mock-backend
|
||||
|
||||
########################################################
|
||||
### UI E2E Test Server
|
||||
########################################################
|
||||
|
||||
build-ui-test-server: build-mock-backend react-ui protogen-go
|
||||
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
|
||||
|
||||
test-ui-e2e: build-ui-test-server
|
||||
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
|
||||
|
||||
test-ui-e2e-docker:
|
||||
docker build -t localai-ui-e2e -f tests/e2e-ui/Dockerfile .
|
||||
docker run --rm localai-ui-e2e
|
||||
|
||||
clean-ui-test-server:
|
||||
rm -f tests/e2e-ui/ui-test-server
|
||||
|
||||
########################################################
|
||||
### END Backends
|
||||
########################################################
|
||||
|
||||
421
README.md
421
README.md
@@ -5,35 +5,17 @@
|
||||
</h1>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/go-skynet/LocalAI/fork" target="blank">
|
||||
<img src="https://img.shields.io/github/forks/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI forks"/>
|
||||
</a>
|
||||
<a href="https://github.com/go-skynet/LocalAI/stargazers" target="blank">
|
||||
<img src="https://img.shields.io/github/stars/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI stars"/>
|
||||
</a>
|
||||
<a href="https://github.com/go-skynet/LocalAI/pulls" target="blank">
|
||||
<img src="https://img.shields.io/github/issues-pr/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI pull-requests"/>
|
||||
</a>
|
||||
<a href='https://github.com/go-skynet/LocalAI/releases'>
|
||||
<img src='https://img.shields.io/github/release/go-skynet/LocalAI?&label=Latest&style=for-the-badge'>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="LICENSE" target="blank">
|
||||
<img src="https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge" alt="LocalAI License"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://hub.docker.com/r/localai/localai" target="blank">
|
||||
<img src="https://img.shields.io/badge/dockerhub-images-important.svg?logo=Docker" alt="LocalAI Docker hub"/>
|
||||
</a>
|
||||
<a href="https://quay.io/repository/go-skynet/local-ai?tab=tags&tag=latest" target="blank">
|
||||
<img src="https://img.shields.io/badge/quay.io-images-important.svg?" alt="LocalAI Quay.io"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://twitter.com/LocalAI_API" target="blank">
|
||||
<img src="https://img.shields.io/badge/X-%23000000.svg?style=for-the-badge&logo=X&logoColor=white&label=LocalAI_API" alt="Follow LocalAI_API"/>
|
||||
@@ -47,363 +29,184 @@
|
||||
<a href="https://trendshift.io/repositories/5539" target="_blank"><img src="https://trendshift.io/api/badge/repositories/5539" alt="mudler%2FLocalAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||
>
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||
[](https://t.me/localaiofficial_bot)
|
||||
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||
|
||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
||||
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||
- **35+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||
- **Privacy-first** — your data never leaves your infrastructure
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/mudler/LocalAI-examples" target="blank">
|
||||
<img src="https://img.shields.io/badge/📦_Examples_Repository-Browse_Ready--to--Run_Examples-blue?style=for-the-badge" alt="LocalAI Examples Repository"/>
|
||||
</a>
|
||||
</p>
|
||||
Created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
||||
|
||||
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
||||
> [:book: Documentation](https://localai.io/) | [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) | [💻 Quickstart](https://localai.io/basics/getting_started/) | [🖼️ Models](https://models.localai.io/) | [❓FAQ](https://localai.io/faq/)
|
||||
|
||||
## Guided tour
|
||||
|
||||
https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
|
||||
|
||||
<details>
|
||||
<summary><strong>Table of Contents</strong></summary>
|
||||
|
||||
- [Local Stack Family](#local-stack-family)
|
||||
- [Screenshots / Video](#screenshots--video)
|
||||
- [Quickstart](#-quickstart)
|
||||
- [macOS Download](#macos-download)
|
||||
- [Containers (Docker, podman, ...)](#containers-docker-podman-)
|
||||
- [Latest project news](#-latest-project-news)
|
||||
- [Features](#-features)
|
||||
- [Supported Backends & Acceleration](#-supported-backends--acceleration)
|
||||
- [Text Generation & Language Models](#text-generation--language-models)
|
||||
- [Audio & Speech Processing](#audio--speech-processing)
|
||||
- [Image & Video Generation](#image--video-generation)
|
||||
- [Specialized AI Tasks](#specialized-ai-tasks)
|
||||
- [Hardware Acceleration Matrix](#hardware-acceleration-matrix)
|
||||
- [Community and integrations](#-community-and-integrations)
|
||||
- [Resources](#-resources)
|
||||
- [Media, Blogs, Social](#book--media-blogs-social)
|
||||
- [Autonomous Development Team](#-autonomous-development-team)
|
||||
- [Citation](#citation)
|
||||
- [Sponsors](#️-sponsors)
|
||||
- [Individual sponsors](#individual-sponsors)
|
||||
- [Star history](#-star-history)
|
||||
- [License](#-license)
|
||||
- [Acknowledgements](#-acknowledgements)
|
||||
- [Contributors](#-contributors)
|
||||
<summary>
|
||||
Click to see more!
|
||||
</summary>
|
||||
|
||||
#### User and auth
|
||||
|
||||
https://github.com/user-attachments/assets/228fa9ad-81a3-4d43-bfb9-31557e14a36c
|
||||
|
||||
#### Agents
|
||||
|
||||
https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
|
||||
|
||||
#### Usage metrics per user
|
||||
|
||||
https://github.com/user-attachments/assets/cbb03379-23b4-4e3d-bd26-d152f057007f
|
||||
|
||||
#### Fine-tuning and Quantization
|
||||
|
||||
https://github.com/user-attachments/assets/5ba4ace9-d3df-4795-b7d4-b0b404ea71ee
|
||||
|
||||
#### WebRTC
|
||||
|
||||
https://github.com/user-attachments/assets/ed88e34c-fed3-4b83-8a67-4716a9feeb7b
|
||||
|
||||
</details>
|
||||
|
||||
## Local Stack Family
|
||||
## Quickstart
|
||||
|
||||
Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tools, you might also like:
|
||||
|
||||
- **[LocalAGI](https://github.com/mudler/LocalAGI)** - AI agent orchestration platform with OpenAI Responses API compatibility and advanced agentic capabilities
|
||||
- **[LocalRecall](https://github.com/mudler/LocalRecall)** - MCP/REST API knowledge base system providing persistent memory and storage for AI agents
|
||||
- 🆕 **[Cogito](https://github.com/mudler/cogito)** - Go library for building intelligent, co-operative agentic software and LLM-powered workflows, focusing on improving results for small, open source language models that scales to any LLM. Powers LocalAGI and LocalAI MCP/Agentic capabilities
|
||||
- 🆕 **[Wiz](https://github.com/mudler/wiz)** - Terminal-based AI agent accessible via Ctrl+Space keybinding. Portable, local-LLM friendly shell assistant with TUI/CLI modes, tool execution with approval, MCP protocol support, and multi-shell compatibility (zsh, bash, fish)
|
||||
- 🆕 **[SkillServer](https://github.com/mudler/skillserver)** - Simple, centralized skills database for AI agents via MCP. Manages skills as Markdown files with MCP server integration, web UI for editing, Git synchronization, and full-text search capabilities
|
||||
|
||||
|
||||
## Screenshots / Video
|
||||
|
||||
### Youtube video
|
||||
|
||||
<h1 align="center">
|
||||
<br>
|
||||
<a href="https://www.youtube.com/watch?v=PDqYhB9nNHA" target="_blank"> <img width="300" src="https://img.youtube.com/vi/PDqYhB9nNHA/0.jpg"> </a><br>
|
||||
<br>
|
||||
</h1>
|
||||
|
||||
|
||||
### Screenshots
|
||||
|
||||
| Talk Interface | Generate Audio |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
| Models Overview | Generate Images |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
| Chat Interface | Home |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
| Login | Swarm |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
## 💻 Quickstart
|
||||
|
||||
|
||||
|
||||
### macOS Download:
|
||||
### macOS
|
||||
|
||||
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
|
||||
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
|
||||
</a>
|
||||
|
||||
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
|
||||
> **Note:** The DMG is not signed by Apple. After installing, run: `sudo xattr -d com.apple.quarantine /Applications/LocalAI.app`. See [#6268](https://github.com/mudler/LocalAI/issues/6268) for details.
|
||||
|
||||
### Containers (Docker, podman, ...)
|
||||
|
||||
> **💡 Docker Run vs Docker Start**
|
||||
>
|
||||
> - `docker run` creates and starts a new container. If a container with the same name already exists, this command will fail.
|
||||
> - `docker start` starts an existing container that was previously created with `docker run`.
|
||||
>
|
||||
> If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai`
|
||||
> Already ran LocalAI before? Use `docker start -i local-ai` to restart an existing container.
|
||||
|
||||
#### CPU only image:
|
||||
#### CPU only:
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest
|
||||
```
|
||||
|
||||
#### NVIDIA GPU Images:
|
||||
#### NVIDIA GPU:
|
||||
|
||||
```bash
|
||||
# CUDA 13.0
|
||||
# CUDA 13
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-13
|
||||
|
||||
# CUDA 12.0
|
||||
# CUDA 12
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12
|
||||
|
||||
# NVIDIA Jetson (L4T) ARM64
|
||||
# CUDA 12 (for Nvidia AGX Orin and similar platforms)
|
||||
# NVIDIA Jetson ARM64 (CUDA 12, for AGX Orin and similar)
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64
|
||||
|
||||
# CUDA 13 (for Nvidia DGX Spark)
|
||||
# NVIDIA Jetson ARM64 (CUDA 13, for DGX Spark)
|
||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13
|
||||
```
|
||||
|
||||
#### AMD GPU Images (ROCm):
|
||||
#### AMD GPU (ROCm):
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas
|
||||
```
|
||||
|
||||
#### Intel GPU Images (oneAPI):
|
||||
#### Intel GPU (oneAPI):
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
|
||||
```
|
||||
|
||||
#### Vulkan GPU Images:
|
||||
#### Vulkan GPU:
|
||||
|
||||
```bash
|
||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan
|
||||
```
|
||||
|
||||
To load models:
|
||||
### Loading models
|
||||
|
||||
```bash
|
||||
# From the model gallery (see available models with `local-ai models list`, in the WebUI from the model tab, or visiting https://models.localai.io)
|
||||
# From the model gallery (see available models with `local-ai models list` or at https://models.localai.io)
|
||||
local-ai run llama-3.2-1b-instruct:q4_k_m
|
||||
# Start LocalAI with the phi-2 model directly from huggingface
|
||||
# From Huggingface
|
||||
local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf
|
||||
# Install and run a model from the Ollama OCI registry
|
||||
# From the Ollama OCI registry
|
||||
local-ai run ollama://gemma:2b
|
||||
# Run a model from a configuration file
|
||||
# From a YAML config
|
||||
local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
||||
# Install and run a model from a standard OCI registry (e.g., Docker Hub)
|
||||
# From a standard OCI registry (e.g., Docker Hub)
|
||||
local-ai run oci://localai/phi-2:latest
|
||||
```
|
||||
|
||||
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
|
||||
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
|
||||
|
||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
||||
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).
|
||||
|
||||
## 📰 Latest project news
|
||||
- March 2026: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790),[MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
||||
- February 2026: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
||||
- January 2026: **LocalAI 3.10.0** - Major release with Anthropic API support, Open Responses API for stateful agents, video & image generation suite (LTX-2), unified GPU backends, tool streaming & XML parsing, system-aware backend gallery, crash fixes for AVX-only CPUs and AMD VRAM reporting, request tracing, and new backends: **Moonshine** (ultra-fast transcription), **Pocket-TTS** (lightweight TTS). Vulkan arm64 builds now available. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0).
|
||||
- December 2025: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic fitting of models to multiple GPUS(llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Added Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
|
||||
- November 2025: Major improvements to the UX. Among these: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245) and [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
|
||||
- October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools
|
||||
- September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments.
|
||||
- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
|
||||
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
|
||||
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
||||
- June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607).
|
||||
- May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery).
|
||||
- May 2025: Important: image name changes [See release](https://github.com/mudler/LocalAI/releases/tag/v2.29.0)
|
||||
- Apr 2025: Rebrand, WebUI enhancements
|
||||
- Apr 2025: [LocalAGI](https://github.com/mudler/LocalAGI) and [LocalRecall](https://github.com/mudler/LocalRecall) join the LocalAI family stack.
|
||||
- Apr 2025: WebUI overhaul
|
||||
- Feb 2025: Backend cleanup, Breaking changes, new backends (kokoro, OutelTTS, faster-whisper), Nvidia L4T images
|
||||
- Jan 2025: LocalAI model release: https://huggingface.co/mudler/LocalAI-functioncall-phi-4-v0.3, SANA support in diffusers: https://github.com/mudler/LocalAI/pull/4603
|
||||
- Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 )
|
||||
- Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 )
|
||||
- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
|
||||
- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
|
||||
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
|
||||
- July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723. P2P Global community pools: https://github.com/mudler/LocalAI/issues/3113
|
||||
- May 2024: 🔥🔥 Decentralized P2P llama.cpp: https://github.com/mudler/LocalAI/pull/2343 (peer2peer llama.cpp!) 👉 Docs https://localai.io/features/distribute/
|
||||
- May 2024: 🔥🔥 Distributed inferencing: https://github.com/mudler/LocalAI/pull/2324
|
||||
- April 2024: Reranker API: https://github.com/mudler/LocalAI/pull/2121
|
||||
## Latest News
|
||||
|
||||
Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
||||
- **March 2026**: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
||||
- **February 2026**: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
||||
- **January 2026**: **LocalAI 3.10.0** — Anthropic API support, Open Responses API, video & image generation (LTX-2), unified GPU backends, tool streaming, Moonshine, Pocket-TTS. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0)
|
||||
- **December 2025**: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic multi-GPU model fitting (llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
|
||||
- **November 2025**: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245), [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
|
||||
- **October 2025**: [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support for agentic capabilities
|
||||
- **September 2025**: New Launcher for macOS and Linux, extended backend support for Mac and Nvidia L4T, MLX-Audio, WAN 2.2
|
||||
- **August 2025**: MLX, MLX-VLM, Diffusers, llama.cpp now supported on Apple Silicon
|
||||
- **July 2025**: All backends migrated outside the main binary — [lightweight, modular architecture](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
||||
|
||||
## 🚀 [Features](https://localai.io/features/)
|
||||
For older news and full release notes, see [GitHub Releases](https://github.com/mudler/LocalAI/releases) and the [News page](https://localai.io/basics/news/).
|
||||
|
||||
- 🧩 [Backend Gallery](https://localai.io/backends/): Install/remove backends on the fly, powered by OCI images — fully customizable and API-driven.
|
||||
- 📖 [Text generation with GPTs](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [:book: and more](https://localai.io/model-compatibility/index.html#model-compatibility-table))
|
||||
- 🗣 [Text to Audio](https://localai.io/features/text-to-audio/)
|
||||
- 🔈 [Audio to Text](https://localai.io/features/audio-to-text/)
|
||||
- 🎨 [Image generation](https://localai.io/features/image-generation)
|
||||
- 🔥 [OpenAI-alike tools API](https://localai.io/features/openai-functions/)
|
||||
- ⚡ [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
|
||||
- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
|
||||
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
||||
- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
|
||||
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
|
||||
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
||||
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
||||
- 🆕🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - Agentic capabilities with external tools and [LocalAGI's Agentic capabilities](https://github.com/mudler/LocalAGI)
|
||||
- 🆕🤖 [Built-in Agents](https://localai.io/features/agents/) - Autonomous AI agents with tool use, knowledge base (RAG), skills, SSE streaming, import/export, and [Agent Hub](https://agenthub.localai.io) — powered by [LocalAGI](https://github.com/mudler/LocalAGI)
|
||||
- 🔊 Voice activity detection (Silero-VAD support)
|
||||
- 🌍 Integrated WebUI!
|
||||
## Features
|
||||
|
||||
## 🧩 Supported Backends & Acceleration
|
||||
- [Text generation](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [and more](https://localai.io/model-compatibility/))
|
||||
- [Text to Audio](https://localai.io/features/text-to-audio/)
|
||||
- [Audio to Text](https://localai.io/features/audio-to-text/)
|
||||
- [Image generation](https://localai.io/features/image-generation)
|
||||
- [OpenAI-compatible tools API](https://localai.io/features/openai-functions/)
|
||||
- [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
|
||||
- [Embeddings generation](https://localai.io/features/embeddings/)
|
||||
- [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||
- [Download models from Huggingface](https://localai.io/models/)
|
||||
- [Vision API](https://localai.io/features/gpt-vision/)
|
||||
- [Object Detection](https://localai.io/features/object-detection/)
|
||||
- [Reranker API](https://localai.io/features/reranker/)
|
||||
- [P2P Inferencing](https://localai.io/features/distribute/)
|
||||
- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS
|
||||
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
|
||||
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
|
||||
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images
|
||||
- Voice Activity Detection (Silero-VAD)
|
||||
- Integrated WebUI
|
||||
|
||||
LocalAI supports a comprehensive range of AI backends with multiple acceleration options:
|
||||
## Supported Backends & Acceleration
|
||||
|
||||
### Text Generation & Language Models
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **llama.cpp** | LLM inference in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, Metal, CPU |
|
||||
| **vLLM** | Fast LLM inference with PagedAttention | CUDA 12/13, ROCm, Intel |
|
||||
| **transformers** | HuggingFace transformers framework | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) |
|
||||
| **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) |
|
||||
| **vLLM Omni** | Multimodal vLLM with vision and audio | CUDA 12/13, ROCm, Intel |
|
||||
LocalAI supports **35+ backends** including llama.cpp, vLLM, transformers, whisper.cpp, diffusers, MLX, MLX-VLM, and many more. Hardware acceleration is available for **NVIDIA** (CUDA 12/13), **AMD** (ROCm), **Intel** (oneAPI/SYCL), **Apple Silicon** (Metal), **Vulkan**, and **NVIDIA Jetson** (L4T). All backends can be installed on-the-fly from the [Backend Gallery](https://localai.io/backends/).
|
||||
|
||||
### Audio & Speech Processing
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, CPU |
|
||||
| **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **moonshine** | Ultra-fast transcription engine for low-end devices | CUDA 12/13, Metal, CPU |
|
||||
| **coqui** | Advanced TTS with 1100+ languages | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **kokoro** | Lightweight TTS model | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **chatterbox** | Production-grade TTS | CUDA 12/13, CPU |
|
||||
| **piper** | Fast neural TTS system | CPU |
|
||||
| **kitten-tts** | Kitten TTS models | CPU |
|
||||
| **silero-vad** | Voice Activity Detection | CPU |
|
||||
| **neutts** | Text-to-speech with voice cloning | CUDA 12/13, ROCm, CPU |
|
||||
| **vibevoice** | Real-time TTS with voice cloning | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **pocket-tts** | Lightweight CPU-based TTS | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **qwen-tts** | High-quality TTS with custom voice, voice design, and voice cloning | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **nemo** | NVIDIA NeMo framework for speech models | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **outetts** | OuteTTS with voice cloning | CUDA 12/13, CPU |
|
||||
| **faster-qwen3-tts** | Faster Qwen3 TTS | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **qwen-asr** | Qwen ASR speech recognition | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **voxcpm** | VoxCPM speech understanding | CUDA 12/13, Metal, CPU |
|
||||
| **whisperx** | Enhanced Whisper transcription | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **ace-step** | Music generation from text descriptions, lyrics, or audio samples | CUDA 12/13, ROCm, Intel, Metal, CPU |
|
||||
See the full [Backend & Model Compatibility Table](https://localai.io/model-compatibility/) and [GPU Acceleration guide](https://localai.io/features/gpu-acceleration/).
|
||||
|
||||
### Image & Video Generation
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12/13, Intel SYCL, Vulkan, CPU |
|
||||
| **diffusers** | HuggingFace diffusion models | CUDA 12/13, ROCm, Intel, Metal, CPU |
|
||||
## Resources
|
||||
|
||||
### Specialized AI Tasks
|
||||
| Backend | Description | Acceleration Support |
|
||||
|---------|-------------|---------------------|
|
||||
| **rfdetr** | Real-time object detection | CUDA 12/13, Intel, CPU |
|
||||
| **rerankers** | Document reranking API | CUDA 12/13, ROCm, Intel, CPU |
|
||||
| **local-store** | Vector database | CPU |
|
||||
| **huggingface** | HuggingFace API integration | API-based |
|
||||
- [Documentation](https://localai.io/)
|
||||
- [LLM fine-tuning guide](https://localai.io/docs/advanced/fine-tuning/)
|
||||
- [Build from source](https://localai.io/basics/build/)
|
||||
- [Kubernetes installation](https://localai.io/basics/getting_started/#run-localai-in-kubernetes)
|
||||
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
||||
- [Installation video walkthrough](https://www.youtube.com/watch?v=cMVNnlqwfw4)
|
||||
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
||||
- [Examples](https://github.com/mudler/LocalAI-examples)
|
||||
|
||||
### Hardware Acceleration Matrix
|
||||
## Autonomous Development Team
|
||||
|
||||
| Acceleration Type | Supported Backends | Hardware Support |
|
||||
|-------------------|-------------------|------------------|
|
||||
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
|
||||
| **NVIDIA CUDA 13** | All CUDA-compatible backends | Nvidia hardware |
|
||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, neutts, vibevoice, pocket-tts, qwen-tts, ace-step | AMD Graphics |
|
||||
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, coqui, kokoro, vibevoice, pocket-tts, qwen-tts, ace-step | Intel Arc, Intel iGPUs |
|
||||
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, moonshine, ace-step | Apple M1/M2/M3+ |
|
||||
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
|
||||
| **NVIDIA Jetson (CUDA 12)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr, ace-step | ARM64 embedded AI (AGX Orin, etc.) |
|
||||
| **NVIDIA Jetson (CUDA 13)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (DGX Spark) |
|
||||
| **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support |
|
||||
LocalAI is helped being maintained by a team of autonomous AI agents led by an AI Scrum Master.
|
||||
|
||||
### 🔗 Community and integrations
|
||||
|
||||
Build and deploy custom containers:
|
||||
- https://github.com/sozercan/aikit
|
||||
|
||||
WebUIs:
|
||||
- https://github.com/Jirubizu/localai-admin
|
||||
- https://github.com/go-skynet/LocalAI-frontend
|
||||
- QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot
|
||||
|
||||
Agentic Libraries:
|
||||
- https://github.com/mudler/cogito
|
||||
|
||||
MCPs:
|
||||
- https://github.com/mudler/MCPs
|
||||
|
||||
OS Assistant:
|
||||
|
||||
- https://github.com/mudler/Keygeist - Keygeist is an AI-powered keyboard operator that listens for key combinations and responds with AI-generated text typed directly into your Linux box.
|
||||
|
||||
Model galleries
|
||||
- https://github.com/go-skynet/model-gallery
|
||||
|
||||
Voice:
|
||||
- https://github.com/richiejp/VoxInput
|
||||
|
||||
Other:
|
||||
- Helm chart https://github.com/go-skynet/helm-charts
|
||||
- VSCode extension https://github.com/badgooooor/localai-vscode-plugin
|
||||
- Langchain: https://python.langchain.com/docs/integrations/providers/localai/
|
||||
- Terminal utility https://github.com/djcopley/ShellOracle
|
||||
- Local Smart assistant https://github.com/mudler/LocalAGI
|
||||
- Home Assistant https://github.com/drndos/hass-openai-custom-conversation / https://github.com/valentinfrlch/ha-llmvision / https://github.com/loryanstrant/HA-LocalAI-Monitor
|
||||
- Discord bot https://github.com/mudler/LocalAGI/tree/main/examples/discord
|
||||
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
|
||||
- Shell-Pilot(Interact with LLM using LocalAI models via pure shell scripts on your Linux or MacOS system) https://github.com/reid41/shell-pilot
|
||||
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
|
||||
- Another Telegram Bot https://github.com/JackBekket/Hellper
|
||||
- Auto-documentation https://github.com/JackBekket/Reflexia
|
||||
- Github bot which answer on issues, with code and documentation as context https://github.com/JackBekket/GitHelper
|
||||
- Github Actions: https://github.com/marketplace/actions/start-localai
|
||||
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
|
||||
|
||||
|
||||
### 🔗 Resources
|
||||
|
||||
- [LLM finetuning guide](https://localai.io/docs/advanced/fine-tuning/)
|
||||
- [How to build locally](https://localai.io/basics/build/index.html)
|
||||
- [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes)
|
||||
- [Projects integrating LocalAI](https://localai.io/docs/integrations/)
|
||||
- [How tos section](https://io.midori-ai.xyz/howtos/) (curated by our community)
|
||||
|
||||
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
||||
|
||||
- 🆕 [LocalAI Autonomous Dev Team Blog Post](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
||||
|
||||
- [Run Visual studio code with LocalAI (SUSE)](https://www.suse.com/c/running-ai-locally/)
|
||||
- 🆕 [Run LocalAI on Jetson Nano Devkit](https://mudler.pm/posts/local-ai-jetson-nano-devkit/)
|
||||
- [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/blog/low-code-llm-apps-with-local-ai-flowise-and-pulumi/)
|
||||
- [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance)
|
||||
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
|
||||
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
|
||||
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
||||
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65)
|
||||
|
||||
|
||||
## 🤖 Autonomous Development Team
|
||||
|
||||
LocalAI is now helped being maintained (for small tasks!) by a full team of autonomous AI agents led by an AI Scrum Master! This experiment demonstrates how open source projects can leverage AI agents for sustainable, long-term maintenance.
|
||||
|
||||
- **📊 Live Reports**: [Automatically generated reports](http://reports.localai.io)
|
||||
- **📋 Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6)
|
||||
- **📝 Blog Post**: [Learn about the autonomous dev team experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
||||
- **Live Reports**: [reports.localai.io](http://reports.localai.io)
|
||||
- **Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6)
|
||||
- **Blog Post**: [Learn about the experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -419,7 +222,7 @@ If you utilize this repository, data in a downstream project, please consider ci
|
||||
howpublished = {\url{https://github.com/go-skynet/LocalAI}},
|
||||
```
|
||||
|
||||
## ❤️ Sponsors
|
||||
## Sponsors
|
||||
|
||||
> Do you find LocalAI useful?
|
||||
|
||||
@@ -438,19 +241,19 @@ A huge thank you to our generous sponsors who support this project covering CI e
|
||||
|
||||
### Individual sponsors
|
||||
|
||||
A special thanks to individual sponsors that contributed to the project, a full list is in [Github](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler), a special shout out goes to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
|
||||
A special thanks to individual sponsors, a full list is on [GitHub](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler). Special shout out to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
|
||||
|
||||
## 🌟 Star history
|
||||
## Star history
|
||||
|
||||
[](https://star-history.com/#go-skynet/LocalAI&Date)
|
||||
|
||||
## 📖 License
|
||||
## License
|
||||
|
||||
LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/).
|
||||
|
||||
MIT - Author Ettore Di Giacinto <mudler@localai.io>
|
||||
|
||||
## 🙇 Acknowledgements
|
||||
## Acknowledgements
|
||||
|
||||
LocalAI couldn't have been built without the help of great software already available from the community. Thank you!
|
||||
|
||||
@@ -463,9 +266,9 @@ LocalAI couldn't have been built without the help of great software already avai
|
||||
- https://github.com/rhasspy/piper
|
||||
- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation
|
||||
|
||||
## 🤗 Contributors
|
||||
## Contributors
|
||||
|
||||
This is a community project, a special thanks to our contributors! 🤗
|
||||
This is a community project, a special thanks to our contributors!
|
||||
<a href="https://github.com/go-skynet/LocalAI/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" />
|
||||
</a>
|
||||
|
||||
39
backend/Dockerfile.rust
Normal file
39
backend/Dockerfile.rust
Normal file
@@ -0,0 +1,39 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG BACKEND=kokoros
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git ccache \
|
||||
ca-certificates \
|
||||
make cmake wget \
|
||||
curl unzip \
|
||||
clang \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
espeak-ng libespeak-ng-dev \
|
||||
libsonic-dev libpcaudio-dev \
|
||||
libopus-dev \
|
||||
protobuf-compiler && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
RUN git config --global --add safe.directory /LocalAI
|
||||
|
||||
RUN make -C /LocalAI/backend/rust/${BACKEND} build
|
||||
|
||||
FROM scratch
|
||||
ARG BACKEND=kokoros
|
||||
|
||||
COPY --from=builder /LocalAI/backend/rust/${BACKEND}/package/. ./
|
||||
@@ -39,6 +39,19 @@ service Backend {
|
||||
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
||||
|
||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||
|
||||
// Fine-tuning RPCs
|
||||
rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {}
|
||||
rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {}
|
||||
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
|
||||
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
|
||||
rpc ExportModel(ExportModelRequest) returns (Result) {}
|
||||
|
||||
// Quantization RPCs
|
||||
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -166,6 +179,7 @@ message PredictOptions {
|
||||
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
||||
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
||||
map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking)
|
||||
float MinP = 53; // Minimum probability sampling threshold (0.0 = disabled)
|
||||
}
|
||||
|
||||
// ToolCallDelta represents an incremental tool call update from the C++ parser.
|
||||
@@ -430,6 +444,10 @@ message Message {
|
||||
|
||||
message DetectOptions {
|
||||
string src = 1;
|
||||
string prompt = 2; // Text prompt (for SAM 3 PCS mode)
|
||||
repeated float points = 3; // Point coordinates as [x1, y1, label1, x2, y2, label2, ...] (label: 1=pos, 0=neg)
|
||||
repeated float boxes = 4; // Box coordinates as [x1, y1, x2, y2, ...]
|
||||
float threshold = 5; // Detection confidence threshold
|
||||
}
|
||||
|
||||
message Detection {
|
||||
@@ -439,6 +457,7 @@ message Detection {
|
||||
float height = 4;
|
||||
float confidence = 5;
|
||||
string class_name = 6;
|
||||
bytes mask = 7; // PNG-encoded binary segmentation mask
|
||||
}
|
||||
|
||||
message DetectResponse {
|
||||
@@ -472,7 +491,7 @@ message ToolFormatMarkers {
|
||||
string id_field = 16; // e.g., "id"
|
||||
bool fun_name_is_key = 17;
|
||||
bool tools_array_wrapped = 18;
|
||||
bool uses_python_dicts = 19;
|
||||
reserved 19;
|
||||
|
||||
// Reasoning markers
|
||||
string reasoning_start = 20; // e.g., "<think>"
|
||||
@@ -528,3 +547,139 @@ message ModelMetadataResponse {
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
||||
}
|
||||
|
||||
// Fine-tuning messages
|
||||
|
||||
message FineTuneRequest {
|
||||
// Model identification
|
||||
string model = 1; // HF model name or local path
|
||||
string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train
|
||||
string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training"
|
||||
|
||||
// Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion)
|
||||
int32 adapter_rank = 10; // LoRA rank (r), default 16
|
||||
int32 adapter_alpha = 11; // scaling factor, default 16
|
||||
float adapter_dropout = 12; // default 0.0
|
||||
repeated string target_modules = 13; // layer names to adapt
|
||||
|
||||
// Universal training hyperparameters
|
||||
float learning_rate = 20; // default 2e-4
|
||||
int32 num_epochs = 21; // default 3
|
||||
int32 batch_size = 22; // default 2
|
||||
int32 gradient_accumulation_steps = 23; // default 4
|
||||
int32 warmup_steps = 24; // default 5
|
||||
int32 max_steps = 25; // 0 = use epochs
|
||||
int32 save_steps = 26; // 0 = only save final
|
||||
float weight_decay = 27; // default 0.01
|
||||
bool gradient_checkpointing = 28;
|
||||
string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy
|
||||
int32 seed = 30; // default 3407
|
||||
string mixed_precision = 31; // fp16, bf16, fp8, no
|
||||
|
||||
// Dataset
|
||||
string dataset_source = 40; // HF dataset ID, local file/dir path
|
||||
string dataset_split = 41; // train, test, etc.
|
||||
|
||||
// Output
|
||||
string output_dir = 50;
|
||||
string job_id = 51; // client-assigned or auto-generated
|
||||
|
||||
// Resume training from a checkpoint
|
||||
string resume_from_checkpoint = 55; // path to checkpoint dir to resume from
|
||||
|
||||
// Backend-specific AND method-specific extensibility
|
||||
map<string, string> extra_options = 60;
|
||||
}
|
||||
|
||||
message FineTuneJobResult {
|
||||
string job_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
message FineTuneProgressRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
message FineTuneProgressUpdate {
|
||||
string job_id = 1;
|
||||
int32 current_step = 2;
|
||||
int32 total_steps = 3;
|
||||
float current_epoch = 4;
|
||||
float total_epochs = 5;
|
||||
float loss = 6;
|
||||
float learning_rate = 7;
|
||||
float grad_norm = 8;
|
||||
float eval_loss = 9;
|
||||
float eta_seconds = 10;
|
||||
float progress_percent = 11;
|
||||
string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped
|
||||
string message = 13;
|
||||
string checkpoint_path = 14; // set when a checkpoint is saved
|
||||
string sample_path = 15; // set when a sample is generated (video/image backends)
|
||||
map<string, float> extra_metrics = 16; // method-specific metrics
|
||||
}
|
||||
|
||||
message FineTuneStopRequest {
|
||||
string job_id = 1;
|
||||
bool save_checkpoint = 2;
|
||||
}
|
||||
|
||||
message ListCheckpointsRequest {
|
||||
string output_dir = 1;
|
||||
}
|
||||
|
||||
message ListCheckpointsResponse {
|
||||
repeated CheckpointInfo checkpoints = 1;
|
||||
}
|
||||
|
||||
message CheckpointInfo {
|
||||
string path = 1;
|
||||
int32 step = 2;
|
||||
float epoch = 3;
|
||||
float loss = 4;
|
||||
string created_at = 5;
|
||||
}
|
||||
|
||||
message ExportModelRequest {
|
||||
string checkpoint_path = 1;
|
||||
string output_path = 2;
|
||||
string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers
|
||||
string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||
string model = 5; // base model name (for merge operations)
|
||||
map<string, string> extra_options = 6;
|
||||
}
|
||||
|
||||
// Quantization messages
|
||||
|
||||
message QuantizationRequest {
|
||||
string model = 1; // HF model name or local path
|
||||
string quantization_type = 2; // q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||
string output_dir = 3; // where to write output files
|
||||
string job_id = 4; // client-assigned job ID
|
||||
map<string, string> extra_options = 5; // hf_token, custom flags, etc.
|
||||
}
|
||||
|
||||
message QuantizationJobResult {
|
||||
string job_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
message QuantizationProgressRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
message QuantizationProgressUpdate {
|
||||
string job_id = 1;
|
||||
float progress_percent = 2;
|
||||
string status = 3; // queued, downloading, converting, quantizing, completed, failed, stopped
|
||||
string message = 4;
|
||||
string output_file = 5; // set when completed — path to the output GGUF file
|
||||
map<string, float> extra_metrics = 6; // e.g. file_size_mb, compression_ratio
|
||||
}
|
||||
|
||||
message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=e30f1fdf74ea9238ff562901aa974c75aab6619b
|
||||
LLAMA_VERSION?=e62fa13c2497b2cd1958cb496e9489e86bbd5182
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -22,8 +22,10 @@
|
||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/health_check_service_interface.h>
|
||||
#include <grpcpp/security/server_credentials.h>
|
||||
#include <regex>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <signal.h>
|
||||
#include <thread>
|
||||
@@ -37,6 +39,43 @@ using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
using grpc::Status;
|
||||
|
||||
// gRPC bearer token auth for distributed mode.
|
||||
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
||||
// requests without a matching "authorization: Bearer <token>" metadata header.
|
||||
|
||||
// Cached auth token — empty means auth is disabled.
|
||||
static std::string g_grpc_auth_token;
|
||||
|
||||
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||
unsigned char result = 0;
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
result |= pa[i] ^ pb[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns OK when auth is disabled or the token matches.
|
||||
static grpc::Status checkAuth(grpc::ServerContext* context) {
|
||||
if (g_grpc_auth_token.empty()) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
auto metadata = context->client_metadata();
|
||||
auto it = metadata.find("authorization");
|
||||
if (it != metadata.end()) {
|
||||
std::string expected = "Bearer " + g_grpc_auth_token;
|
||||
std::string got(it->second.data(), it->second.size());
|
||||
if (expected.size() == got.size() &&
|
||||
ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
}
|
||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||
}
|
||||
|
||||
// END LocalAI
|
||||
|
||||
|
||||
@@ -136,6 +175,7 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
data["mirostat_eta"] = predict->mirostateta();
|
||||
data["n_keep"] = predict->nkeep();
|
||||
data["seed"] = predict->seed();
|
||||
data["min_p"] = predict->minp();
|
||||
|
||||
|
||||
std::string grammar_str = predict->grammar();
|
||||
@@ -244,6 +284,12 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
||||
data["ignore_eos"] = predict->ignoreeos();
|
||||
data["embeddings"] = predict->embeddings();
|
||||
|
||||
// Speculative decoding per-request overrides
|
||||
// NDraft maps to speculative.n_max (maximum draft tokens per speculation step)
|
||||
if (predict->ndraft() > 0) {
|
||||
data["speculative.n_max"] = predict->ndraft();
|
||||
}
|
||||
|
||||
// Add the correlationid to json data
|
||||
data["correlation_id"] = predict->correlationid();
|
||||
|
||||
@@ -362,6 +408,16 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
if (!request->mmproj().empty()) {
|
||||
params.mmproj.path = request->mmproj();
|
||||
}
|
||||
|
||||
// Draft model for speculative decoding
|
||||
if (!request->draftmodel().empty()) {
|
||||
params.speculative.mparams_dft.path = request->draftmodel();
|
||||
// Default to draft type if a draft model is set but no explicit type
|
||||
if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_DRAFT;
|
||||
}
|
||||
}
|
||||
|
||||
// params.model_alias ??
|
||||
params.model_alias.insert(request->modelfile());
|
||||
if (!request->cachetypekey().empty()) {
|
||||
@@ -569,6 +625,48 @@ static void params_parse(server_context& /*ctx_server*/, const backend::ModelOpt
|
||||
// If conversion fails, keep default value (8)
|
||||
}
|
||||
}
|
||||
// Speculative decoding options
|
||||
} else if (!strcmp(optname, "spec_type") || !strcmp(optname, "speculative_type")) {
|
||||
auto type = common_speculative_type_from_name(optval_str);
|
||||
if (type != COMMON_SPECULATIVE_TYPE_COUNT) {
|
||||
params.speculative.type = type;
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_n_max") || !strcmp(optname, "draft_max")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_max = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_n_min") || !strcmp(optname, "draft_min")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_min = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_p_min") || !strcmp(optname, "draft_p_min")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.p_min = std::stof(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_p_split")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.p_split = std::stof(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_size_n") || !strcmp(optname, "ngram_size_n")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_size_n = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_size_m") || !strcmp(optname, "ngram_size_m")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_size_m = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "spec_ngram_min_hits") || !strcmp(optname, "ngram_min_hits")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.ngram_min_hits = (uint16_t)std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "draft_gpu_layers")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_gpu_layers = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
} else if (!strcmp(optname, "draft_ctx_size")) {
|
||||
if (optval != NULL) {
|
||||
try { params.speculative.n_ctx = std::stoi(optval_str); } catch (...) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -713,13 +811,17 @@ private:
|
||||
public:
|
||||
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {}
|
||||
|
||||
grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
grpc::Status Health(ServerContext* context, const backend::HealthMessage* /*request*/, backend::Reply* reply) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
// Implement Health RPC
|
||||
reply->set_message("OK");
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
// Implement LoadModel RPC
|
||||
common_params params;
|
||||
params_parse(ctx_server, request, params);
|
||||
@@ -918,6 +1020,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -1205,6 +1309,7 @@ public:
|
||||
|
||||
body_json["messages"] = messages_json;
|
||||
body_json["stream"] = true; // PredictStream is always streaming
|
||||
body_json["stream_options"] = {{"include_usage", true}}; // Ensure token counts in final chunk
|
||||
|
||||
// Check if grammar is provided from Go layer (NoGrammar=false)
|
||||
// If grammar is provided, we must use it and NOT let template generate grammar from tools
|
||||
@@ -1509,11 +1614,15 @@ public:
|
||||
ctx_server.impl->vocab,
|
||||
params_base,
|
||||
ctx_server.get_meta().slot_n_ctx,
|
||||
ctx_server.get_meta().logit_bias_eog,
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
// Without this, the PEG parser never produces diffs and the Go side
|
||||
// cannot detect tool calls or separate reasoning from content.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -1538,19 +1647,47 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||
}
|
||||
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result.
|
||||
// Handles both native format ({"content": "..."}) and OAI chat format
|
||||
// ({"choices": [{"delta": {"content": "...", "reasoning": "..."}}]}).
|
||||
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
||||
backend::Reply reply;
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
reply.set_message(completion_text);
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
std::string completion_text;
|
||||
|
||||
if (res_json.contains("choices")) {
|
||||
// OAI chat format — extract content from choices[0].delta
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & delta = choices[0].value("delta", json::object());
|
||||
if (delta.contains("content") && !delta.at("content").is_null()) {
|
||||
completion_text = delta.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = res_json.value("content", "");
|
||||
}
|
||||
|
||||
reply.set_message(completion_text);
|
||||
|
||||
// Token counts: native format has top-level fields,
|
||||
// OAI format has them in "usage" (final chunk only)
|
||||
if (res_json.contains("usage")) {
|
||||
const auto & usage = res_json.at("usage");
|
||||
reply.set_tokens(usage.value("completion_tokens", 0));
|
||||
reply.set_prompt_tokens(usage.value("prompt_tokens", 0));
|
||||
} else {
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
}
|
||||
|
||||
// Timings: present as top-level "timings" in both formats
|
||||
if (res_json.contains("timings")) {
|
||||
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
reply.set_logprobs(logprobs_json.dump());
|
||||
@@ -1559,6 +1696,12 @@ public:
|
||||
return reply;
|
||||
};
|
||||
|
||||
// Attach chat deltas from the autoparser to a Reply.
|
||||
// When diffs are available, populate ChatDeltas on the reply.
|
||||
// The raw message is always preserved so the Go side can use it
|
||||
// for reasoning extraction and tool call parsing as a fallback
|
||||
// (important in distributed mode where ChatDeltas may not be
|
||||
// the primary parsing path).
|
||||
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
||||
// Try streaming partial result first
|
||||
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
||||
@@ -1573,12 +1716,23 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Process first result
|
||||
// Process first result.
|
||||
// When TASK_RESPONSE_TYPE_OAI_CHAT is used, the first token may
|
||||
// produce a JSON array with a role-init element followed by the
|
||||
// actual content element. We must only attach chat deltas to the
|
||||
// content element — attaching to both would duplicate the first
|
||||
// token since oaicompat_msg_diffs is the same for both.
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
// Skip chat deltas for role-init elements (have "role" in
|
||||
// delta but no content/reasoning diffs of their own).
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
res["choices"][0].value("delta", json::object()).contains("role");
|
||||
if (!is_role_init) {
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
}
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -1602,7 +1756,11 @@ public:
|
||||
if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
attach_chat_deltas(reply, result.get());
|
||||
bool is_role_init = res.contains("choices") && !res["choices"].empty() &&
|
||||
res["choices"][0].value("delta", json::object()).contains("role");
|
||||
if (!is_role_init) {
|
||||
attach_chat_deltas(reply, result.get());
|
||||
}
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
@@ -1621,6 +1779,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2238,11 +2398,13 @@ public:
|
||||
ctx_server.impl->vocab,
|
||||
params_base,
|
||||
ctx_server.get_meta().slot_n_ctx,
|
||||
ctx_server.get_meta().logit_bias_eog,
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -2273,25 +2435,48 @@ public:
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
||||
GGML_ASSERT(final_res != nullptr);
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
// Handle both native format ({"content": "...", "tokens_predicted": N})
|
||||
// and OAI chat format ({"choices": [{"message": {"content": "..."}}],
|
||||
// "usage": {"completion_tokens": N, "prompt_tokens": N}}).
|
||||
std::string completion_text;
|
||||
int32_t tokens_predicted = 0;
|
||||
int32_t tokens_evaluated = 0;
|
||||
|
||||
if (result_json.contains("choices")) {
|
||||
// OAI chat format
|
||||
const auto & choices = result_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
completion_text = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
if (result_json.contains("usage")) {
|
||||
const auto & usage = result_json.at("usage");
|
||||
tokens_predicted = usage.value("completion_tokens", 0);
|
||||
tokens_evaluated = usage.value("prompt_tokens", 0);
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = result_json.value("content", "");
|
||||
tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
}
|
||||
reply->set_message(completion_text);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
// Timings: present in both formats as a top-level "timings" object
|
||||
if (result_json.contains("timings")) {
|
||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
reply->set_timing_prompt_processing(result_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply->set_timing_token_generation(result_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
reply->set_logprobs(logprobs_json.dump());
|
||||
}
|
||||
|
||||
// Populate chat deltas from the autoparser's final parsed message
|
||||
@@ -2307,7 +2492,20 @@ public:
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
json res_json = res->to_json();
|
||||
arr.push_back(res_json.value("content", ""));
|
||||
// Handle both native and OAI chat formats
|
||||
std::string result_content;
|
||||
if (res_json.contains("choices")) {
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
result_content = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result_content = res_json.value("content", "");
|
||||
}
|
||||
arr.push_back(result_content);
|
||||
|
||||
// Extract logprobs for each result
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
@@ -2339,6 +2537,8 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2519,7 +2719,9 @@ public:
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
@@ -2687,7 +2889,6 @@ public:
|
||||
tf->set_id_field(ap.tools.format.id_field);
|
||||
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
|
||||
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
|
||||
tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts);
|
||||
tf->set_function_field(ap.tools.format.function_field);
|
||||
|
||||
tf->set_gen_id_field(ap.tools.format.gen_id_field);
|
||||
@@ -2761,10 +2962,18 @@ int main(int argc, char** argv) {
|
||||
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
|
||||
// Initialize bearer token auth if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||
g_grpc_auth_token = auth_token;
|
||||
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
||||
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]()
|
||||
|
||||
@@ -24,6 +24,9 @@ if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
@@ -33,6 +36,9 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# acestep.cpp version
|
||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||
ACESTEP_CPP_VERSION?=5aa065445541094cba934299cd498bbb9fa5c434
|
||||
ACESTEP_CPP_VERSION?=e0c8d75a672fca5684c88c68dbf6d12f58754258
|
||||
SO_TARGET?=libgoacestepcpp.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -106,12 +106,13 @@ func TestLoadModel(t *testing.T) {
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
|
||||
// Get base directory from main model file for relative paths
|
||||
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
|
||||
|
||||
|
||||
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: mainModelPath,
|
||||
ModelPath: modelDir,
|
||||
Options: []string{
|
||||
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
|
||||
"dit_model:acestep-v15-turbo-Q8_0.gguf",
|
||||
@@ -133,7 +134,7 @@ func TestSoundGeneration(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
outputFile := filepath.Join(tmpDir, "output.wav")
|
||||
|
||||
@@ -151,6 +152,7 @@ func TestSoundGeneration(t *testing.T) {
|
||||
// Load models
|
||||
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||
ModelFile: mainModelPath,
|
||||
ModelPath: modelDir,
|
||||
Options: []string{
|
||||
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
|
||||
"dit_model:acestep-v15-turbo-Q8_0.gguf",
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
||||
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
||||
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
|
||||
)
|
||||
|
||||
@@ -24,23 +24,23 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
|
||||
lmModel := opts.ModelFile
|
||||
|
||||
// Get the base directory from ModelFile for resolving relative paths
|
||||
baseDir := filepath.Dir(lmModel)
|
||||
baseDir := opts.ModelPath
|
||||
|
||||
var textEncoderModel, ditModel, vaeModel string
|
||||
|
||||
for _, oo := range opts.Options {
|
||||
parts := strings.SplitN(oo, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
key, value, found := strings.Cut(oo, ":")
|
||||
if !found {
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
continue
|
||||
}
|
||||
switch parts[0] {
|
||||
switch key {
|
||||
case "text_encoder_model":
|
||||
textEncoderModel = parts[1]
|
||||
textEncoderModel = value
|
||||
case "dit_model":
|
||||
ditModel = parts[1]
|
||||
ditModel = value
|
||||
case "vae_model":
|
||||
vaeModel = parts[1]
|
||||
vaeModel = value
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ type LLM struct {
|
||||
draftModel *llama.LLama
|
||||
}
|
||||
|
||||
|
||||
// Free releases GPU resources and frees the llama model
|
||||
// This should be called when the model is being unloaded to properly release VRAM
|
||||
func (llm *LLM) Free() error {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build debug
|
||||
// +build debug
|
||||
|
||||
package main
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !debug
|
||||
// +build !debug
|
||||
|
||||
package main
|
||||
|
||||
|
||||
@@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot float32
|
||||
for i := 0; i < len(k1); i++ {
|
||||
for i := range len(k1) {
|
||||
dot += k1[i] * k2[i]
|
||||
}
|
||||
|
||||
@@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := 0; i < len(k1); i++ {
|
||||
for i := range len(k1) {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
@@ -701,7 +701,7 @@ var _ = Describe("Opus", func() {
|
||||
// to one-shot (only difference is resampler batch boundaries).
|
||||
var maxDiff float64
|
||||
var sumDiffSq float64
|
||||
for i := 0; i < minLen; i++ {
|
||||
for i := range minLen {
|
||||
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
|
||||
if diff > maxDiff {
|
||||
maxDiff = diff
|
||||
@@ -774,7 +774,7 @@ var _ = Describe("Opus", func() {
|
||||
minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
|
||||
|
||||
var persistentMaxDiff, freshMaxDiff float64
|
||||
for i := 0; i < minLen; i++ {
|
||||
for i := range minLen {
|
||||
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
|
||||
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
|
||||
if pd > persistentMaxDiff {
|
||||
@@ -932,7 +932,7 @@ var _ = Describe("Opus", func() {
|
||||
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
|
||||
mean, stddev, stddev/mean, 16000.0/440.0/2.0)
|
||||
|
||||
Expect(stddev / mean).To(BeNumerically("<", 0.15),
|
||||
Expect(stddev/mean).To(BeNumerically("<", 0.15),
|
||||
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
|
||||
|
||||
// Also check frequency is correct
|
||||
@@ -978,7 +978,7 @@ var _ = Describe("Opus", func() {
|
||||
|
||||
// Every sample must be identical — the resampler is deterministic
|
||||
var maxDiff float64
|
||||
for i := 0; i < len(oneShot); i++ {
|
||||
for i := range len(oneShot) {
|
||||
diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
|
||||
if diff > maxDiff {
|
||||
maxDiff = diff
|
||||
@@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() {
|
||||
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
|
||||
copy(hdr[8:12], "WAVE")
|
||||
copy(hdr[12:16], "fmt ")
|
||||
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
||||
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
||||
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
||||
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
||||
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
||||
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
||||
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
||||
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
||||
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
||||
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
||||
copy(hdr[36:40], "data")
|
||||
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
|
||||
|
||||
@@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() {
|
||||
)
|
||||
|
||||
pcm := make([]byte, toneNumSamples*2)
|
||||
for i := 0; i < toneNumSamples; i++ {
|
||||
for i := range toneNumSamples {
|
||||
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
|
||||
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
|
||||
}
|
||||
|
||||
7
backend/go/sam3-cpp/.gitignore
vendored
Normal file
7
backend/go/sam3-cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
sources/
|
||||
build*/
|
||||
package/
|
||||
libgosam3*.so
|
||||
sam3-cpp
|
||||
test-models/
|
||||
test-data/
|
||||
26
backend/go/sam3-cpp/CMakeLists.txt
Normal file
26
backend/go/sam3-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
project(gosam3 LANGUAGES C CXX)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# Build ggml as static libraries to avoid runtime .so dependencies
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build static libraries" FORCE)
|
||||
|
||||
set(SAM3_BUILD_EXAMPLES OFF CACHE BOOL "Disable sam3.cpp examples" FORCE)
|
||||
set(SAM3_BUILD_TESTS OFF CACHE BOOL "Disable sam3.cpp tests" FORCE)
|
||||
|
||||
add_subdirectory(./sources/sam3.cpp)
|
||||
|
||||
add_library(gosam3 MODULE gosam3.cpp)
|
||||
target_link_libraries(gosam3 PRIVATE sam3 ggml)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
target_link_libraries(gosam3 PRIVATE stdc++fs)
|
||||
endif()
|
||||
|
||||
target_include_directories(gosam3 PUBLIC
|
||||
sources/sam3.cpp
|
||||
sources/sam3.cpp/ggml/include
|
||||
)
|
||||
|
||||
set_property(TARGET gosam3 PROPERTY CXX_STANDARD 14)
|
||||
set_target_properties(gosam3 PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
122
backend/go/sam3-cpp/Makefile
Normal file
122
backend/go/sam3-cpp/Makefile
Normal file
@@ -0,0 +1,122 @@
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
|
||||
GOCMD?=go
|
||||
GO_TAGS?=
|
||||
JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# sam3.cpp
|
||||
SAM3_REPO?=https://github.com/PABannier/sam3.cpp
|
||||
SAM3_VERSION?=01832ef85fcc8eb6488f1d01cd247f07e96ff5a9
|
||||
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
endif
|
||||
|
||||
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
CMAKE_ARGS+=-DGGML_CUDA=ON
|
||||
else ifeq ($(BUILD_TYPE),openblas)
|
||||
CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
|
||||
else ifeq ($(BUILD_TYPE),clblas)
|
||||
CMAKE_ARGS+=-DGGML_CLBLAST=ON
|
||||
else ifeq ($(BUILD_TYPE),hipblas)
|
||||
ROCM_HOME ?= /opt/rocm
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS?=gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=ON
|
||||
else ifeq ($(OS),Darwin)
|
||||
ifneq ($(BUILD_TYPE),metal)
|
||||
CMAKE_ARGS+=-DGGML_METAL=OFF
|
||||
else
|
||||
CMAKE_ARGS+=-DGGML_METAL=ON
|
||||
CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f16)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),sycl_f32)
|
||||
CMAKE_ARGS+=-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
endif
|
||||
|
||||
sources/sam3.cpp:
|
||||
git clone --recursive $(SAM3_REPO) sources/sam3.cpp && \
|
||||
cd sources/sam3.cpp && \
|
||||
git checkout $(SAM3_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
# Only build CPU variants on Linux
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgosam3-avx.so libgosam3-avx2.so libgosam3-avx512.so libgosam3-fallback.so
|
||||
else
|
||||
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||
VARIANT_TARGETS = libgosam3-fallback.so
|
||||
endif
|
||||
|
||||
sam3-cpp: main.go gosam3.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o sam3-cpp ./
|
||||
|
||||
package: sam3-cpp
|
||||
bash package.sh
|
||||
|
||||
build: package
|
||||
|
||||
clean: purge
|
||||
rm -rf libgosam3*.so sam3-cpp package sources
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgosam3-avx.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx${RESET})
|
||||
SO_TARGET=libgosam3-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-avx2.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx2${RESET})
|
||||
SO_TARGET=libgosam3-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-avx512.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:avx512${RESET})
|
||||
SO_TARGET=libgosam3-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
libgosam3-fallback.so: sources/sam3.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I sam3-cpp build info:fallback${RESET})
|
||||
SO_TARGET=libgosam3-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgosam3-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgosam3-custom: CMakeLists.txt gosam3.cpp gosam3.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgosam3.so ./$(SO_TARGET)
|
||||
|
||||
all: sam3-cpp package
|
||||
193
backend/go/sam3-cpp/gosam3.cpp
Normal file
193
backend/go/sam3-cpp/gosam3.cpp
Normal file
@@ -0,0 +1,193 @@
|
||||
#include "sam3.h"
|
||||
#include "gosam3.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||
#define STB_IMAGE_WRITE_STATIC
|
||||
#include "stb_image_write.h"
|
||||
|
||||
// Static state
|
||||
static std::shared_ptr<sam3_model> g_model;
|
||||
static sam3_state_ptr g_state;
|
||||
static sam3_result g_result;
|
||||
static std::vector<std::vector<unsigned char>> g_mask_pngs;
|
||||
|
||||
// Callback for stbi_write_png_to_mem via stbi_write_png_to_func
|
||||
static void png_write_callback(void *context, void *data, int size) {
|
||||
auto *buf = static_cast<std::vector<unsigned char>*>(context);
|
||||
auto *bytes = static_cast<unsigned char*>(data);
|
||||
buf->insert(buf->end(), bytes, bytes + size);
|
||||
}
|
||||
|
||||
// Encode all masks as PNGs after segmentation
|
||||
static void encode_masks_as_png() {
|
||||
g_mask_pngs.clear();
|
||||
g_mask_pngs.resize(g_result.detections.size());
|
||||
|
||||
for (size_t i = 0; i < g_result.detections.size(); i++) {
|
||||
const auto &mask = g_result.detections[i].mask;
|
||||
if (mask.width > 0 && mask.height > 0 && !mask.data.empty()) {
|
||||
stbi_write_png_to_func(png_write_callback, &g_mask_pngs[i],
|
||||
mask.width, mask.height, 1,
|
||||
mask.data.data(), mask.width);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
int sam3_cpp_load_model(const char *model_path, int threads) {
|
||||
sam3_params params;
|
||||
params.model_path = model_path;
|
||||
params.n_threads = threads;
|
||||
params.use_gpu = true;
|
||||
|
||||
g_model = sam3_load_model(params);
|
||||
if (!g_model) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to load model: %s\n", model_path);
|
||||
return 1;
|
||||
}
|
||||
|
||||
g_state = sam3_create_state(*g_model, params);
|
||||
if (!g_state) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to create state\n");
|
||||
g_model.reset();
|
||||
return 2;
|
||||
}
|
||||
|
||||
fprintf(stderr, "[sam3-cpp] Model loaded: %s (threads=%d)\n", model_path, threads);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sam3_cpp_encode_image(const char *image_path) {
|
||||
if (!g_model || !g_state) {
|
||||
fprintf(stderr, "[sam3-cpp] Model not loaded\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
sam3_image img = sam3_load_image(image_path);
|
||||
if (img.data.empty()) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to load image: %s\n", image_path);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (!sam3_encode_image(*g_state, *g_model, img)) {
|
||||
fprintf(stderr, "[sam3-cpp] Failed to encode image\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sam3_cpp_segment_pvs(float *points, int n_point_triples,
|
||||
float *boxes, int n_box_quads,
|
||||
float threshold) {
|
||||
if (!g_model || !g_state) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
sam3_pvs_params pvs_params;
|
||||
|
||||
// Parse points: each triple is [x, y, label]
|
||||
for (int i = 0; i < n_point_triples; i++) {
|
||||
float x = points[i * 3];
|
||||
float y = points[i * 3 + 1];
|
||||
float label = points[i * 3 + 2];
|
||||
sam3_point pt = {x, y};
|
||||
if (label > 0.5f) {
|
||||
pvs_params.pos_points.push_back(pt);
|
||||
} else {
|
||||
pvs_params.neg_points.push_back(pt);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse boxes: each quad is [x1, y1, x2, y2], use only first box
|
||||
if (n_box_quads > 0) {
|
||||
pvs_params.box = {boxes[0], boxes[1], boxes[2], boxes[3]};
|
||||
pvs_params.use_box = true;
|
||||
}
|
||||
|
||||
g_result = sam3_segment_pvs(*g_state, *g_model, pvs_params);
|
||||
encode_masks_as_png();
|
||||
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
int sam3_cpp_segment_pcs(const char *text_prompt, float threshold) {
|
||||
if (!g_model || !g_state) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// PCS mode requires SAM 3 (full model with text encoder)
|
||||
if (sam3_is_visual_only(*g_model) ||
|
||||
sam3_get_model_type(*g_model) != SAM3_MODEL_SAM3) {
|
||||
fprintf(stderr, "[sam3-cpp] PCS mode requires full SAM 3 model\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sam3_pcs_params pcs_params;
|
||||
pcs_params.text_prompt = text_prompt;
|
||||
pcs_params.score_threshold = threshold > 0 ? threshold : 0.5f;
|
||||
|
||||
g_result = sam3_segment_pcs(*g_state, *g_model, pcs_params);
|
||||
encode_masks_as_png();
|
||||
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
int sam3_cpp_get_n_detections(void) {
|
||||
return static_cast<int>(g_result.detections.size());
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_x(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].box.x0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_y(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].box.y0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_w(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
const auto &box = g_result.detections[i].box;
|
||||
return box.x1 - box.x0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_h(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
const auto &box = g_result.detections[i].box;
|
||||
return box.y1 - box.y0;
|
||||
}
|
||||
|
||||
float sam3_cpp_get_detection_score(int i) {
|
||||
if (i < 0 || i >= static_cast<int>(g_result.detections.size())) return 0;
|
||||
return g_result.detections[i].score;
|
||||
}
|
||||
|
||||
int sam3_cpp_get_detection_mask_png(int i, unsigned char *buf, int buf_size) {
|
||||
if (i < 0 || i >= static_cast<int>(g_mask_pngs.size())) return 0;
|
||||
|
||||
const auto &png = g_mask_pngs[i];
|
||||
int size = static_cast<int>(png.size());
|
||||
|
||||
if (buf == nullptr) {
|
||||
return size;
|
||||
}
|
||||
|
||||
int to_copy = size < buf_size ? size : buf_size;
|
||||
memcpy(buf, png.data(), to_copy);
|
||||
return to_copy;
|
||||
}
|
||||
|
||||
void sam3_cpp_free_results(void) {
|
||||
g_result.detections.clear();
|
||||
g_mask_pngs.clear();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
143
backend/go/sam3-cpp/gosam3.go
Normal file
143
backend/go/sam3-cpp/gosam3.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type SAM3 struct {
|
||||
base.SingleThread
|
||||
}
|
||||
|
||||
var (
|
||||
CppLoadModel func(modelPath string, threads int) int
|
||||
CppEncodeImage func(imagePath string) int
|
||||
CppSegmentPVS func(points uintptr, nPointTriples int, boxes uintptr, nBoxQuads int, threshold float32) int
|
||||
CppSegmentPCS func(textPrompt string, threshold float32) int
|
||||
CppGetNDetections func() int
|
||||
CppGetDetectionX func(i int) float32
|
||||
CppGetDetectionY func(i int) float32
|
||||
CppGetDetectionW func(i int) float32
|
||||
CppGetDetectionH func(i int) float32
|
||||
CppGetDetectionScore func(i int) float32
|
||||
CppGetDetectionMaskPNG func(i int, buf uintptr, bufSize int) int
|
||||
CppFreeResults func()
|
||||
)
|
||||
|
||||
func (s *SAM3) Load(opts *pb.ModelOptions) error {
|
||||
modelFile := opts.ModelFile
|
||||
if modelFile == "" {
|
||||
modelFile = opts.Model
|
||||
}
|
||||
|
||||
var modelPath string
|
||||
if filepath.IsAbs(modelFile) {
|
||||
modelPath = modelFile
|
||||
} else {
|
||||
modelPath = filepath.Join(opts.ModelPath, modelFile)
|
||||
}
|
||||
|
||||
threads := int(opts.Threads)
|
||||
if threads <= 0 {
|
||||
threads = 4
|
||||
}
|
||||
|
||||
ret := CppLoadModel(modelPath, threads)
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("failed to load SAM3 model (error %d): %s", ret, modelPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SAM3) Detect(opts *pb.DetectOptions) (pb.DetectResponse, error) {
|
||||
// Decode base64 image and write to temp file
|
||||
imgData, err := base64.StdEncoding.DecodeString(opts.Src)
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "sam3-*.png")
|
||||
if err != nil {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := tmpFile.Write(imgData); err != nil {
|
||||
tmpFile.Close()
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Encode image
|
||||
ret := CppEncodeImage(tmpFile.Name())
|
||||
if ret != 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("failed to encode image (error %d)", ret)
|
||||
}
|
||||
|
||||
threshold := opts.Threshold
|
||||
if threshold <= 0 {
|
||||
threshold = 0.5
|
||||
}
|
||||
|
||||
// Determine segmentation mode
|
||||
var nDetections int
|
||||
if opts.Prompt != "" {
|
||||
// Text-prompted segmentation (PCS mode, SAM 3 only)
|
||||
nDetections = CppSegmentPCS(opts.Prompt, threshold)
|
||||
} else {
|
||||
// Point/box-prompted segmentation (PVS mode)
|
||||
var pointsPtr uintptr
|
||||
var boxesPtr uintptr
|
||||
nPointTriples := len(opts.Points) / 3
|
||||
nBoxQuads := len(opts.Boxes) / 4
|
||||
|
||||
if nPointTriples > 0 {
|
||||
pointsPtr = uintptr(unsafe.Pointer(&opts.Points[0]))
|
||||
}
|
||||
if nBoxQuads > 0 {
|
||||
boxesPtr = uintptr(unsafe.Pointer(&opts.Boxes[0]))
|
||||
}
|
||||
|
||||
nDetections = CppSegmentPVS(pointsPtr, nPointTriples, boxesPtr, nBoxQuads, threshold)
|
||||
}
|
||||
|
||||
if nDetections < 0 {
|
||||
return pb.DetectResponse{}, fmt.Errorf("segmentation failed")
|
||||
}
|
||||
|
||||
defer CppFreeResults()
|
||||
|
||||
// Build response
|
||||
detections := make([]*pb.Detection, nDetections)
|
||||
for i := 0; i < nDetections; i++ {
|
||||
det := &pb.Detection{
|
||||
X: CppGetDetectionX(i),
|
||||
Y: CppGetDetectionY(i),
|
||||
Width: CppGetDetectionW(i),
|
||||
Height: CppGetDetectionH(i),
|
||||
Confidence: CppGetDetectionScore(i),
|
||||
ClassName: "segment",
|
||||
}
|
||||
|
||||
// Get mask PNG
|
||||
maskSize := CppGetDetectionMaskPNG(i, 0, 0)
|
||||
if maskSize > 0 {
|
||||
maskBuf := make([]byte, maskSize)
|
||||
CppGetDetectionMaskPNG(i, uintptr(unsafe.Pointer(&maskBuf[0])), maskSize)
|
||||
det.Mask = maskBuf
|
||||
}
|
||||
|
||||
detections[i] = det
|
||||
}
|
||||
|
||||
return pb.DetectResponse{
|
||||
Detections: detections,
|
||||
}, nil
|
||||
}
|
||||
51
backend/go/sam3-cpp/gosam3.h
Normal file
51
backend/go/sam3-cpp/gosam3.h
Normal file
@@ -0,0 +1,51 @@
|
||||
#ifndef GOSAM3_H
|
||||
#define GOSAM3_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Load model from file. Returns 0 on success, non-zero on failure.
|
||||
int sam3_cpp_load_model(const char *model_path, int threads);
|
||||
|
||||
// Encode an image from file path. Must be called before segmentation.
|
||||
// Returns 0 on success.
|
||||
int sam3_cpp_encode_image(const char *image_path);
|
||||
|
||||
// Segment with point/box prompts (PVS mode).
|
||||
// points: flat array of [x, y, label] triples (label: 1=positive, 0=negative)
|
||||
// boxes: flat array of [x1, y1, x2, y2] quads
|
||||
// Returns number of detections, or -1 on error.
|
||||
int sam3_cpp_segment_pvs(float *points, int n_point_triples,
|
||||
float *boxes, int n_box_quads,
|
||||
float threshold);
|
||||
|
||||
// Segment with text prompt (PCS mode, SAM 3 only).
|
||||
// Returns number of detections, or -1 on error.
|
||||
int sam3_cpp_segment_pcs(const char *text_prompt, float threshold);
|
||||
|
||||
// Access detection results (valid after a segment call).
|
||||
int sam3_cpp_get_n_detections(void);
|
||||
|
||||
// Get bounding box for detection i (as x, y, width, height).
|
||||
float sam3_cpp_get_detection_x(int i);
|
||||
float sam3_cpp_get_detection_y(int i);
|
||||
float sam3_cpp_get_detection_w(int i);
|
||||
float sam3_cpp_get_detection_h(int i);
|
||||
|
||||
// Get confidence score for detection i.
|
||||
float sam3_cpp_get_detection_score(int i);
|
||||
|
||||
// Get mask as PNG-encoded bytes.
|
||||
// If buf is NULL, returns the required buffer size.
|
||||
// Otherwise writes up to buf_size bytes and returns bytes written.
|
||||
int sam3_cpp_get_detection_mask_png(int i, unsigned char *buf, int buf_size);
|
||||
|
||||
// Free current detection results.
|
||||
void sam3_cpp_free_results(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // GOSAM3_H
|
||||
56
backend/go/sam3-cpp/main.go
Normal file
56
backend/go/sam3-cpp/main.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
type LibFuncs struct {
|
||||
FuncPtr any
|
||||
Name string
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("SAM3_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgosam3-fallback.so"
|
||||
}
|
||||
|
||||
gosamLib, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
libFuncs := []LibFuncs{
|
||||
{&CppLoadModel, "sam3_cpp_load_model"},
|
||||
{&CppEncodeImage, "sam3_cpp_encode_image"},
|
||||
{&CppSegmentPVS, "sam3_cpp_segment_pvs"},
|
||||
{&CppSegmentPCS, "sam3_cpp_segment_pcs"},
|
||||
{&CppGetNDetections, "sam3_cpp_get_n_detections"},
|
||||
{&CppGetDetectionX, "sam3_cpp_get_detection_x"},
|
||||
{&CppGetDetectionY, "sam3_cpp_get_detection_y"},
|
||||
{&CppGetDetectionW, "sam3_cpp_get_detection_w"},
|
||||
{&CppGetDetectionH, "sam3_cpp_get_detection_h"},
|
||||
{&CppGetDetectionScore, "sam3_cpp_get_detection_score"},
|
||||
{&CppGetDetectionMaskPNG, "sam3_cpp_get_detection_mask_png"},
|
||||
{&CppFreeResults, "sam3_cpp_free_results"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
purego.RegisterLibFunc(lf.FuncPtr, gosamLib, lf.Name)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &SAM3{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
59
backend/go/sam3-cpp/package.sh
Executable file
59
backend/go/sam3-cpp/package.sh
Executable file
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/libgosam3-*.so $CURDIR/package/
|
||||
cp -avf $CURDIR/sam3-cpp $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ $(uname -s) = "Darwin" ]; then
|
||||
echo "Detected Darwin"
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
52
backend/go/sam3-cpp/run.sh
Executable file
52
backend/go/sam3-cpp/run.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgosam3-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgosam3-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgosam3-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export SAM3_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/sam3-cpp "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/sam3-cpp "$@"
|
||||
50
backend/go/sam3-cpp/test.sh
Executable file
50
backend/go/sam3-cpp/test.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
echo "Running sam3-cpp backend tests..."
|
||||
|
||||
# The test requires a SAM model in GGML format.
|
||||
# Uses EdgeTAM Q4_0 (~15MB) for fast CI testing.
|
||||
SAM3_MODEL_DIR="${SAM3_MODEL_DIR:-$CURDIR/test-models}"
|
||||
SAM3_MODEL_FILE="${SAM3_MODEL_FILE:-edgetam_q4_0.ggml}"
|
||||
SAM3_MODEL_URL="${SAM3_MODEL_URL:-https://huggingface.co/PABannier/sam3.cpp/resolve/main/edgetam_q4_0.ggml}"
|
||||
|
||||
# Download model if not present
|
||||
if [ ! -f "$SAM3_MODEL_DIR/$SAM3_MODEL_FILE" ]; then
|
||||
echo "Downloading EdgeTAM Q4_0 model for testing..."
|
||||
mkdir -p "$SAM3_MODEL_DIR"
|
||||
curl -L -o "$SAM3_MODEL_DIR/$SAM3_MODEL_FILE" "$SAM3_MODEL_URL" --progress-bar
|
||||
echo "Model downloaded."
|
||||
fi
|
||||
|
||||
# Create a test image (4x4 red pixel PNG) using base64
|
||||
# This is a minimal valid PNG for testing the pipeline
|
||||
TEST_IMAGE_DIR="$CURDIR/test-data"
|
||||
mkdir -p "$TEST_IMAGE_DIR"
|
||||
|
||||
# Generate a simple test image using Python if available, otherwise use a pre-encoded one
|
||||
if command -v python3 &> /dev/null; then
|
||||
python3 -c "
|
||||
import struct, zlib, base64
|
||||
def create_png(width, height, r, g, b):
|
||||
raw = b''
|
||||
for y in range(height):
|
||||
raw += b'\x00' # filter byte
|
||||
for x in range(width):
|
||||
raw += bytes([r, g, b])
|
||||
def chunk(ctype, data):
|
||||
c = ctype + data
|
||||
return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff)
|
||||
ihdr = struct.pack('>IIBBBBB', width, height, 8, 2, 0, 0, 0)
|
||||
return b'\x89PNG\r\n\x1a\n' + chunk(b'IHDR', ihdr) + chunk(b'IDAT', zlib.compress(raw)) + chunk(b'IEND', b'')
|
||||
with open('$TEST_IMAGE_DIR/test.png', 'wb') as f:
|
||||
f.write(create_png(64, 64, 255, 0, 0))
|
||||
"
|
||||
echo "Test image created."
|
||||
fi
|
||||
|
||||
echo "sam3-cpp test setup complete."
|
||||
echo "Model: $SAM3_MODEL_DIR/$SAM3_MODEL_FILE"
|
||||
echo "Note: Full integration tests run via the LocalAI test-extra target."
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=d6dd6d7b555c233bb9bc9f20b4751eb8c9269743
|
||||
STABLEDIFFUSION_GGML_VERSION?=e8323cabb0e4511ba18a50b1cb34cf1f87fc71ef
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -27,107 +27,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <regex>
|
||||
|
||||
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
||||
const char* sample_method_str[] = {
|
||||
"euler",
|
||||
"euler_a",
|
||||
"heun",
|
||||
"dpm2",
|
||||
"dpm++2s_a",
|
||||
"dpm++2m",
|
||||
"dpm++2mv2",
|
||||
"ipndm",
|
||||
"ipndm_v",
|
||||
"lcm",
|
||||
"ddim_trailing",
|
||||
"tcd",
|
||||
"res_multistep",
|
||||
"res_2s",
|
||||
};
|
||||
|
||||
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
|
||||
|
||||
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
||||
const char* schedulers[] = {
|
||||
"discrete",
|
||||
"karras",
|
||||
"exponential",
|
||||
"ays",
|
||||
"gits",
|
||||
"sgm_uniform",
|
||||
"simple",
|
||||
"smoothstep",
|
||||
"kl_optimal",
|
||||
"lcm",
|
||||
"bong_tangent",
|
||||
};
|
||||
|
||||
static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");
|
||||
|
||||
// New enum string arrays
|
||||
const char* rng_type_str[] = {
|
||||
"std_default",
|
||||
"cuda",
|
||||
"cpu",
|
||||
};
|
||||
static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch");
|
||||
|
||||
const char* prediction_str[] = {
|
||||
"epsilon",
|
||||
"v",
|
||||
"edm_v",
|
||||
"flow",
|
||||
"flux_flow",
|
||||
"flux2_flow",
|
||||
};
|
||||
static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch");
|
||||
|
||||
const char* lora_apply_mode_str[] = {
|
||||
"auto",
|
||||
"immediately",
|
||||
"at_runtime",
|
||||
};
|
||||
static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch");
|
||||
|
||||
constexpr const char* sd_type_str[] = {
|
||||
"f32", // 0
|
||||
"f16", // 1
|
||||
"q4_0", // 2
|
||||
"q4_1", // 3
|
||||
nullptr, // 4
|
||||
nullptr, // 5
|
||||
"q5_0", // 6
|
||||
"q5_1", // 7
|
||||
"q8_0", // 8
|
||||
"q8_1", // 9
|
||||
"q2_k", // 10
|
||||
"q3_k", // 11
|
||||
"q4_k", // 12
|
||||
"q5_k", // 13
|
||||
"q6_k", // 14
|
||||
"q8_k", // 15
|
||||
"iq2_xxs", // 16
|
||||
"iq2_xs", // 17
|
||||
"iq3_xxs", // 18
|
||||
"iq1_s", // 19
|
||||
"iq4_nl", // 20
|
||||
"iq3_s", // 21
|
||||
"iq2_s", // 22
|
||||
"iq4_xs", // 23
|
||||
"i8", // 24
|
||||
"i16", // 25
|
||||
"i32", // 26
|
||||
"i64", // 27
|
||||
"f64", // 28
|
||||
"iq1_m", // 29
|
||||
"bf16", // 30
|
||||
nullptr, nullptr, nullptr, // 31-33
|
||||
"tq1_0", // 34
|
||||
"tq2_0", // 35
|
||||
nullptr, nullptr, nullptr, // 36-38
|
||||
"mxfp4" // 39
|
||||
};
|
||||
static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch");
|
||||
|
||||
sd_ctx_params_t ctx_params;
|
||||
sd_ctx_t* sd_c;
|
||||
@@ -596,75 +496,45 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval);
|
||||
|
||||
if (!strcmp(optname, "rng_type")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
|
||||
if (!strcmp(optval, rng_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
rng_type = (rng_type_t)found;
|
||||
rng_type_t parsed = str_to_rng_type(optval);
|
||||
if (parsed != RNG_TYPE_COUNT) {
|
||||
rng_type = parsed;
|
||||
fprintf(stderr, "Found rng_type: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid rng_type: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "sampler_rng_type")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
|
||||
if (!strcmp(optval, rng_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
sampler_rng_type = (rng_type_t)found;
|
||||
rng_type_t parsed = str_to_rng_type(optval);
|
||||
if (parsed != RNG_TYPE_COUNT) {
|
||||
sampler_rng_type = parsed;
|
||||
fprintf(stderr, "Found sampler_rng_type: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "prediction")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < PREDICTION_COUNT; m++) {
|
||||
if (!strcmp(optval, prediction_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
prediction = (prediction_t)found;
|
||||
prediction_t parsed = str_to_prediction(optval);
|
||||
if (parsed != PREDICTION_COUNT) {
|
||||
prediction = parsed;
|
||||
fprintf(stderr, "Found prediction: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid prediction: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "lora_apply_mode")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) {
|
||||
if (!strcmp(optval, lora_apply_mode_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
lora_apply_mode = (lora_apply_mode_t)found;
|
||||
lora_apply_mode_t parsed = str_to_lora_apply_mode(optval);
|
||||
if (parsed != LORA_APPLY_MODE_COUNT) {
|
||||
lora_apply_mode = parsed;
|
||||
fprintf(stderr, "Found lora_apply_mode: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval);
|
||||
}
|
||||
}
|
||||
if (!strcmp(optname, "wtype")) {
|
||||
int found = -1;
|
||||
for (int m = 0; m < SD_TYPE_COUNT; m++) {
|
||||
if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) {
|
||||
found = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found != -1) {
|
||||
wtype = (sd_type_t)found;
|
||||
sd_type_t parsed = str_to_sd_type(optval);
|
||||
if (parsed != SD_TYPE_COUNT) {
|
||||
wtype = parsed;
|
||||
fprintf(stderr, "Found wtype: %s\n", optval);
|
||||
} else {
|
||||
fprintf(stderr, "Invalid wtype: %s, using default\n", optval);
|
||||
@@ -735,27 +605,25 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
||||
fprintf (stderr, "Created context: OK\n");
|
||||
|
||||
int sample_method_found = -1;
|
||||
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
|
||||
if (!strcmp(sampler, sample_method_str[m])) {
|
||||
sample_method_found = m;
|
||||
fprintf(stderr, "Found sampler: %s\n", sampler);
|
||||
}
|
||||
sample_method_t sm = str_to_sample_method(sampler);
|
||||
if (sm != SAMPLE_METHOD_COUNT) {
|
||||
sample_method_found = (int)sm;
|
||||
fprintf(stderr, "Found sampler: %s\n", sampler);
|
||||
}
|
||||
if (sample_method_found == -1) {
|
||||
sample_method_found = sd_get_default_sample_method(sd_ctx);
|
||||
fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]);
|
||||
fprintf(stderr, "Invalid sample method, using default: %s\n", sd_sample_method_name((sample_method_t)sample_method_found));
|
||||
}
|
||||
sample_method = (sample_method_t)sample_method_found;
|
||||
|
||||
for (int d = 0; d < SCHEDULER_COUNT; d++) {
|
||||
if (!strcmp(scheduler_str, schedulers[d])) {
|
||||
scheduler = (scheduler_t)d;
|
||||
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
|
||||
}
|
||||
scheduler_t sched = str_to_scheduler(scheduler_str);
|
||||
if (sched != SCHEDULER_COUNT) {
|
||||
scheduler = sched;
|
||||
fprintf(stderr, "Found scheduler: %s\n", scheduler_str);
|
||||
}
|
||||
if (scheduler == SCHEDULER_COUNT) {
|
||||
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||
fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]);
|
||||
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||
fprintf(stderr, "Invalid scheduler, using default: %s\n", sd_scheduler_name(scheduler));
|
||||
}
|
||||
|
||||
sd_c = sd_ctx;
|
||||
|
||||
@@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
// Download sample audio — JFK "ask not what your country can do for you" clip
|
||||
audioFile := filepath.Join(tmpDir, "sample.wav")
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=30c5194c9691e4e9a98b3dea9f19727397d3f46e
|
||||
WHISPER_CPP_VERSION?=95ea8f9bfb03a15db08a8989966fd1ae3361e20d
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -125,6 +125,31 @@
|
||||
nvidia-cuda-13: "cuda13-rfdetr"
|
||||
nvidia-cuda-12: "cuda12-rfdetr"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr"
|
||||
- &sam3cpp
|
||||
name: "sam3-cpp"
|
||||
alias: "sam3-cpp"
|
||||
license: mit
|
||||
description: |
|
||||
Segment Anything Model (SAM 3/2/EdgeTAM) in C/C++ using GGML.
|
||||
Supports text-prompted and point/box-prompted image segmentation.
|
||||
urls:
|
||||
- https://github.com/PABannier/sam3.cpp
|
||||
tags:
|
||||
- image-segmentation
|
||||
- object-detection
|
||||
- sam3
|
||||
- gpu
|
||||
- cpu
|
||||
capabilities:
|
||||
default: "cpu-sam3-cpp"
|
||||
nvidia: "cuda12-sam3-cpp"
|
||||
nvidia-cuda-12: "cuda12-sam3-cpp"
|
||||
nvidia-cuda-13: "cuda13-sam3-cpp"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-sam3-cpp"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-sam3-cpp"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||
intel: "intel-sycl-f32-sam3-cpp"
|
||||
vulkan: "vulkan-sam3-cpp"
|
||||
- &vllm
|
||||
name: "vllm"
|
||||
license: apache-2.0
|
||||
@@ -400,12 +425,15 @@
|
||||
license: MIT
|
||||
name: "faster-whisper"
|
||||
capabilities:
|
||||
default: "cpu-faster-whisper"
|
||||
nvidia: "cuda12-faster-whisper"
|
||||
intel: "intel-faster-whisper"
|
||||
amd: "rocm-faster-whisper"
|
||||
metal: "metal-faster-whisper"
|
||||
nvidia-cuda-13: "cuda13-faster-whisper"
|
||||
nvidia-cuda-12: "cuda12-faster-whisper"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-faster-whisper"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-faster-whisper"
|
||||
- &moonshine
|
||||
description: |
|
||||
Moonshine is a fast, accurate, and efficient speech-to-text transcription model using ONNX Runtime.
|
||||
@@ -438,6 +466,7 @@
|
||||
- whisperx
|
||||
license: BSD-4-Clause
|
||||
name: "whisperx"
|
||||
alias: "whisperx"
|
||||
capabilities:
|
||||
nvidia: "cuda12-whisperx"
|
||||
amd: "rocm-whisperx"
|
||||
@@ -445,6 +474,8 @@
|
||||
default: "cpu-whisperx"
|
||||
nvidia-cuda-13: "cuda13-whisperx"
|
||||
nvidia-cuda-12: "cuda12-whisperx"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-whisperx"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisperx"
|
||||
- &kokoro
|
||||
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
||||
description: |
|
||||
@@ -468,6 +499,26 @@
|
||||
nvidia-cuda-13: "cuda13-kokoro"
|
||||
nvidia-cuda-12: "cuda12-kokoro"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-kokoro"
|
||||
- &kokoros
|
||||
icon: https://avatars.githubusercontent.com/u/166769057?v=4
|
||||
description: |
|
||||
Kokoros is a pure Rust TTS backend using the Kokoro ONNX model (82M parameters).
|
||||
It provides fast, high-quality text-to-speech with streaming support, built on
|
||||
ONNX Runtime for efficient CPU inference. Supports English, Japanese, Mandarin
|
||||
Chinese, and German.
|
||||
urls:
|
||||
- https://huggingface.co/hexgrad/Kokoro-82M
|
||||
- https://github.com/lucasjinreal/Kokoros
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
- Rust
|
||||
- ONNX
|
||||
license: apache-2.0
|
||||
alias: "kokoros"
|
||||
name: "kokoros"
|
||||
capabilities:
|
||||
default: "cpu-kokoros"
|
||||
- &coqui
|
||||
urls:
|
||||
- https://github.com/idiap/coqui-ai-TTS
|
||||
@@ -726,6 +777,7 @@
|
||||
- TTS
|
||||
- &opus
|
||||
name: "opus"
|
||||
alias: "opus"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus"
|
||||
urls:
|
||||
- https://opus-codec.org/
|
||||
@@ -1601,6 +1653,89 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rfdetr"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-rfdetr
|
||||
## sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "sam3-cpp-development"
|
||||
capabilities:
|
||||
default: "cpu-sam3-cpp-development"
|
||||
nvidia: "cuda12-sam3-cpp-development"
|
||||
nvidia-cuda-12: "cuda12-sam3-cpp-development"
|
||||
nvidia-cuda-13: "cuda13-sam3-cpp-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-sam3-cpp-development"
|
||||
intel: "intel-sycl-f32-sam3-cpp-development"
|
||||
vulkan: "vulkan-sam3-cpp-development"
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cpu-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cpu-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda12-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda12-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "nvidia-l4t-arm64-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "nvidia-l4t-arm64-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-nvidia-l4t-arm64-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "cuda13-nvidia-l4t-arm64-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "intel-sycl-f32-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "intel-sycl-f32-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "vulkan-sam3-cpp"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-sam3-cpp
|
||||
- !!merge <<: *sam3cpp
|
||||
name: "vulkan-sam3-cpp-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-sam3-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-sam3-cpp
|
||||
## Rerankers
|
||||
- !!merge <<: *rerankers
|
||||
name: "rerankers-development"
|
||||
@@ -2041,15 +2176,32 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-kokoro
|
||||
## kokoros (Rust)
|
||||
- !!merge <<: *kokoros
|
||||
name: "kokoros-development"
|
||||
capabilities:
|
||||
default: "cpu-kokoros-development"
|
||||
- !!merge <<: *kokoros
|
||||
name: "cpu-kokoros"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-kokoros"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-kokoros
|
||||
- !!merge <<: *kokoros
|
||||
name: "cpu-kokoros-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-kokoros"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-kokoros
|
||||
## faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "faster-whisper-development"
|
||||
capabilities:
|
||||
default: "cpu-faster-whisper-development"
|
||||
nvidia: "cuda12-faster-whisper-development"
|
||||
intel: "intel-faster-whisper-development"
|
||||
amd: "rocm-faster-whisper-development"
|
||||
metal: "metal-faster-whisper-development"
|
||||
nvidia-cuda-13: "cuda13-faster-whisper-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-faster-whisper-development"
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cuda12-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-whisper"
|
||||
@@ -2090,6 +2242,36 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cuda12-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "rocm-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cpu-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "cpu-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "nvidia-l4t-arm64-faster-whisper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-faster-whisper
|
||||
- !!merge <<: *faster-whisper
|
||||
name: "nvidia-l4t-arm64-faster-whisper-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-faster-whisper"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-faster-whisper
|
||||
## moonshine
|
||||
- !!merge <<: *moonshine
|
||||
name: "moonshine-development"
|
||||
@@ -2148,6 +2330,7 @@
|
||||
default: "cpu-whisperx-development"
|
||||
nvidia-cuda-13: "cuda13-whisperx-development"
|
||||
nvidia-cuda-12: "cuda12-whisperx-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-whisperx-development"
|
||||
- !!merge <<: *whisperx
|
||||
name: "cpu-whisperx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisperx"
|
||||
@@ -2198,6 +2381,16 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "nvidia-l4t-arm64-whisperx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "nvidia-l4t-arm64-whisperx-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-whisperx
|
||||
## coqui
|
||||
|
||||
- !!merge <<: *coqui
|
||||
@@ -3029,3 +3222,82 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-voxtral
|
||||
- &trl
|
||||
name: "trl"
|
||||
alias: "trl"
|
||||
license: apache-2.0
|
||||
description: |
|
||||
HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods.
|
||||
Works on CPU and GPU.
|
||||
urls:
|
||||
- https://github.com/huggingface/trl
|
||||
tags:
|
||||
- fine-tuning
|
||||
- LLM
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
capabilities:
|
||||
default: "cpu-trl"
|
||||
nvidia: "cuda12-trl"
|
||||
nvidia-cuda-12: "cuda12-trl"
|
||||
nvidia-cuda-13: "cuda13-trl"
|
||||
## TRL backend images
|
||||
- !!merge <<: *trl
|
||||
name: "cpu-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cpu-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda12-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cublas-cuda12-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda12-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cublas-cuda12-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda13-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cublas-cuda13-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda13-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cublas-cuda13-trl
|
||||
## llama.cpp quantization backend
|
||||
- &llama-cpp-quantization
|
||||
name: "llama-cpp-quantization"
|
||||
alias: "llama-cpp-quantization"
|
||||
license: mit
|
||||
icon: https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png
|
||||
description: |
|
||||
Model quantization backend using llama.cpp. Downloads HuggingFace models, converts them to GGUF format,
|
||||
and quantizes them to various formats (q4_k_m, q5_k_m, q8_0, f16, etc.).
|
||||
urls:
|
||||
- https://github.com/ggml-org/llama.cpp
|
||||
tags:
|
||||
- quantization
|
||||
- GGUF
|
||||
- CPU
|
||||
capabilities:
|
||||
default: "cpu-llama-cpp-quantization"
|
||||
metal: "metal-darwin-arm64-llama-cpp-quantization"
|
||||
- !!merge <<: *llama-cpp-quantization
|
||||
name: "cpu-llama-cpp-quantization"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp-quantization"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-llama-cpp-quantization
|
||||
- !!merge <<: *llama-cpp-quantization
|
||||
name: "metal-darwin-arm64-llama-cpp-quantization"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp-quantization"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-llama-cpp-quantization
|
||||
|
||||
@@ -19,6 +19,10 @@ import tempfile
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from acestep.inference import (
|
||||
GenerationParams,
|
||||
GenerationConfig,
|
||||
@@ -444,6 +448,8 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -16,6 +16,10 @@ import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import tempfile
|
||||
|
||||
def is_float(s):
|
||||
@@ -225,7 +229,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
78
backend/python/common/grpc_auth.py
Normal file
78
backend/python/common/grpc_auth.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends.
|
||||
|
||||
When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without
|
||||
a valid Bearer token in the 'authorization' metadata header are rejected with
|
||||
UNAUTHENTICATED. When the variable is empty or unset, no authentication is
|
||||
performed (backward compatible).
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import os
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class _AbortHandler(grpc.RpcMethodHandler):
|
||||
"""A method handler that immediately aborts with UNAUTHENTICATED."""
|
||||
|
||||
def __init__(self):
|
||||
self.request_streaming = False
|
||||
self.response_streaming = False
|
||||
self.request_deserializer = None
|
||||
self.response_serializer = None
|
||||
self.unary_unary = self._abort
|
||||
self.unary_stream = None
|
||||
self.stream_unary = None
|
||||
self.stream_stream = None
|
||||
|
||||
@staticmethod
|
||||
def _abort(request, context):
|
||||
context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token")
|
||||
|
||||
|
||||
class TokenAuthInterceptor(grpc.ServerInterceptor):
|
||||
"""Sync gRPC server interceptor that validates a bearer token."""
|
||||
|
||||
def __init__(self, token: str):
|
||||
self._token = token
|
||||
self._abort_handler = _AbortHandler()
|
||||
|
||||
def intercept_service(self, continuation, handler_call_details):
|
||||
metadata = dict(handler_call_details.invocation_metadata)
|
||||
auth = metadata.get("authorization", "")
|
||||
expected = "Bearer " + self._token
|
||||
if not hmac.compare_digest(auth, expected):
|
||||
return self._abort_handler
|
||||
return continuation(handler_call_details)
|
||||
|
||||
|
||||
class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor):
|
||||
"""Async gRPC server interceptor that validates a bearer token."""
|
||||
|
||||
def __init__(self, token: str):
|
||||
self._token = token
|
||||
|
||||
async def intercept_service(self, continuation, handler_call_details):
|
||||
metadata = dict(handler_call_details.invocation_metadata)
|
||||
auth = metadata.get("authorization", "")
|
||||
expected = "Bearer " + self._token
|
||||
if not hmac.compare_digest(auth, expected):
|
||||
return _AbortHandler()
|
||||
return await continuation(handler_call_details)
|
||||
|
||||
|
||||
def get_auth_interceptors(*, aio: bool = False):
|
||||
"""Return a list of gRPC interceptors for bearer token auth.
|
||||
|
||||
Args:
|
||||
aio: If True, return async-compatible interceptors for grpc.aio.server().
|
||||
If False (default), return sync interceptors for grpc.server().
|
||||
|
||||
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
|
||||
"""
|
||||
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
|
||||
if not token:
|
||||
return []
|
||||
if aio:
|
||||
return [AsyncTokenAuthInterceptor(token)]
|
||||
return [TokenAuthInterceptor(token)]
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
grpcio==1.80.0
|
||||
protobuf
|
||||
grpcio-tools
|
||||
@@ -15,6 +15,10 @@ import torch
|
||||
from TTS.api import TTS
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -93,7 +97,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
transformers==4.48.3
|
||||
accelerate
|
||||
torch==2.4.1
|
||||
torchaudio==2.4.1
|
||||
coqui-tts
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.78.1
|
||||
grpcio==1.80.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
@@ -22,6 +22,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
# Import dynamic loader for pipeline discovery
|
||||
from diffusers_dynamic_loader import (
|
||||
@@ -1042,7 +1046,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -15,6 +15,10 @@ import torch
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -165,6 +169,8 @@ def serve(address):
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
]
|
||||
,
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -14,6 +14,10 @@ import torch
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -70,7 +74,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -16,4 +16,14 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
3
backend/python/faster-whisper/requirements-l4t12.txt
Normal file
3
backend/python/faster-whisper/requirements-l4t12.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
faster-whisper
|
||||
3
backend/python/faster-whisper/requirements-l4t13.txt
Normal file
3
backend/python/faster-whisper/requirements-l4t13.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
faster-whisper
|
||||
@@ -19,6 +19,10 @@ import numpy as np
|
||||
import json
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -424,6 +428,8 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -16,6 +16,10 @@ from kittentts import KittenTTS
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -77,7 +81,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -16,6 +16,10 @@ from kokoro import KPipeline
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -84,7 +88,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -21,3 +21,8 @@ if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
# spaCy is a dependency of misaki (used by kokoro for English phonemization).
|
||||
# Pre-download the model here because at runtime the portable Python environment
|
||||
# has no pip/uv, so spacy's auto-download would fail.
|
||||
python -m spacy download en_core_web_sm
|
||||
|
||||
26
backend/python/llama-cpp-quantization/Makefile
Normal file
26
backend/python/llama-cpp-quantization/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
# Version of llama.cpp to fetch convert_hf_to_gguf.py from
|
||||
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||
|
||||
.PHONY: llama-cpp-quantization
|
||||
llama-cpp-quantization:
|
||||
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: llama-cpp-quantization
|
||||
@echo "Running llama-cpp-quantization..."
|
||||
bash run.sh
|
||||
@echo "llama-cpp-quantization run."
|
||||
|
||||
.PHONY: test
|
||||
test: llama-cpp-quantization
|
||||
@echo "Testing llama-cpp-quantization..."
|
||||
bash test.sh
|
||||
@echo "llama-cpp-quantization tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
426
backend/python/llama-cpp-quantization/backend.py
Normal file
426
backend/python/llama-cpp-quantization/backend.py
Normal file
@@ -0,0 +1,426 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
llama.cpp quantization backend for LocalAI.
|
||||
|
||||
Downloads HuggingFace models, converts them to GGUF format using
|
||||
convert_hf_to_gguf.py, and quantizes using llama-quantize.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Tracks a running quantization job."""
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = None
|
||||
self.process = None # subprocess handle for killing
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.jobs = {} # job_id -> ActiveJob
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=b"OK")
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Accept LoadModel — actual work happens in StartQuantization."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def StartQuantization(self, request, context):
|
||||
job_id = request.job_id
|
||||
if job_id in self.jobs:
|
||||
return backend_pb2.QuantizationJobResult(
|
||||
job_id=job_id,
|
||||
success=False,
|
||||
message=f"Job {job_id} already exists",
|
||||
)
|
||||
|
||||
job = ActiveJob(job_id)
|
||||
self.jobs[job_id] = job
|
||||
|
||||
job.thread = threading.Thread(
|
||||
target=self._do_quantization,
|
||||
args=(job, request),
|
||||
daemon=True,
|
||||
)
|
||||
job.thread.start()
|
||||
|
||||
return backend_pb2.QuantizationJobResult(
|
||||
job_id=job_id,
|
||||
success=True,
|
||||
message="Quantization job started",
|
||||
)
|
||||
|
||||
def _send_progress(self, job, status, message, progress_percent=0.0, output_file="", extra_metrics=None):
|
||||
update = backend_pb2.QuantizationProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
progress_percent=progress_percent,
|
||||
status=status,
|
||||
message=message,
|
||||
output_file=output_file,
|
||||
extra_metrics=extra_metrics or {},
|
||||
)
|
||||
job.progress_queue.put(update)
|
||||
|
||||
def _do_quantization(self, job, request):
|
||||
try:
|
||||
model = request.model
|
||||
quant_type = request.quantization_type or "q4_k_m"
|
||||
output_dir = request.output_dir
|
||||
extra_options = dict(request.extra_options) if request.extra_options else {}
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped before starting")
|
||||
return
|
||||
|
||||
# Step 1: Download / resolve model
|
||||
self._send_progress(job, "downloading", f"Resolving model: {model}", progress_percent=0.0)
|
||||
|
||||
model_path = self._resolve_model(job, model, output_dir, extra_options)
|
||||
if model_path is None:
|
||||
return # error already sent
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped during download")
|
||||
return
|
||||
|
||||
# Step 2: Convert to f16 GGUF
|
||||
self._send_progress(job, "converting", "Converting model to GGUF (f16)...", progress_percent=30.0)
|
||||
|
||||
f16_gguf_path = os.path.join(output_dir, "model-f16.gguf")
|
||||
if not self._convert_to_gguf(job, model_path, f16_gguf_path, extra_options):
|
||||
return # error already sent
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||
return
|
||||
|
||||
# Step 3: Quantize
|
||||
# If the user requested f16, skip quantization — the f16 GGUF is the final output
|
||||
if quant_type.lower() in ("f16", "fp16"):
|
||||
output_file = f16_gguf_path
|
||||
self._send_progress(
|
||||
job, "completed",
|
||||
f"Model converted to f16 GGUF: {output_file}",
|
||||
progress_percent=100.0,
|
||||
output_file=output_file,
|
||||
extra_metrics=self._file_metrics(output_file),
|
||||
)
|
||||
return
|
||||
|
||||
output_file = os.path.join(output_dir, f"model-{quant_type}.gguf")
|
||||
self._send_progress(job, "quantizing", f"Quantizing to {quant_type}...", progress_percent=50.0)
|
||||
|
||||
if not self._quantize(job, f16_gguf_path, output_file, quant_type):
|
||||
return # error already sent
|
||||
|
||||
# Clean up f16 intermediate file to save disk space
|
||||
try:
|
||||
os.remove(f16_gguf_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
self._send_progress(
|
||||
job, "completed",
|
||||
f"Quantization complete: {quant_type}",
|
||||
progress_percent=100.0,
|
||||
output_file=output_file,
|
||||
extra_metrics=self._file_metrics(output_file),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||
|
||||
def _resolve_model(self, job, model, output_dir, extra_options):
|
||||
"""Download model from HuggingFace or return local path."""
|
||||
# If it's a local path that exists, use it directly
|
||||
if os.path.isdir(model):
|
||||
return model
|
||||
|
||||
# If it looks like a GGUF file path, use it directly
|
||||
if os.path.isfile(model) and model.endswith(".gguf"):
|
||||
return model
|
||||
|
||||
# Download from HuggingFace
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
hf_token = extra_options.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
cache_dir = os.path.join(output_dir, "hf_cache")
|
||||
|
||||
self._send_progress(job, "downloading", f"Downloading {model} from HuggingFace...", progress_percent=5.0)
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=model,
|
||||
cache_dir=cache_dir,
|
||||
token=hf_token,
|
||||
ignore_patterns=["*.md", "*.txt", "LICENSE*", ".gitattributes"],
|
||||
)
|
||||
|
||||
self._send_progress(job, "downloading", f"Downloaded {model}", progress_percent=25.0)
|
||||
return local_path
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "gated" in error_msg.lower() or "access" in error_msg.lower():
|
||||
self._send_progress(
|
||||
job, "failed",
|
||||
f"Access denied for {model}. This model may be gated — "
|
||||
f"please accept the license at https://huggingface.co/{model} "
|
||||
f"and provide your HF token in extra_options.",
|
||||
)
|
||||
else:
|
||||
self._send_progress(job, "failed", f"Failed to download model: {error_msg}")
|
||||
return None
|
||||
|
||||
def _convert_to_gguf(self, job, model_path, output_path, extra_options):
|
||||
"""Convert HF model to f16 GGUF using convert_hf_to_gguf.py."""
|
||||
# If the model_path is already a GGUF file, just use it as-is
|
||||
if isinstance(model_path, str) and model_path.endswith(".gguf"):
|
||||
# Copy or symlink the GGUF file
|
||||
import shutil
|
||||
shutil.copy2(model_path, output_path)
|
||||
return True
|
||||
|
||||
# Find convert_hf_to_gguf.py
|
||||
convert_script = self._find_convert_script()
|
||||
if convert_script is None:
|
||||
self._send_progress(job, "failed", "convert_hf_to_gguf.py not found. Install it via the backend's install.sh.")
|
||||
return False
|
||||
|
||||
cmd = [
|
||||
sys.executable, convert_script,
|
||||
model_path,
|
||||
"--outfile", output_path,
|
||||
"--outtype", "f16",
|
||||
]
|
||||
|
||||
self._send_progress(job, "converting", "Running convert_hf_to_gguf.py...", progress_percent=35.0)
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
job.process = process
|
||||
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._send_progress(job, "converting", line, progress_percent=40.0)
|
||||
if job.stop_event.is_set():
|
||||
process.kill()
|
||||
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||
return False
|
||||
|
||||
process.wait()
|
||||
job.process = None
|
||||
|
||||
if process.returncode != 0:
|
||||
self._send_progress(job, "failed", f"convert_hf_to_gguf.py failed with exit code {process.returncode}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Conversion failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def _quantize(self, job, input_path, output_path, quant_type):
|
||||
"""Quantize a GGUF file using llama-quantize."""
|
||||
quantize_bin = self._find_quantize_binary()
|
||||
if quantize_bin is None:
|
||||
self._send_progress(job, "failed", "llama-quantize binary not found. Ensure it is installed and in PATH.")
|
||||
return False
|
||||
|
||||
cmd = [quantize_bin, input_path, output_path, quant_type]
|
||||
|
||||
self._send_progress(job, "quantizing", f"Running llama-quantize ({quant_type})...", progress_percent=55.0)
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
job.process = process
|
||||
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
# Try to parse progress from llama-quantize output
|
||||
progress = self._parse_quantize_progress(line)
|
||||
pct = 55.0 + (progress * 0.40) if progress else 60.0
|
||||
self._send_progress(job, "quantizing", line, progress_percent=pct)
|
||||
if job.stop_event.is_set():
|
||||
process.kill()
|
||||
self._send_progress(job, "stopped", "Job stopped during quantization")
|
||||
return False
|
||||
|
||||
process.wait()
|
||||
job.process = None
|
||||
|
||||
if process.returncode != 0:
|
||||
self._send_progress(job, "failed", f"llama-quantize failed with exit code {process.returncode}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def _parse_quantize_progress(self, line):
|
||||
"""Try to parse a progress percentage from llama-quantize output."""
|
||||
# llama-quantize typically outputs lines like:
|
||||
# [ 123/ 1234] quantizing blk.0.attn_k.weight ...
|
||||
match = re.search(r'\[\s*(\d+)\s*/\s*(\d+)\]', line)
|
||||
if match:
|
||||
current = int(match.group(1))
|
||||
total = int(match.group(2))
|
||||
if total > 0:
|
||||
return current / total
|
||||
return None
|
||||
|
||||
def _find_convert_script(self):
|
||||
"""Find convert_hf_to_gguf.py in known locations."""
|
||||
candidates = [
|
||||
# Same directory as this backend
|
||||
os.path.join(os.path.dirname(__file__), "convert_hf_to_gguf.py"),
|
||||
# Installed via install.sh
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "convert_hf_to_gguf.py"),
|
||||
]
|
||||
|
||||
# Also check if it's on PATH
|
||||
import shutil
|
||||
path_script = shutil.which("convert_hf_to_gguf.py")
|
||||
if path_script:
|
||||
candidates.append(path_script)
|
||||
|
||||
for candidate in candidates:
|
||||
if os.path.isfile(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _find_quantize_binary(self):
|
||||
"""Find llama-quantize binary."""
|
||||
import shutil
|
||||
|
||||
# Check common names on PATH
|
||||
for name in ["llama-quantize", "quantize"]:
|
||||
path = shutil.which(name)
|
||||
if path:
|
||||
return path
|
||||
|
||||
# Check in the backend directory (built by install.sh)
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
for name in ["llama-quantize", "quantize"]:
|
||||
candidate = os.path.join(backend_dir, name)
|
||||
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def _file_metrics(self, filepath):
|
||||
"""Return file size metrics."""
|
||||
try:
|
||||
size_bytes = os.path.getsize(filepath)
|
||||
return {"file_size_mb": size_bytes / (1024 * 1024)}
|
||||
except OSError:
|
||||
return {}
|
||||
|
||||
def QuantizationProgress(self, request, context):
|
||||
job_id = request.job_id
|
||||
job = self.jobs.get(job_id)
|
||||
if job is None:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
yield update
|
||||
# If this is a terminal status, stop streaming
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
except queue.Empty:
|
||||
# Check if the thread is still alive
|
||||
if job.thread and not job.thread.is_alive():
|
||||
# Thread finished but no terminal update — drain queue
|
||||
while not job.progress_queue.empty():
|
||||
update = job.progress_queue.get_nowait()
|
||||
yield update
|
||||
break
|
||||
# Check if client disconnected
|
||||
if context.is_active() is False:
|
||||
break
|
||||
|
||||
def StopQuantization(self, request, context):
|
||||
job_id = request.job_id
|
||||
job = self.jobs.get(job_id)
|
||||
if job is None:
|
||||
return backend_pb2.Result(success=False, message=f"Job {job_id} not found")
|
||||
|
||||
job.stop_event.set()
|
||||
if job.process:
|
||||
try:
|
||||
job.process.kill()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return backend_pb2.Result(success=True, message="Stop signal sent")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"Quantization backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="llama.cpp quantization gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0))
|
||||
serve(args.addr)
|
||||
58
backend/python/llama-cpp-quantization/install.sh
Executable file
58
backend/python/llama-cpp-quantization/install.sh
Executable file
@@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade "
|
||||
installRequirements
|
||||
|
||||
# Fetch convert_hf_to_gguf.py from llama.cpp
|
||||
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
curl -L --fail --retry 3 \
|
||||
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py."
|
||||
fi
|
||||
|
||||
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||
pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
pip install "gguf>=0.16.0"
|
||||
}
|
||||
else
|
||||
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
uv pip install "gguf>=0.16.0"
|
||||
}
|
||||
fi
|
||||
|
||||
# Build llama-quantize from llama.cpp if not already present
|
||||
QUANTIZE_BIN="${EDIR}/llama-quantize"
|
||||
if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then
|
||||
if command -v cmake &>/dev/null; then
|
||||
echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
LLAMA_CPP_SRC="${EDIR}/llama.cpp"
|
||||
if [ ! -d "${LLAMA_CPP_SRC}" ]; then
|
||||
git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \
|
||||
https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \
|
||||
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}"
|
||||
fi
|
||||
cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF
|
||||
cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)"
|
||||
cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}"
|
||||
chmod +x "${QUANTIZE_BIN}"
|
||||
echo "Built llama-quantize at ${QUANTIZE_BIN}"
|
||||
else
|
||||
echo "Warning: cmake not found — llama-quantize will not be available. Install cmake or provide llama-quantize on PATH."
|
||||
fi
|
||||
fi
|
||||
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.10.0
|
||||
transformers>=4.56.2
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
@@ -0,0 +1,4 @@
|
||||
torch==2.10.0
|
||||
transformers>=4.56.2
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
3
backend/python/llama-cpp-quantization/requirements.txt
Normal file
3
backend/python/llama-cpp-quantization/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
10
backend/python/llama-cpp-quantization/run.sh
Executable file
10
backend/python/llama-cpp-quantization/run.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
99
backend/python/llama-cpp-quantization/test.py
Normal file
99
backend/python/llama-cpp-quantization/test.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Test script for the llama-cpp-quantization gRPC backend.
|
||||
|
||||
Downloads a small model (functiongemma-270m-it), converts it to GGUF,
|
||||
and quantizes it to q4_k_m.
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
SERVER_ADDR = "localhost:50051"
|
||||
# Small model for CI testing (~540MB)
|
||||
TEST_MODEL = "unsloth/functiongemma-270m-it"
|
||||
|
||||
|
||||
class TestQuantizationBackend(unittest.TestCase):
|
||||
"""Tests for the llama-cpp-quantization gRPC service."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", SERVER_ADDR]
|
||||
)
|
||||
time.sleep(5)
|
||||
cls.output_dir = tempfile.mkdtemp(prefix="quantize-test-")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.service.kill()
|
||||
cls.service.wait()
|
||||
# Clean up output directory
|
||||
if os.path.isdir(cls.output_dir):
|
||||
shutil.rmtree(cls.output_dir, ignore_errors=True)
|
||||
|
||||
def _channel(self):
|
||||
return grpc.insecure_channel(SERVER_ADDR)
|
||||
|
||||
def test_01_health(self):
|
||||
"""Test that the server starts and responds to health checks."""
|
||||
with self._channel() as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b"OK")
|
||||
|
||||
def test_02_quantize_small_model(self):
|
||||
"""Download, convert, and quantize functiongemma-270m-it to q4_k_m."""
|
||||
with self._channel() as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
|
||||
job_id = "test-quantize-001"
|
||||
|
||||
# Start quantization
|
||||
result = stub.StartQuantization(
|
||||
backend_pb2.QuantizationRequest(
|
||||
model=TEST_MODEL,
|
||||
quantization_type="q4_k_m",
|
||||
output_dir=self.output_dir,
|
||||
job_id=job_id,
|
||||
)
|
||||
)
|
||||
self.assertTrue(result.success, f"StartQuantization failed: {result.message}")
|
||||
self.assertEqual(result.job_id, job_id)
|
||||
|
||||
# Stream progress until completion
|
||||
final_status = None
|
||||
output_file = None
|
||||
for update in stub.QuantizationProgress(
|
||||
backend_pb2.QuantizationProgressRequest(job_id=job_id)
|
||||
):
|
||||
print(f" [{update.status}] {update.progress_percent:.1f}% - {update.message}")
|
||||
final_status = update.status
|
||||
if update.output_file:
|
||||
output_file = update.output_file
|
||||
|
||||
self.assertEqual(final_status, "completed", f"Expected completed, got {final_status}")
|
||||
self.assertIsNotNone(output_file, "No output_file in progress updates")
|
||||
self.assertTrue(os.path.isfile(output_file), f"Output file not found: {output_file}")
|
||||
|
||||
# Verify the output is a valid GGUF file (starts with "GGUF" magic)
|
||||
with open(output_file, "rb") as f:
|
||||
magic = f.read(4)
|
||||
self.assertEqual(magic, b"GGUF", f"Output file does not have GGUF magic: {magic!r}")
|
||||
|
||||
# Verify reasonable file size (q4_k_m of 270M model should be ~150-400MB)
|
||||
size_mb = os.path.getsize(output_file) / (1024 * 1024)
|
||||
print(f" Output file size: {size_mb:.1f} MB")
|
||||
self.assertGreater(size_mb, 10, "Output file suspiciously small")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/llama-cpp-quantization/test.sh
Executable file
11
backend/python/llama-cpp-quantization/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -15,6 +15,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_audio.tts.utils import load_model
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
@@ -436,7 +440,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -23,6 +23,10 @@ import tempfile
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
@@ -468,6 +472,8 @@ async def serve(address):
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -12,6 +12,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_vlm import load, generate, stream_generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config, load_image
|
||||
@@ -446,7 +450,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -12,6 +12,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
from mlx_lm import load, generate, stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||
@@ -421,7 +425,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -17,6 +17,10 @@ from moonshine_voice import (
|
||||
)
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -128,7 +132,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,6 +14,10 @@ import torch
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -119,7 +123,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -4,3 +4,4 @@ certifi
|
||||
packaging==24.1
|
||||
setuptools
|
||||
pyarrow==20.0.0
|
||||
pybind11
|
||||
|
||||
@@ -15,6 +15,10 @@ from neuttsair.neutts import NeuTTSAir
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -130,7 +134,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,6 +14,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import outetts
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@@ -116,7 +120,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
|
||||
@@ -16,6 +16,10 @@ import torch
|
||||
from pocket_tts import TTSModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -225,7 +229,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -14,6 +14,10 @@ import torch
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -143,7 +147,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.language and request.language.strip():
|
||||
language = request.language.strip()
|
||||
|
||||
results = self.model.transcribe(audio=audio_path, language=language)
|
||||
context = ""
|
||||
if request.prompt and request.prompt.strip():
|
||||
context = request.prompt.strip()
|
||||
|
||||
results = self.model.transcribe(audio=audio_path, language=language, context=context)
|
||||
|
||||
if not results:
|
||||
return backend_pb2.TranscriptResult(segments=[], text="")
|
||||
@@ -184,7 +192,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -23,6 +23,10 @@ import hashlib
|
||||
import pickle
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
|
||||
def is_float(s):
|
||||
@@ -900,6 +904,8 @@ def serve(address):
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
@@ -14,6 +14,10 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
from rerankers import Reranker
|
||||
|
||||
@@ -97,7 +101,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
grpcio==1.80.0
|
||||
protobuf
|
||||
certifi
|
||||
@@ -13,6 +13,10 @@ import base64
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
|
||||
import requests
|
||||
|
||||
@@ -139,7 +143,9 @@ def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
|
||||
@@ -16,16 +16,22 @@ import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
|
||||
import transformers as transformers_module
|
||||
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from scipy.io import wavfile
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Backward-compat aliases for model types
|
||||
TYPE_ALIASES = {"Mamba": "MambaForCausalLM"}
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
@@ -52,32 +58,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
A gRPC method that returns the health status of the backend service.
|
||||
|
||||
Args:
|
||||
request: A HealthRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Reply object that contains the health status of the backend service.
|
||||
"""
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""
|
||||
A gRPC method that loads a model into memory.
|
||||
|
||||
Args:
|
||||
request: A LoadModelRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Result object that contains the result of the LoadModel operation.
|
||||
"""
|
||||
|
||||
model_name = request.Model
|
||||
|
||||
|
||||
# Check to see if the Model exists in the filesystem already.
|
||||
if os.path.exists(request.ModelFile):
|
||||
model_name = request.ModelFile
|
||||
@@ -88,8 +73,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
self.CUDA = torch.cuda.is_available()
|
||||
self.OV=False
|
||||
self.DiaTTS=False
|
||||
self.GenericTTS=False
|
||||
self.SentenceTransformer = False
|
||||
self.processor = None
|
||||
|
||||
device_map="cpu"
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
@@ -101,7 +87,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# Parse options from request.Options
|
||||
self.options = {}
|
||||
options = request.Options
|
||||
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when generating
|
||||
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
|
||||
@@ -123,7 +109,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print(f"Parsed options: {self.options}", file=sys.stderr)
|
||||
|
||||
if self.CUDA:
|
||||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
||||
from transformers import BitsAndBytesConfig
|
||||
if request.MainGPU:
|
||||
device_map=request.MainGPU
|
||||
else:
|
||||
@@ -140,40 +126,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
quantization = BitsAndBytesConfig(
|
||||
load_in_4bit=False,
|
||||
bnb_4bit_compute_dtype = None,
|
||||
load_in_8bit=True,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
|
||||
try:
|
||||
if request.Type == "AutoModelForCausalLM":
|
||||
if XPU:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||
if XPU and request.Type == "AutoModelForCausalLM":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||
|
||||
device_map="xpu"
|
||||
compute=torch.float16
|
||||
if request.Quantization == "xpu_4bit":
|
||||
xpu_4bit = True
|
||||
xpu_8bit = False
|
||||
elif request.Quantization == "xpu_8bit":
|
||||
xpu_4bit = False
|
||||
xpu_8bit = True
|
||||
else:
|
||||
xpu_4bit = False
|
||||
xpu_8bit = False
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
device_map=device_map,
|
||||
load_in_4bit=xpu_4bit,
|
||||
load_in_8bit=xpu_8bit,
|
||||
torch_dtype=compute)
|
||||
device_map="xpu"
|
||||
compute=torch.float16
|
||||
if request.Quantization == "xpu_4bit":
|
||||
xpu_4bit = True
|
||||
xpu_8bit = False
|
||||
elif request.Quantization == "xpu_8bit":
|
||||
xpu_4bit = False
|
||||
xpu_8bit = True
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute)
|
||||
xpu_4bit = False
|
||||
xpu_8bit = False
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
device_map=device_map,
|
||||
load_in_4bit=xpu_4bit,
|
||||
load_in_8bit=xpu_8bit,
|
||||
torch_dtype=compute)
|
||||
elif request.Type == "OVModelForCausalLM":
|
||||
from optimum.intel.openvino import OVModelForCausalLM
|
||||
from openvino.runtime import Core
|
||||
@@ -185,14 +162,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
devices = Core().available_devices
|
||||
if "GPU" in " ".join(devices):
|
||||
device_map="AUTO:GPU"
|
||||
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
|
||||
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
|
||||
if "CPU" or "NPU" in device_map:
|
||||
if "-CPU" or "-NPU" not in device_map:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
||||
else:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
||||
self.model = OVModelForCausalLM.from_pretrained(model_name,
|
||||
self.model = OVModelForCausalLM.from_pretrained(model_name,
|
||||
compile=True,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
ov_config=ovconfig,
|
||||
@@ -209,59 +184,60 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
devices = Core().available_devices
|
||||
if "GPU" in " ".join(devices):
|
||||
device_map="AUTO:GPU"
|
||||
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
|
||||
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
|
||||
if "CPU" or "NPU" in device_map:
|
||||
if "-CPU" or "-NPU" not in device_map:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
||||
else:
|
||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
||||
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
|
||||
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
|
||||
compile=True,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
ov_config=ovconfig,
|
||||
ov_config=ovconfig,
|
||||
export=True,
|
||||
device=device_map)
|
||||
self.OV = True
|
||||
elif request.Type == "MusicgenForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
elif request.Type == "DiaForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
print("DiaForConditionalGeneration", file=sys.stderr)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
||||
if self.CUDA:
|
||||
self.model = self.model.to("cuda")
|
||||
self.processor = self.processor.to("cuda")
|
||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
||||
self.DiaTTS = True
|
||||
elif request.Type == "SentenceTransformer":
|
||||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
elif request.Type == "Mamba":
|
||||
autoTokenizer = False
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = MambaForCausalLM.from_pretrained(model_name)
|
||||
else:
|
||||
print("Automodel", file=sys.stderr)
|
||||
self.model = AutoModel.from_pretrained(model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
use_safetensors=True,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute)
|
||||
# Generic: dynamically resolve model class from transformers
|
||||
model_type = TYPE_ALIASES.get(request.Type, request.Type)
|
||||
ModelClass = AutoModel # default
|
||||
if model_type and hasattr(transformers_module, model_type):
|
||||
ModelClass = getattr(transformers_module, model_type)
|
||||
print(f"Using model class: {model_type}", file=sys.stderr)
|
||||
else:
|
||||
print(f"Using default AutoModel (type={request.Type!r})", file=sys.stderr)
|
||||
|
||||
self.model = ModelClass.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
quantization_config=quantization,
|
||||
device_map=device_map,
|
||||
torch_dtype=compute,
|
||||
)
|
||||
|
||||
# Try to load a processor (needed for TTS/audio models)
|
||||
try:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=request.TrustRemoteCode,
|
||||
)
|
||||
self.GenericTTS = True
|
||||
print(f"Loaded processor for {model_name}", file=sys.stderr)
|
||||
except Exception:
|
||||
self.processor = None
|
||||
|
||||
if request.ContextSize > 0:
|
||||
self.max_tokens = request.ContextSize
|
||||
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||
self.max_tokens = self.model.config.max_position_embeddings
|
||||
else:
|
||||
self.max_tokens = self.options.get("max_new_tokens", 512)
|
||||
|
||||
|
||||
if autoTokenizer:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.XPU = False
|
||||
|
||||
if XPU and self.OV == False:
|
||||
@@ -275,22 +251,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as err:
|
||||
print("Error:", err, file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""
|
||||
A gRPC method that calculates embeddings for a given sentence.
|
||||
|
||||
Args:
|
||||
request: An EmbeddingRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
An EmbeddingResult object that contains the calculated embeddings.
|
||||
"""
|
||||
|
||||
set_seed(request.Seed)
|
||||
# Tokenize input
|
||||
max_length = 512
|
||||
@@ -303,13 +266,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||
embeds = self.model.encode(request.Embeddings)
|
||||
else:
|
||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||
|
||||
# Create word embeddings
|
||||
if self.CUDA:
|
||||
encoded_input = encoded_input.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
|
||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||
@@ -317,11 +280,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
embeds = sentence_embeddings[0]
|
||||
return backend_pb2.EmbeddingResult(embeddings=embeds)
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
set_seed(request.Seed)
|
||||
if request.TopP < 0 or request.TopP > 1:
|
||||
request.TopP = 1
|
||||
|
||||
|
||||
if request.TopK <= 0:
|
||||
request.TopK = 50
|
||||
|
||||
@@ -334,7 +297,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
request.Temperature == None
|
||||
|
||||
prompt = request.Prompt
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
@@ -363,10 +326,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True)
|
||||
config=dict(inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
@@ -387,18 +350,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
else:
|
||||
if XPU and self.OV == False:
|
||||
outputs = self.model.generate(inputs["input_ids"],
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
pad_token=self.tokenizer.eos_token_id)
|
||||
else:
|
||||
outputs = self.model.generate(**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
top_k=request.TopK,
|
||||
do_sample=sample,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
@@ -413,31 +376,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
|
||||
async def Predict(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters.
|
||||
|
||||
Args:
|
||||
request: The predict request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
gen = self._predict(request, context, streaming=False)
|
||||
res = await gen.__anext__()
|
||||
return res
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||
|
||||
Args:
|
||||
request: The predict stream request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Result: The predict stream result.
|
||||
"""
|
||||
iterations = self._predict(request, context, streaming=True)
|
||||
try:
|
||||
async for iteration in iterations:
|
||||
@@ -455,18 +398,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.model is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
# Dynamically resolve model class if configured, otherwise default to MusicgenForConditionalGeneration
|
||||
model_type = self.options.get("model_type", "MusicgenForConditionalGeneration")
|
||||
ModelClass = getattr(transformers_module, model_type)
|
||||
self.model = ModelClass.from_pretrained(model_name)
|
||||
inputs = None
|
||||
if request.text == "":
|
||||
inputs = self.model.get_unconditional_inputs(num_samples=1)
|
||||
elif request.HasField('src'):
|
||||
# TODO SECURITY CODE GOES HERE LOL
|
||||
# WHO KNOWS IF THIS WORKS???
|
||||
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
|
||||
|
||||
|
||||
if request.HasField('src_divisor'):
|
||||
wsamples = wsamples[: len(wsamples) // request.src_divisor]
|
||||
|
||||
|
||||
inputs = self.processor(
|
||||
audio=wsamples,
|
||||
sampling_rate=sample_rate,
|
||||
@@ -480,7 +424,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
if request.HasField('duration'):
|
||||
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||
guidance = self.options.get("guidance_scale", 3.0)
|
||||
@@ -490,92 +434,97 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.HasField('sample'):
|
||||
dosample = request.sample
|
||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
|
||||
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
|
||||
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
|
||||
print("[transformers] SoundGeneration generated!", file=sys.stderr)
|
||||
|
||||
# Save audio output
|
||||
if hasattr(self.processor, 'save_audio'):
|
||||
if hasattr(self.processor, 'batch_decode'):
|
||||
try:
|
||||
audio_values = self.processor.batch_decode(audio_values)
|
||||
except Exception:
|
||||
pass
|
||||
self.processor.save_audio(audio_values, request.dst)
|
||||
else:
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
|
||||
print("[transformers] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||
print(request, file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def CallDiaTTS(self, request, context):
|
||||
"""
|
||||
Generates dialogue audio using the Dia model.
|
||||
|
||||
Args:
|
||||
request: A TTSRequest containing text dialogue and generation parameters
|
||||
context: The gRPC context
|
||||
|
||||
Returns:
|
||||
A Result object indicating success or failure
|
||||
"""
|
||||
try:
|
||||
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
|
||||
|
||||
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
|
||||
text = [request.text]
|
||||
|
||||
# Process the input
|
||||
inputs = self.processor(text=text, padding=True, return_tensors="pt")
|
||||
|
||||
# Generate audio with parameters from options or defaults
|
||||
generation_params = {
|
||||
**inputs,
|
||||
"max_new_tokens": self.max_tokens,
|
||||
"guidance_scale": self.options.get("guidance_scale", 3.0),
|
||||
"temperature": self.options.get("temperature", 1.8),
|
||||
"top_p": self.options.get("top_p", 0.90),
|
||||
"top_k": self.options.get("top_k", 45)
|
||||
}
|
||||
|
||||
outputs = self.model.generate(**generation_params)
|
||||
|
||||
# Decode and save audio
|
||||
outputs = self.processor.batch_decode(outputs)
|
||||
self.processor.save_audio(outputs, request.dst)
|
||||
|
||||
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
|
||||
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
|
||||
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||
def TTS(self, request, context):
|
||||
if self.DiaTTS:
|
||||
print("DiaTTS", file=sys.stderr)
|
||||
return self.CallDiaTTS(request, context)
|
||||
|
||||
model_name = request.model
|
||||
try:
|
||||
if self.processor is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
if self.model is None:
|
||||
if model_name == "":
|
||||
return backend_pb2.Result(success=False, message="request.model is required")
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
inputs = self.processor(
|
||||
text=[request.text],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
|
||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr)
|
||||
print("[transformers-musicgen] TTS for", file=sys.stderr)
|
||||
print(request, file=sys.stderr)
|
||||
text = request.text
|
||||
print(f"[transformers] TTS generating for text: {text[:100]}...", file=sys.stderr)
|
||||
|
||||
# Build inputs based on processor capabilities
|
||||
if request.voice and os.path.exists(request.voice):
|
||||
# Voice cloning: use chat template with reference audio
|
||||
chat_template = [{
|
||||
"role": "0",
|
||||
"content": [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "audio", "path": request.voice},
|
||||
],
|
||||
}]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat_template, tokenize=True, return_dict=True,
|
||||
).to(self.model.device, self.model.dtype)
|
||||
elif hasattr(self.processor, 'apply_chat_template'):
|
||||
# Models that use chat template format (VibeVoice, CSM, etc.)
|
||||
chat_template = [{"role": "0", "content": [{"type": "text", "text": text}]}]
|
||||
try:
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat_template, tokenize=True, return_dict=True,
|
||||
).to(self.model.device, self.model.dtype)
|
||||
except Exception:
|
||||
# Fallback if chat template fails (not all processors support it)
|
||||
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
|
||||
if self.CUDA:
|
||||
inputs = inputs.to("cuda")
|
||||
else:
|
||||
# Direct processor call (Musicgen, etc.)
|
||||
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
|
||||
if self.CUDA:
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Build generation kwargs from self.options
|
||||
gen_kwargs = {**inputs, "max_new_tokens": self.max_tokens}
|
||||
for key in ["guidance_scale", "temperature", "top_p", "top_k", "do_sample"]:
|
||||
if key in self.options:
|
||||
gen_kwargs[key] = self.options[key]
|
||||
|
||||
# Add noise scheduler if configured (e.g., for VibeVoice)
|
||||
noise_scheduler_type = self.options.get("noise_scheduler", None)
|
||||
if noise_scheduler_type:
|
||||
import diffusers
|
||||
SchedulerClass = getattr(diffusers, noise_scheduler_type)
|
||||
scheduler_kwargs = {}
|
||||
for key in ["beta_schedule", "prediction_type"]:
|
||||
if key in self.options:
|
||||
scheduler_kwargs[key] = self.options[key]
|
||||
gen_kwargs["noise_scheduler"] = SchedulerClass(**scheduler_kwargs)
|
||||
|
||||
# Generate audio
|
||||
audio = self.model.generate(**gen_kwargs)
|
||||
print("[transformers] TTS generated!", file=sys.stderr)
|
||||
|
||||
# Save audio output
|
||||
if hasattr(self.processor, 'save_audio'):
|
||||
if hasattr(self.processor, 'batch_decode'):
|
||||
try:
|
||||
audio = self.processor.batch_decode(audio)
|
||||
except Exception:
|
||||
pass
|
||||
self.processor.save_audio(audio, request.dst)
|
||||
else:
|
||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||
wavfile.write(request.dst, rate=sampling_rate, data=audio[0, 0].numpy())
|
||||
|
||||
print("[transformers] TTS saved to", request.dst, file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
@@ -587,7 +536,9 @@ async def serve(address):
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
|
||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
||||
accelerate
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,7 +2,9 @@
|
||||
torch==2.9.0
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -1,9 +1,11 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.4
|
||||
torch==2.8.0+rocm6.4
|
||||
accelerate
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -3,7 +3,9 @@ torch
|
||||
optimum[openvino]
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
accelerate
|
||||
transformers
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.78.1
|
||||
grpcio==1.80.0
|
||||
protobuf==6.33.5
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
26
backend/python/trl/Makefile
Normal file
26
backend/python/trl/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
|
||||
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||
|
||||
.PHONY: trl
|
||||
trl:
|
||||
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: trl
|
||||
@echo "Running trl..."
|
||||
bash run.sh
|
||||
@echo "trl run."
|
||||
|
||||
.PHONY: test
|
||||
test: trl
|
||||
@echo "Testing trl..."
|
||||
bash test.sh
|
||||
@echo "trl tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
866
backend/python/trl/backend.py
Normal file
866
backend/python/trl/backend.py
Normal file
@@ -0,0 +1,866 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TRL fine-tuning backend for LocalAI.
|
||||
|
||||
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
|
||||
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||
|
||||
|
||||
class ProgressCallback:
|
||||
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
|
||||
|
||||
def __init__(self, job_id, progress_queue, total_epochs):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = progress_queue
|
||||
self.total_epochs = total_epochs
|
||||
|
||||
def get_callback(self):
|
||||
from transformers import TrainerCallback
|
||||
|
||||
parent = self
|
||||
|
||||
class _Callback(TrainerCallback):
|
||||
def __init__(self):
|
||||
self._train_start_time = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self._train_start_time = time.time()
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if logs is None:
|
||||
return
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
eta = 0.0
|
||||
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
|
||||
elapsed = time.time() - self._train_start_time
|
||||
remaining_steps = total_steps - state.global_step
|
||||
if state.global_step > 0:
|
||||
eta = remaining_steps * (elapsed / state.global_step)
|
||||
|
||||
extra_metrics = {}
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
|
||||
extra_metrics[k] = float(v)
|
||||
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(logs.get('epoch', 0)),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
loss=float(logs.get('loss', 0)),
|
||||
learning_rate=float(logs.get('learning_rate', 0)),
|
||||
grad_norm=float(logs.get('grad_norm', 0)),
|
||||
eval_loss=float(logs.get('eval_loss', 0)),
|
||||
eta_seconds=float(eta),
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
extra_metrics=extra_metrics,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_prediction_step(self, args, state, control, **kwargs):
|
||||
"""Send periodic updates during evaluation so the UI doesn't freeze."""
|
||||
if not hasattr(self, '_eval_update_counter'):
|
||||
self._eval_update_counter = 0
|
||||
self._eval_update_counter += 1
|
||||
# Throttle: send an update every 10 prediction steps
|
||||
if self._eval_update_counter % 10 != 0:
|
||||
return
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(state.epoch or 0),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
message=f"Evaluating... (batch {self._eval_update_counter})",
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
"""Report eval results once evaluation is done."""
|
||||
# Reset prediction counter for next eval round
|
||||
self._eval_update_counter = 0
|
||||
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
|
||||
eval_loss = 0.0
|
||||
extra_metrics = {}
|
||||
if metrics:
|
||||
eval_loss = float(metrics.get('eval_loss', 0))
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
|
||||
extra_metrics[k] = float(v)
|
||||
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(state.epoch or 0),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
eval_loss=eval_loss,
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
message=f"Evaluation complete at step {state.global_step}",
|
||||
extra_metrics=extra_metrics,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
status="saving",
|
||||
message=f"Checkpoint saved at step {state.global_step}",
|
||||
checkpoint_path=checkpoint_path,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=state.max_steps,
|
||||
progress_percent=100.0,
|
||||
status="completed",
|
||||
message="Training completed",
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
return _Callback()
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Represents an active fine-tuning job."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.trainer = None
|
||||
self.thread = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.error = None
|
||||
self.completed = False
|
||||
self.stopped = False
|
||||
|
||||
|
||||
def _is_gated_repo_error(exc):
|
||||
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
|
||||
try:
|
||||
from huggingface_hub.utils import GatedRepoError
|
||||
if isinstance(exc, GatedRepoError):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
msg = str(exc).lower()
|
||||
if "gated repo" in msg or "access to model" in msg:
|
||||
return True
|
||||
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
|
||||
if exc.response.status_code in (401, 403):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.active_job = None
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Accept LoadModel — actual model loading happens in StartFineTune."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def StartFineTune(self, request, context):
|
||||
if self.active_job is not None and not self.active_job.completed:
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id="",
|
||||
success=False,
|
||||
message="A fine-tuning job is already running",
|
||||
)
|
||||
|
||||
job_id = request.job_id if request.job_id else str(uuid.uuid4())
|
||||
job = ActiveJob(job_id)
|
||||
self.active_job = job
|
||||
|
||||
# Start training in background thread
|
||||
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
|
||||
job.thread = thread
|
||||
thread.start()
|
||||
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id=job_id,
|
||||
success=True,
|
||||
message="Fine-tuning job started",
|
||||
)
|
||||
|
||||
def _run_training(self, request, job):
|
||||
try:
|
||||
self._do_training(request, job)
|
||||
except Exception as e:
|
||||
if _is_gated_repo_error(e):
|
||||
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||
else:
|
||||
msg = f"Training failed: {e}"
|
||||
job.error = msg
|
||||
job.completed = True
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
status="failed",
|
||||
message=msg,
|
||||
)
|
||||
job.progress_queue.put(update)
|
||||
# Send sentinel
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def _do_training(self, request, job):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
extra = dict(request.extra_options)
|
||||
training_method = request.training_method or "sft"
|
||||
training_type = request.training_type or "lora"
|
||||
|
||||
# Send loading status
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
|
||||
))
|
||||
|
||||
# Determine device and dtype
|
||||
device_map = "auto" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
|
||||
|
||||
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
|
||||
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
|
||||
# Load model
|
||||
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
|
||||
if hf_token:
|
||||
model_kwargs["token"] = hf_token
|
||||
if extra.get("trust_remote_code", "false").lower() == "true":
|
||||
model_kwargs["trust_remote_code"] = True
|
||||
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
|
||||
from transformers import BitsAndBytesConfig
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
job.model = model
|
||||
job.tokenizer = tokenizer
|
||||
|
||||
# Apply LoRA if requested
|
||||
if training_type == "lora":
|
||||
from peft import LoraConfig, get_peft_model
|
||||
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
|
||||
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
|
||||
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
|
||||
|
||||
target_modules = list(request.target_modules) if request.target_modules else None
|
||||
peft_config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=target_modules or "all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
# Load dataset
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
|
||||
))
|
||||
|
||||
dataset_split = request.dataset_split or "train"
|
||||
if os.path.exists(request.dataset_source):
|
||||
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
|
||||
elif request.dataset_source.endswith('.csv'):
|
||||
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
|
||||
else:
|
||||
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||
else:
|
||||
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||
|
||||
# Eval dataset setup
|
||||
eval_dataset = None
|
||||
eval_strategy = extra.get("eval_strategy", "steps")
|
||||
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
|
||||
|
||||
if eval_strategy != "no":
|
||||
eval_split = extra.get("eval_split")
|
||||
eval_dataset_source = extra.get("eval_dataset_source")
|
||||
if eval_split:
|
||||
# Load a specific split as eval dataset
|
||||
if os.path.exists(request.dataset_source):
|
||||
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
|
||||
elif request.dataset_source.endswith('.csv'):
|
||||
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
|
||||
else:
|
||||
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||
else:
|
||||
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||
elif eval_dataset_source:
|
||||
# Load eval dataset from a separate source
|
||||
eval_dataset = load_dataset(eval_dataset_source, split="train")
|
||||
else:
|
||||
# Auto-split from training set
|
||||
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
|
||||
split = dataset.train_test_split(test_size=eval_split_ratio)
|
||||
dataset = split["train"]
|
||||
eval_dataset = split["test"]
|
||||
|
||||
if eval_strategy == "no":
|
||||
eval_dataset = None
|
||||
|
||||
# Training config
|
||||
output_dir = request.output_dir or f"./output-{job.job_id}"
|
||||
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
|
||||
batch_size = request.batch_size if request.batch_size > 0 else 2
|
||||
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
|
||||
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
|
||||
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
|
||||
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
|
||||
max_steps = request.max_steps if request.max_steps > 0 else -1
|
||||
save_steps = request.save_steps if request.save_steps > 0 else 500
|
||||
seed = request.seed if request.seed > 0 else 3407
|
||||
optimizer = request.optimizer or "adamw_torch"
|
||||
|
||||
# Checkpoint save controls
|
||||
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
|
||||
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
|
||||
|
||||
# CPU vs GPU training args (can be overridden via extra_options)
|
||||
use_cpu = not torch.cuda.is_available()
|
||||
common_train_kwargs = {}
|
||||
if use_cpu:
|
||||
common_train_kwargs["use_cpu"] = True
|
||||
common_train_kwargs["fp16"] = False
|
||||
common_train_kwargs["bf16"] = False
|
||||
common_train_kwargs["gradient_checkpointing"] = False
|
||||
else:
|
||||
common_train_kwargs["bf16"] = True
|
||||
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
|
||||
|
||||
# Allow extra_options to override training kwargs
|
||||
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
|
||||
if flag in extra:
|
||||
common_train_kwargs[flag] = extra[flag].lower() == "true"
|
||||
|
||||
# Create progress callback
|
||||
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
|
||||
|
||||
# Build save kwargs (shared across all methods)
|
||||
_save_kwargs = {}
|
||||
if save_strategy == "steps" and save_steps > 0:
|
||||
_save_kwargs["save_steps"] = save_steps
|
||||
_save_kwargs["save_strategy"] = "steps"
|
||||
elif save_strategy == "epoch":
|
||||
_save_kwargs["save_strategy"] = "epoch"
|
||||
elif save_strategy == "no":
|
||||
_save_kwargs["save_strategy"] = "no"
|
||||
else:
|
||||
_save_kwargs["save_steps"] = save_steps
|
||||
_save_kwargs["save_strategy"] = "steps"
|
||||
if save_total_limit:
|
||||
_save_kwargs["save_total_limit"] = save_total_limit
|
||||
|
||||
# Eval kwargs
|
||||
_eval_kwargs = {}
|
||||
if eval_dataset is not None:
|
||||
_eval_kwargs["eval_strategy"] = eval_strategy
|
||||
_eval_kwargs["eval_steps"] = eval_steps
|
||||
|
||||
# Common training arguments shared by all methods
|
||||
_common_args = dict(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
learning_rate=lr,
|
||||
gradient_accumulation_steps=grad_accum,
|
||||
warmup_steps=warmup_steps,
|
||||
weight_decay=weight_decay,
|
||||
max_steps=max_steps,
|
||||
seed=seed,
|
||||
optim=optimizer,
|
||||
logging_steps=1,
|
||||
report_to="none",
|
||||
**_save_kwargs,
|
||||
**common_train_kwargs,
|
||||
**_eval_kwargs,
|
||||
)
|
||||
|
||||
# Select trainer based on training method
|
||||
if training_method == "sft":
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
max_length = int(extra.get("max_seq_length", "512"))
|
||||
packing = extra.get("packing", "false").lower() == "true"
|
||||
|
||||
training_args = SFTConfig(
|
||||
max_length=max_length,
|
||||
packing=packing,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "dpo":
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
loss_type = extra.get("loss_type", "sigmoid")
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = DPOConfig(
|
||||
beta=beta,
|
||||
loss_type=loss_type,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "grpo":
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
num_generations = int(extra.get("num_generations", "4"))
|
||||
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||
|
||||
training_args = GRPOConfig(
|
||||
num_generations=num_generations,
|
||||
max_completion_length=max_completion_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
# GRPO requires reward functions passed via extra_options as a JSON list
|
||||
from reward_functions import build_reward_functions
|
||||
|
||||
reward_funcs = []
|
||||
if extra.get("reward_funcs"):
|
||||
reward_funcs = build_reward_functions(extra["reward_funcs"])
|
||||
|
||||
if not reward_funcs:
|
||||
raise ValueError(
|
||||
"GRPO requires at least one reward function. "
|
||||
"Specify reward_functions in the request or "
|
||||
"reward_funcs in extra_options."
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=reward_funcs,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "orpo":
|
||||
from trl import ORPOTrainer, ORPOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = ORPOConfig(
|
||||
beta=beta,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "kto":
|
||||
from trl import KTOTrainer, KTOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = KTOConfig(
|
||||
beta=beta,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "rloo":
|
||||
from trl import RLOOTrainer, RLOOConfig
|
||||
|
||||
num_generations = int(extra.get("num_generations", "4"))
|
||||
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||
|
||||
training_args = RLOOConfig(
|
||||
num_generations=num_generations,
|
||||
max_new_tokens=max_completion_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = RLOOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "reward":
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = RewardConfig(
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported training method: {training_method}. "
|
||||
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
|
||||
|
||||
job.trainer = trainer
|
||||
|
||||
# Start training
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="training", message="Training started",
|
||||
))
|
||||
|
||||
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
|
||||
trainer.train(resume_from_checkpoint=resume_ckpt)
|
||||
|
||||
# Save final model
|
||||
trainer.save_model(output_dir)
|
||||
if tokenizer:
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
job.completed = True
|
||||
# Sentinel to signal stream end
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def FineTuneProgress(self, request, context):
|
||||
if self.active_job is None or self.active_job.job_id != request.job_id:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(f"Job {request.job_id} not found")
|
||||
return
|
||||
|
||||
job = self.active_job
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
if update is None:
|
||||
break
|
||||
yield update
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
except queue.Empty:
|
||||
if job.completed or job.stopped:
|
||||
break
|
||||
if not context.is_active():
|
||||
break
|
||||
continue
|
||||
|
||||
def StopFineTune(self, request, context):
|
||||
# Stopping is handled by killing the process from Go via ShutdownModel.
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def ListCheckpoints(self, request, context):
|
||||
output_dir = request.output_dir
|
||||
if not os.path.isdir(output_dir):
|
||||
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
|
||||
|
||||
checkpoints = []
|
||||
for entry in sorted(os.listdir(output_dir)):
|
||||
if entry.startswith("checkpoint-"):
|
||||
ckpt_path = os.path.join(output_dir, entry)
|
||||
if not os.path.isdir(ckpt_path):
|
||||
continue
|
||||
step = 0
|
||||
try:
|
||||
step = int(entry.split("-")[1])
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
# Try to read trainer_state.json for metadata
|
||||
loss = 0.0
|
||||
epoch = 0.0
|
||||
state_file = os.path.join(ckpt_path, "trainer_state.json")
|
||||
if os.path.exists(state_file):
|
||||
try:
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
if state.get("log_history"):
|
||||
last_log = state["log_history"][-1]
|
||||
loss = last_log.get("loss", 0.0)
|
||||
epoch = last_log.get("epoch", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
created_at = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
time.gmtime(os.path.getmtime(ckpt_path)),
|
||||
)
|
||||
|
||||
checkpoints.append(backend_pb2.CheckpointInfo(
|
||||
path=ckpt_path,
|
||||
step=step,
|
||||
epoch=float(epoch),
|
||||
loss=float(loss),
|
||||
created_at=created_at,
|
||||
))
|
||||
|
||||
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
|
||||
|
||||
def ExportModel(self, request, context):
|
||||
export_format = request.export_format or "lora"
|
||||
output_path = request.output_path
|
||||
checkpoint_path = request.checkpoint_path
|
||||
|
||||
# Extract HF token for gated model access
|
||||
extra = dict(request.extra_options) if request.extra_options else {}
|
||||
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
|
||||
if not checkpoint_path or not os.path.isdir(checkpoint_path):
|
||||
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
if export_format == "lora":
|
||||
# Just copy the adapter files
|
||||
import shutil
|
||||
for f in os.listdir(checkpoint_path):
|
||||
src = os.path.join(checkpoint_path, f)
|
||||
dst = os.path.join(output_path, f)
|
||||
if os.path.isfile(src):
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
elif export_format in ("merged_16bit", "merged_4bit"):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
base_model_name = request.model
|
||||
if not base_model_name:
|
||||
return backend_pb2.Result(success=False, message="Base model name required for merge export")
|
||||
|
||||
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||
merged = model.merge_and_unload()
|
||||
merged.save_pretrained(output_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||
tokenizer.save_pretrained(output_path)
|
||||
|
||||
elif export_format == "gguf":
|
||||
import torch
|
||||
import subprocess
|
||||
import shutil
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
base_model_name = request.model
|
||||
if not base_model_name:
|
||||
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
|
||||
|
||||
# Step 1: Merge LoRA into base model
|
||||
merge_dir = os.path.join(output_path, "_hf_merged")
|
||||
os.makedirs(merge_dir, exist_ok=True)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||
merged = model.merge_and_unload()
|
||||
merged.save_pretrained(merge_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||
tokenizer.save_pretrained(merge_dir)
|
||||
|
||||
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
|
||||
# Gemma models need this file for GGUF conversion to use the
|
||||
# SentencePiece path; without it, the script falls back to BPE
|
||||
# handling which fails on unrecognized pre-tokenizer hashes.
|
||||
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
|
||||
if not os.path.exists(sp_model_path):
|
||||
sp_copied = False
|
||||
# Method 1: Load the slow tokenizer which keeps the SP model file
|
||||
try:
|
||||
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
|
||||
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
|
||||
sp_copied = True
|
||||
print(f"Copied tokenizer.model from slow tokenizer cache")
|
||||
except Exception as e:
|
||||
print(f"Slow tokenizer method failed: {e}")
|
||||
# Method 2: Download from HF hub
|
||||
if not sp_copied:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(cached_sp, sp_model_path)
|
||||
sp_copied = True
|
||||
print(f"Copied tokenizer.model from HF hub")
|
||||
except Exception as e:
|
||||
print(f"HF hub download method failed: {e}")
|
||||
if not sp_copied:
|
||||
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
|
||||
"GGUF conversion may fail for SentencePiece models.")
|
||||
|
||||
# Free GPU memory before conversion
|
||||
del merged, model, base_model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
|
||||
quant = request.quantization_method or "auto"
|
||||
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
|
||||
outtype = outtype_map.get(quant, "f16")
|
||||
|
||||
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
|
||||
gguf_path = os.path.join(output_path, gguf_filename)
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
|
||||
if not os.path.exists(convert_script):
|
||||
return backend_pb2.Result(success=False,
|
||||
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
|
||||
|
||||
# Log merge_dir contents for debugging conversion issues
|
||||
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
|
||||
print(f"Merge dir contents: {merge_files}", flush=True)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["NO_LOCAL_GGUF"] = "1"
|
||||
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
|
||||
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
|
||||
if conv_result.returncode != 0:
|
||||
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
|
||||
return backend_pb2.Result(success=False,
|
||||
message=f"GGUF conversion failed: {diag}")
|
||||
|
||||
# Clean up intermediate merged model
|
||||
shutil.rmtree(merge_dir, ignore_errors=True)
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
|
||||
|
||||
except Exception as e:
|
||||
if _is_gated_repo_error(e):
|
||||
return backend_pb2.Result(success=False,
|
||||
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
|
||||
|
||||
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
|
||||
interceptors=get_auth_interceptors(),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
# Handle graceful shutdown
|
||||
def stop(signum, frame):
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, stop)
|
||||
signal.signal(signal.SIGINT, stop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
serve(args.addr)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user