mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 22:29:54 -04:00
Compare commits
129 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9f10f2f50 | ||
|
|
b95b0b72ff | ||
|
|
26f1b94f4d | ||
|
|
2d40725ca2 | ||
|
|
6c635e8353 | ||
|
|
cc5f33ce95 | ||
|
|
ba7cdd532a | ||
|
|
6b6c136210 | ||
|
|
e587ecc485 | ||
|
|
f259036a27 | ||
|
|
221ff0f28f | ||
|
|
16d5cb00bd | ||
|
|
952635fba6 | ||
|
|
3cc05af2e5 | ||
|
|
87a63316c7 | ||
|
|
efdcbbe332 | ||
|
|
b4fff9293d | ||
|
|
8180221b7e | ||
|
|
52a9755e08 | ||
|
|
a2a1d919f9 | ||
|
|
a3d37931ec | ||
|
|
5b2e25ebb0 | ||
|
|
b0b37a472f | ||
|
|
3db12eaa7a | ||
|
|
8862e3ce60 | ||
|
|
80699a3f70 | ||
|
|
309a59f61e | ||
|
|
65c9380389 | ||
|
|
79963c56bf | ||
|
|
7004ce0b78 | ||
|
|
702d0e0e4d | ||
|
|
d6de208d6c | ||
|
|
7451145e0c | ||
|
|
cfda3dd0df | ||
|
|
e0eb2fd734 | ||
|
|
dd3376e0a9 | ||
|
|
520e1ce3cd | ||
|
|
3d738164b7 | ||
|
|
56db76599a | ||
|
|
ad57cdfefe | ||
|
|
c2f7d1c18b | ||
|
|
afe79568d6 | ||
|
|
59108fbe32 | ||
|
|
4c870288d9 | ||
|
|
8da7212763 | ||
|
|
6e76052f9d | ||
|
|
cf84db36ec | ||
|
|
d3f629f183 | ||
|
|
b1aa707a92 | ||
|
|
731176ce3a | ||
|
|
b86fa63f70 | ||
|
|
00fcf6936c | ||
|
|
26384c5c70 | ||
|
|
7209457f53 | ||
|
|
9bc68b2721 | ||
|
|
7bdd198fd3 | ||
|
|
b296e3d94b | ||
|
|
c91855a9b2 | ||
|
|
e8e445cd43 | ||
|
|
735c426072 | ||
|
|
0976b8a17b | ||
|
|
2ad8c149e0 | ||
|
|
31fcb1425d | ||
|
|
470d5e506f | ||
|
|
0ee49cf42e | ||
|
|
cecd8d6aa5 | ||
|
|
15935e9d5f | ||
|
|
5d410e5a03 | ||
|
|
5df77d7e8c | ||
|
|
f891d60d26 | ||
|
|
be25217955 | ||
|
|
b74111feed | ||
|
|
bf92117259 | ||
|
|
031a36c995 | ||
|
|
8036d22ec6 | ||
|
|
f7e8d9e791 | ||
|
|
4b183b7bb6 | ||
|
|
f38e91d80b | ||
|
|
aa3e82976e | ||
|
|
d9c1db2b87 | ||
|
|
f7e3aab4fc | ||
|
|
73bdc3b50d | ||
|
|
cb63bdb9e4 | ||
|
|
8cd3f9fc47 | ||
|
|
e0ab1a8b43 | ||
|
|
c3174f9543 | ||
|
|
2b12875302 | ||
|
|
9cdbd89c1f | ||
|
|
7d81bf0aa3 | ||
|
|
a6d0e29eba | ||
|
|
6054d2a91b | ||
|
|
aea21951a2 | ||
|
|
bbe9067227 | ||
|
|
9a9da062e1 | ||
|
|
dd1a8b174f | ||
|
|
cfb7641eea | ||
|
|
e832efeb9e | ||
|
|
a42548e9d1 | ||
|
|
8615ce28a8 | ||
|
|
8336efec41 | ||
|
|
8560a1e571 | ||
|
|
29c33e6a6a | ||
|
|
a58475dbef | ||
|
|
8a0edd0809 | ||
|
|
35d509d8e7 | ||
|
|
eef808d921 | ||
|
|
9d9ea5c1a0 | ||
|
|
e21ad5cfaa | ||
|
|
05ab0c0aa2 | ||
|
|
e2b6233570 | ||
|
|
19f995f38f | ||
|
|
ac168bbc60 | ||
|
|
5c5e537b31 | ||
|
|
118bcee196 | ||
|
|
3eabd6d1d0 | ||
|
|
ee96e5e08d | ||
|
|
3d9ccd1ddc | ||
|
|
d8161bfe57 | ||
|
|
5fd42399d4 | ||
|
|
b2030255ca | ||
|
|
9f903ec06e | ||
|
|
4ea461c330 | ||
|
|
042a9b8ef6 | ||
|
|
65f1a4154a | ||
|
|
c6a51289b0 | ||
|
|
87525109f1 | ||
|
|
c596d8a5d9 | ||
|
|
d79ad76e48 | ||
|
|
dde0353432 |
259
.agents/api-endpoints-and-auth.md
Normal file
259
.agents/api-endpoints-and-auth.md
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
# API Endpoints and Authentication
|
||||||
|
|
||||||
|
This guide covers how to add new API endpoints and properly integrate them with the auth/permissions system.
|
||||||
|
|
||||||
|
## Architecture overview
|
||||||
|
|
||||||
|
Authentication and authorization flow through three layers:
|
||||||
|
|
||||||
|
1. **Global auth middleware** (`core/http/auth/middleware.go` → `auth.Middleware`) — applied to every request in `core/http/app.go`. Handles session cookies, Bearer tokens, API keys, and legacy API keys. Populates `auth_user` and `auth_role` in the Echo context.
|
||||||
|
2. **Feature middleware** (`auth.RequireFeature`) — per-feature access control applied to route groups or individual routes. Checks if the authenticated user has the specific feature enabled.
|
||||||
|
3. **Admin middleware** (`auth.RequireAdmin`) — restricts endpoints to admin users only.
|
||||||
|
|
||||||
|
When auth is disabled (no auth DB, no legacy API keys), all middleware becomes pass-through (`auth.NoopMiddleware`).
|
||||||
|
|
||||||
|
## Adding a new API endpoint
|
||||||
|
|
||||||
|
### Step 1: Create the handler
|
||||||
|
|
||||||
|
Write the endpoint handler in the appropriate package under `core/http/endpoints/`. Follow existing patterns:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// core/http/endpoints/localai/my_feature.go
|
||||||
|
func MyFeatureEndpoint(app *application.Application) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
// Use auth.GetUser(c) to get the authenticated user (may be nil if auth is disabled)
|
||||||
|
user := auth.GetUser(c)
|
||||||
|
|
||||||
|
// Your logic here
|
||||||
|
return c.JSON(http.StatusOK, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Register routes
|
||||||
|
|
||||||
|
Add routes in the appropriate file under `core/http/routes/`. The file you use depends on the endpoint category:
|
||||||
|
|
||||||
|
| File | Category |
|
||||||
|
|------|----------|
|
||||||
|
| `routes/openai.go` | OpenAI-compatible API endpoints (`/v1/...`) |
|
||||||
|
| `routes/localai.go` | LocalAI-specific endpoints (`/api/...`, `/models/...`, `/backends/...`) |
|
||||||
|
| `routes/agents.go` | Agent pool endpoints (`/api/agents/...`) |
|
||||||
|
| `routes/auth.go` | Auth endpoints (`/api/auth/...`) |
|
||||||
|
| `routes/ui_api.go` | UI backend API endpoints |
|
||||||
|
|
||||||
|
### Step 3: Apply the right middleware
|
||||||
|
|
||||||
|
Choose the appropriate protection level:
|
||||||
|
|
||||||
|
#### No auth required (public)
|
||||||
|
Exempt paths bypass auth entirely. Add to `isExemptPath()` in `middleware.go` or use the `/api/auth/` prefix (always exempt). Use sparingly — most endpoints should require auth.
|
||||||
|
|
||||||
|
#### Standard auth (any authenticated user)
|
||||||
|
The global middleware already handles this. API paths (`/api/`, `/v1/`, etc.) automatically require authentication when auth is enabled. You don't need to add any extra middleware.
|
||||||
|
|
||||||
|
```go
|
||||||
|
router.GET("/v1/my-endpoint", myHandler) // auth enforced by global middleware
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Admin only
|
||||||
|
Pass `adminMiddleware` to the route. This is set up in `app.go` and passed to `Register*Routes` functions:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// In the Register function signature, accept the middleware:
|
||||||
|
func RegisterMyRoutes(router *echo.Echo, app *application.Application, adminMiddleware echo.MiddlewareFunc) {
|
||||||
|
router.POST("/models/apply", myHandler, adminMiddleware)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Feature-gated
|
||||||
|
For endpoints that should be toggleable per-user, use feature middleware. There are two approaches:
|
||||||
|
|
||||||
|
**Approach A: Route-level middleware** (preferred for groups of related endpoints)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// In app.go, create the feature middleware:
|
||||||
|
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
|
||||||
|
|
||||||
|
// Pass it to the route registration function:
|
||||||
|
routes.RegisterMyRoutes(e, app, myFeatureMw)
|
||||||
|
|
||||||
|
// In the routes file, apply to a group:
|
||||||
|
g := e.Group("/api/my-feature", myFeatureMw)
|
||||||
|
g.GET("", listHandler)
|
||||||
|
g.POST("", createHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Approach B: RouteFeatureRegistry** (preferred for individual OpenAI-compatible endpoints)
|
||||||
|
|
||||||
|
Add an entry to `RouteFeatureRegistry` in `core/http/auth/features.go`. The `RequireRouteFeature` global middleware will automatically enforce it:
|
||||||
|
|
||||||
|
```go
|
||||||
|
var RouteFeatureRegistry = []RouteFeature{
|
||||||
|
// ... existing entries ...
|
||||||
|
{"POST", "/v1/my-endpoint", FeatureMyFeature},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Adding a new feature
|
||||||
|
|
||||||
|
When you need a new toggleable feature (not just a new endpoint under an existing feature):
|
||||||
|
|
||||||
|
### 1. Define the feature constant
|
||||||
|
|
||||||
|
Add to `core/http/auth/permissions.go`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
const (
|
||||||
|
// Add to the appropriate group:
|
||||||
|
// Agent features (default OFF for new users)
|
||||||
|
FeatureMyFeature = "my_feature"
|
||||||
|
|
||||||
|
// OR API features (default ON for new users)
|
||||||
|
FeatureMyFeature = "my_feature"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then add it to the appropriate slice:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Default OFF — user must be explicitly granted access:
|
||||||
|
var AgentFeatures = []string{..., FeatureMyFeature}
|
||||||
|
|
||||||
|
// Default ON — user has access unless explicitly revoked:
|
||||||
|
var APIFeatures = []string{..., FeatureMyFeature}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Add feature metadata
|
||||||
|
|
||||||
|
In `core/http/auth/features.go`, add to the appropriate `FeatureMetas` function so the admin UI can display it:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func AgentFeatureMetas() []FeatureMeta {
|
||||||
|
return []FeatureMeta{
|
||||||
|
// ... existing ...
|
||||||
|
{FeatureMyFeature, "My Feature", false}, // false = default OFF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Wire up the middleware
|
||||||
|
|
||||||
|
In `core/http/app.go`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then pass it to the route registration function.
|
||||||
|
|
||||||
|
### 4. Register route-feature mappings (if applicable)
|
||||||
|
|
||||||
|
If your feature gates standard API endpoints (like `/v1/...`), add entries to `RouteFeatureRegistry` in `features.go` instead of using per-route middleware.
|
||||||
|
|
||||||
|
## Accessing the authenticated user in handlers
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/mudler/LocalAI/core/http/auth"
|
||||||
|
|
||||||
|
func MyHandler(c echo.Context) error {
|
||||||
|
// Get the user (nil when auth is disabled or unauthenticated)
|
||||||
|
user := auth.GetUser(c)
|
||||||
|
if user == nil {
|
||||||
|
// Handle unauthenticated — or let middleware handle it
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check role
|
||||||
|
if user.Role == auth.RoleAdmin {
|
||||||
|
// admin-specific logic
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check feature access programmatically (when you need conditional behavior, not full blocking)
|
||||||
|
if auth.HasFeatureAccess(db, user, auth.FeatureMyFeature) {
|
||||||
|
// feature-specific logic
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check model access
|
||||||
|
if !auth.IsModelAllowed(db, user, modelName) {
|
||||||
|
return c.JSON(http.StatusForbidden, ...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Middleware composition patterns
|
||||||
|
|
||||||
|
Middleware can be composed at different levels. Here are the patterns used in the codebase:
|
||||||
|
|
||||||
|
### Group-level middleware (agents pattern)
|
||||||
|
```go
|
||||||
|
// All routes in the group share the middleware
|
||||||
|
g := e.Group("/api/agents", poolReadyMw, agentsMw)
|
||||||
|
g.GET("", listHandler)
|
||||||
|
g.POST("", createHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Per-route middleware (localai pattern)
|
||||||
|
```go
|
||||||
|
// Individual routes get middleware as extra arguments
|
||||||
|
router.POST("/models/apply", applyHandler, adminMiddleware)
|
||||||
|
router.GET("/metrics", metricsHandler, adminMiddleware)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Middleware slice (openai pattern)
|
||||||
|
```go
|
||||||
|
// Build a middleware chain for a handler
|
||||||
|
chatMiddleware := []echo.MiddlewareFunc{
|
||||||
|
usageMiddleware,
|
||||||
|
traceMiddleware,
|
||||||
|
modelFilterMiddleware,
|
||||||
|
}
|
||||||
|
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error response format
|
||||||
|
|
||||||
|
Always use `schema.ErrorResponse` for auth/permission errors to stay consistent with the OpenAI-compatible API:
|
||||||
|
|
||||||
|
```go
|
||||||
|
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
|
||||||
|
Error: &schema.APIError{
|
||||||
|
Message: "feature not enabled for your account",
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
Type: "authorization_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
Use these HTTP status codes:
|
||||||
|
- `401 Unauthorized` — no valid credentials provided
|
||||||
|
- `403 Forbidden` — authenticated but lacking permission
|
||||||
|
- `429 Too Many Requests` — rate limited (auth endpoints)
|
||||||
|
|
||||||
|
## Usage tracking
|
||||||
|
|
||||||
|
If your endpoint should be tracked for usage (token counts, request counts), add the `usageMiddleware` to its middleware chain. See `core/http/middleware/usage.go` and how it's applied in `routes/openai.go`.
|
||||||
|
|
||||||
|
## Path protection rules
|
||||||
|
|
||||||
|
The global auth middleware classifies paths as API paths or non-API paths:
|
||||||
|
|
||||||
|
- **API paths** (always require auth when auth is enabled): `/api/`, `/v1/`, `/models/`, `/backends/`, `/backend/`, `/tts`, `/vad`, `/video`, `/stores/`, `/system`, `/ws/`, `/metrics`
|
||||||
|
- **Exempt paths** (never require auth): `/api/auth/` prefix, anything in `appConfig.PathWithoutAuth`
|
||||||
|
- **Non-API paths** (UI, static assets): pass through without auth — the React UI handles login redirects client-side
|
||||||
|
|
||||||
|
If you add endpoints under a new top-level path prefix, add it to `isAPIPath()` in `middleware.go` to ensure it requires authentication.
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
When adding a new endpoint:
|
||||||
|
|
||||||
|
- [ ] Handler in `core/http/endpoints/`
|
||||||
|
- [ ] Route registered in appropriate `core/http/routes/` file
|
||||||
|
- [ ] Auth level chosen: public / standard / admin / feature-gated
|
||||||
|
- [ ] If feature-gated: constant in `permissions.go`, metadata in `features.go`, middleware in `app.go`
|
||||||
|
- [ ] If new path prefix: added to `isAPIPath()` in `middleware.go`
|
||||||
|
- [ ] If OpenAI-compatible: entry in `RouteFeatureRegistry`
|
||||||
|
- [ ] If token-counting: `usageMiddleware` added to middleware chain
|
||||||
|
- [ ] Error responses use `schema.ErrorResponse` format
|
||||||
|
- [ ] Tests cover both authenticated and unauthenticated access
|
||||||
@@ -49,3 +49,4 @@ The project documentation is located in `docs/content`. When adding new features
|
|||||||
- **Feature Documentation**: If you add a new feature (like a new backend or API endpoint), create a new markdown file in `docs/content/features/` explaining what it is, how to configure it, and how to use it.
|
- **Feature Documentation**: If you add a new feature (like a new backend or API endpoint), create a new markdown file in `docs/content/features/` explaining what it is, how to configure it, and how to use it.
|
||||||
- **Configuration**: If you modify configuration options, update the relevant sections in `docs/content/`.
|
- **Configuration**: If you modify configuration options, update the relevant sections in `docs/content/`.
|
||||||
- **Examples**: providing concrete examples (like YAML configuration blocks) is highly encouraged to help users get started quickly.
|
- **Examples**: providing concrete examples (like YAML configuration blocks) is highly encouraged to help users get started quickly.
|
||||||
|
- **Shortcodes**: Use `{{% notice note %}}`, `{{% notice tip %}}`, or `{{% notice warning %}}` for callout boxes. Do **not** use `{{% alert %}}` — that shortcode does not exist in this project's Hugo theme and will break the docs build.
|
||||||
|
|||||||
141
.agents/debugging-backends.md
Normal file
141
.agents/debugging-backends.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Debugging and Rebuilding Backends
|
||||||
|
|
||||||
|
When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild.
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
- **Source directory**: `backend/python/<name>/` (or `backend/go/<name>/`, `backend/cpp/<name>/`)
|
||||||
|
- **Installed directory**: `backends/<name>/` — this is what LocalAI actually runs. It is populated by `make backends/<name>` which builds a Docker image, exports it, and installs it via `local-ai backends install`.
|
||||||
|
- **Virtual environment**: `backends/<name>/venv/` — the installed Python venv (for Python backends). The Python binary is at `backends/<name>/venv/bin/python`.
|
||||||
|
|
||||||
|
Editing files in `backend/python/<name>/` does **not** affect the running backend until you rebuild with `make backends/<name>`.
|
||||||
|
|
||||||
|
## Diagnosing Failures
|
||||||
|
|
||||||
|
### 1. Check the logs
|
||||||
|
|
||||||
|
Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID:
|
||||||
|
|
||||||
|
```
|
||||||
|
GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..."
|
||||||
|
```
|
||||||
|
|
||||||
|
Common error patterns:
|
||||||
|
- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub.
|
||||||
|
- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`).
|
||||||
|
- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback.
|
||||||
|
|
||||||
|
### 2. Test the Python environment directly
|
||||||
|
|
||||||
|
You can run the installed venv's Python to check imports without starting the full server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
backends/<name>/venv/bin/python -c "import datasets; print(datasets.__version__)"
|
||||||
|
```
|
||||||
|
|
||||||
|
If `pip` is missing from the venv, bootstrap it:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
backends/<name>/venv/bin/python -m ensurepip
|
||||||
|
```
|
||||||
|
|
||||||
|
Then use `backends/<name>/venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements.
|
||||||
|
|
||||||
|
### 3. Check upstream dependency constraints
|
||||||
|
|
||||||
|
When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`:
|
||||||
|
|
||||||
|
```
|
||||||
|
https://github.com/huggingface/trl/blob/main/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Pin minimum versions in the backend's requirements files to match upstream.
|
||||||
|
|
||||||
|
## Common Fixes
|
||||||
|
|
||||||
|
### Missing gRPC methods
|
||||||
|
|
||||||
|
If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
"""No-op — actual loading happens elsewhere."""
|
||||||
|
return backend_pb2.Result(success=True, message="OK")
|
||||||
|
```
|
||||||
|
|
||||||
|
The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading.
|
||||||
|
|
||||||
|
### Dependency version conflicts
|
||||||
|
|
||||||
|
Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps:
|
||||||
|
|
||||||
|
1. Identify the broken import in the logs
|
||||||
|
2. Test in the installed venv: `backends/<name>/venv/bin/python -c "import <module>"`
|
||||||
|
3. Check upstream requirements for version constraints
|
||||||
|
4. Update **all** requirements files in `backend/python/<name>/`:
|
||||||
|
- `requirements.txt` — base deps (grpcio, protobuf)
|
||||||
|
- `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index)
|
||||||
|
- `requirements-cublas12.txt` — CUDA 12
|
||||||
|
- `requirements-cublas13.txt` — CUDA 13
|
||||||
|
5. Rebuild: `make backends/<name>`
|
||||||
|
|
||||||
|
### PyTorch index conflicts (uv resolver)
|
||||||
|
|
||||||
|
The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||||
|
installRequirements
|
||||||
|
```
|
||||||
|
|
||||||
|
Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference.
|
||||||
|
|
||||||
|
## Rebuilding
|
||||||
|
|
||||||
|
### Rebuild a single backend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make backends/<name>
|
||||||
|
```
|
||||||
|
|
||||||
|
This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/<name>.tar`, and installs it into `backends/<name>/`. It also rebuilds the `local-ai` Go binary (without extra tags).
|
||||||
|
|
||||||
|
**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/<name>` step will overwrite your binary without that tag. Rebuild the Go binary afterward:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
GO_TAGS=auth make build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rebuild and restart
|
||||||
|
|
||||||
|
After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Kill existing process
|
||||||
|
kill <pid>
|
||||||
|
|
||||||
|
# Restart
|
||||||
|
./local-ai run --debug [your flags]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quick iteration (skip Docker rebuild)
|
||||||
|
|
||||||
|
For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Edit the installed copy
|
||||||
|
vim backends/<name>/backend.py
|
||||||
|
|
||||||
|
# Restart LocalAI to respawn the gRPC process
|
||||||
|
```
|
||||||
|
|
||||||
|
This is useful for testing but **does not persist** — the next `make backends/<name>` will overwrite it. Always commit fixes to the source in `backend/python/<name>/`.
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
After fixing and rebuilding:
|
||||||
|
|
||||||
|
1. Start LocalAI and confirm the backend registers: look for `Registering backend name="<name>"` in the logs
|
||||||
|
2. Trigger the operation that failed (e.g. start a fine-tuning job)
|
||||||
|
3. Watch the GRPC stderr/stdout lines for the backend's model ID
|
||||||
|
4. Confirm no errors in the traceback
|
||||||
3
.github/gallery-agent/agent.go
vendored
3
.github/gallery-agent/agent.go
vendored
@@ -133,6 +133,7 @@ func getRealReadme(ctx context.Context, repository string) (string, error) {
|
|||||||
result, err := cogito.ExecuteTools(llm, fragment,
|
result, err := cogito.ExecuteTools(llm, fragment,
|
||||||
cogito.WithIterations(3),
|
cogito.WithIterations(3),
|
||||||
cogito.WithMaxAttempts(3),
|
cogito.WithMaxAttempts(3),
|
||||||
|
cogito.DisableSinkState,
|
||||||
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
|
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -406,7 +407,7 @@ func getHuggingFaceAvatarURL(author string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse the response to get avatar URL
|
// Parse the response to get avatar URL
|
||||||
var userInfo map[string]interface{}
|
var userInfo map[string]any
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
15
.github/gallery-agent/gallery.go
vendored
15
.github/gallery-agent/gallery.go
vendored
@@ -79,7 +79,20 @@ func generateYAMLEntry(model ProcessedModel, quantization string) string {
|
|||||||
description = cleanTextContent(description)
|
description = cleanTextContent(description)
|
||||||
formattedDescription := formatTextContent(description)
|
formattedDescription := formatTextContent(description)
|
||||||
|
|
||||||
configFile := formatTextContent(modelConfig.ConfigFile)
|
// Strip name and description from config file since they are
|
||||||
|
// already present at the gallery entry level and should not
|
||||||
|
// appear under overrides.
|
||||||
|
configFileContent := modelConfig.ConfigFile
|
||||||
|
var cfgMap map[string]any
|
||||||
|
if err := yaml.Unmarshal([]byte(configFileContent), &cfgMap); err == nil {
|
||||||
|
delete(cfgMap, "name")
|
||||||
|
delete(cfgMap, "description")
|
||||||
|
if cleaned, err := yaml.Marshal(cfgMap); err == nil {
|
||||||
|
configFileContent = string(cleaned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
configFile := formatTextContent(configFileContent)
|
||||||
|
|
||||||
filesYAML, _ := yaml.Marshal(modelConfig.Files)
|
filesYAML, _ := yaml.Marshal(modelConfig.Files)
|
||||||
|
|
||||||
|
|||||||
40
.github/gallery-agent/testing.go
vendored
40
.github/gallery-agent/testing.go
vendored
@@ -3,7 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand/v2"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -13,11 +13,11 @@ func runSyntheticMode() error {
|
|||||||
generator := NewSyntheticDataGenerator()
|
generator := NewSyntheticDataGenerator()
|
||||||
|
|
||||||
// Generate a random number of synthetic models (1-3)
|
// Generate a random number of synthetic models (1-3)
|
||||||
numModels := generator.rand.Intn(3) + 1
|
numModels := generator.rand.IntN(3) + 1
|
||||||
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
||||||
|
|
||||||
var models []ProcessedModel
|
var models []ProcessedModel
|
||||||
for i := 0; i < numModels; i++ {
|
for range numModels {
|
||||||
model := generator.GenerateProcessedModel()
|
model := generator.GenerateProcessedModel()
|
||||||
models = append(models, model)
|
models = append(models, model)
|
||||||
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
||||||
@@ -42,14 +42,14 @@ type SyntheticDataGenerator struct {
|
|||||||
// NewSyntheticDataGenerator creates a new synthetic data generator
|
// NewSyntheticDataGenerator creates a new synthetic data generator
|
||||||
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
||||||
return &SyntheticDataGenerator{
|
return &SyntheticDataGenerator{
|
||||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
rand: rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
||||||
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
||||||
fileTypes := []string{"model", "readme", "other"}
|
fileTypes := []string{"model", "readme", "other"}
|
||||||
fileType := fileTypes[g.rand.Intn(len(fileTypes))]
|
fileType := fileTypes[g.rand.IntN(len(fileTypes))]
|
||||||
|
|
||||||
var path string
|
var path string
|
||||||
var isReadme bool
|
var isReadme bool
|
||||||
@@ -68,7 +68,7 @@ func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile
|
|||||||
|
|
||||||
return ProcessedModelFile{
|
return ProcessedModelFile{
|
||||||
Path: path,
|
Path: path,
|
||||||
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
|
Size: int64(g.rand.IntN(1000000000) + 1000000), // 1MB to 1GB
|
||||||
SHA256: g.randomSHA256(),
|
SHA256: g.randomSHA256(),
|
||||||
IsReadme: isReadme,
|
IsReadme: isReadme,
|
||||||
FileType: fileType,
|
FileType: fileType,
|
||||||
@@ -80,19 +80,19 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
|||||||
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
||||||
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
||||||
|
|
||||||
author := authors[g.rand.Intn(len(authors))]
|
author := authors[g.rand.IntN(len(authors))]
|
||||||
modelName := modelNames[g.rand.Intn(len(modelNames))]
|
modelName := modelNames[g.rand.IntN(len(modelNames))]
|
||||||
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
||||||
|
|
||||||
// Generate files
|
// Generate files
|
||||||
numFiles := g.rand.Intn(5) + 2 // 2-6 files
|
numFiles := g.rand.IntN(5) + 2 // 2-6 files
|
||||||
files := make([]ProcessedModelFile, numFiles)
|
files := make([]ProcessedModelFile, numFiles)
|
||||||
|
|
||||||
// Ensure at least one model file and one readme
|
// Ensure at least one model file and one readme
|
||||||
hasModelFile := false
|
hasModelFile := false
|
||||||
hasReadme := false
|
hasReadme := false
|
||||||
|
|
||||||
for i := 0; i < numFiles; i++ {
|
for i := range numFiles {
|
||||||
files[i] = g.GenerateProcessedModelFile()
|
files[i] = g.GenerateProcessedModelFile()
|
||||||
if files[i].FileType == "model" {
|
if files[i].FileType == "model" {
|
||||||
hasModelFile = true
|
hasModelFile = true
|
||||||
@@ -140,27 +140,27 @@ func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
|||||||
|
|
||||||
// Generate sample metadata
|
// Generate sample metadata
|
||||||
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
|
licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""}
|
||||||
license := licenses[g.rand.Intn(len(licenses))]
|
license := licenses[g.rand.IntN(len(licenses))]
|
||||||
|
|
||||||
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
|
sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"}
|
||||||
numTags := g.rand.Intn(4) + 3 // 3-6 tags
|
numTags := g.rand.IntN(4) + 3 // 3-6 tags
|
||||||
tags := make([]string, numTags)
|
tags := make([]string, numTags)
|
||||||
for i := 0; i < numTags; i++ {
|
for i := range numTags {
|
||||||
tags[i] = sampleTags[g.rand.Intn(len(sampleTags))]
|
tags[i] = sampleTags[g.rand.IntN(len(sampleTags))]
|
||||||
}
|
}
|
||||||
// Remove duplicates
|
// Remove duplicates
|
||||||
tags = g.removeDuplicates(tags)
|
tags = g.removeDuplicates(tags)
|
||||||
|
|
||||||
// Optionally include icon (50% chance)
|
// Optionally include icon (50% chance)
|
||||||
icon := ""
|
icon := ""
|
||||||
if g.rand.Intn(2) == 0 {
|
if g.rand.IntN(2) == 0 {
|
||||||
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
|
icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24))
|
||||||
}
|
}
|
||||||
|
|
||||||
return ProcessedModel{
|
return ProcessedModel{
|
||||||
ModelID: modelID,
|
ModelID: modelID,
|
||||||
Author: author,
|
Author: author,
|
||||||
Downloads: g.rand.Intn(1000000) + 1000,
|
Downloads: g.rand.IntN(1000000) + 1000,
|
||||||
LastModified: g.randomDate(),
|
LastModified: g.randomDate(),
|
||||||
Files: files,
|
Files: files,
|
||||||
PreferredModelFile: preferredModelFile,
|
PreferredModelFile: preferredModelFile,
|
||||||
@@ -180,7 +180,7 @@ func (g *SyntheticDataGenerator) randomString(length int) string {
|
|||||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
b := make([]byte, length)
|
b := make([]byte, length)
|
||||||
for i := range b {
|
for i := range b {
|
||||||
b[i] = charset[g.rand.Intn(len(charset))]
|
b[i] = charset[g.rand.IntN(len(charset))]
|
||||||
}
|
}
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
@@ -189,14 +189,14 @@ func (g *SyntheticDataGenerator) randomSHA256() string {
|
|||||||
const charset = "0123456789abcdef"
|
const charset = "0123456789abcdef"
|
||||||
b := make([]byte, 64)
|
b := make([]byte, 64)
|
||||||
for i := range b {
|
for i := range b {
|
||||||
b[i] = charset[g.rand.Intn(len(charset))]
|
b[i] = charset[g.rand.IntN(len(charset))]
|
||||||
}
|
}
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *SyntheticDataGenerator) randomDate() string {
|
func (g *SyntheticDataGenerator) randomDate() string {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
daysAgo := g.rand.Intn(365) // Random date within last year
|
daysAgo := g.rand.IntN(365) // Random date within last year
|
||||||
pastDate := now.AddDate(0, 0, -daysAgo)
|
pastDate := now.AddDate(0, 0, -daysAgo)
|
||||||
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
||||||
}
|
}
|
||||||
@@ -220,5 +220,5 @@ func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string)
|
|||||||
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
||||||
}
|
}
|
||||||
|
|
||||||
return templates[g.rand.Intn(len(templates))]
|
return templates[g.rand.IntN(len(templates))]
|
||||||
}
|
}
|
||||||
|
|||||||
55
.github/workflows/backend.yml
vendored
55
.github/workflows/backend.yml
vendored
@@ -118,6 +118,32 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
- build-type: ''
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-cpu-trl'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'true'
|
||||||
|
backend: "trl"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
|
- build-type: ''
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64,linux/arm64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-cpu-llama-cpp-quantization'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'true'
|
||||||
|
backend: "llama-cpp-quantization"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
@@ -366,6 +392,19 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "8"
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-cuda-12-trl'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "trl"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: "12"
|
||||||
cuda-minor-version: "8"
|
cuda-minor-version: "8"
|
||||||
@@ -757,6 +796,19 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "13"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-cuda-13-trl'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "trl"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: 'l4t'
|
- build-type: 'l4t'
|
||||||
cuda-major-version: "13"
|
cuda-major-version: "13"
|
||||||
cuda-minor-version: "0"
|
cuda-minor-version: "0"
|
||||||
@@ -2373,6 +2425,9 @@ jobs:
|
|||||||
tag-suffix: "-metal-darwin-arm64-local-store"
|
tag-suffix: "-metal-darwin-arm64-local-store"
|
||||||
build-type: "metal"
|
build-type: "metal"
|
||||||
lang: "go"
|
lang: "go"
|
||||||
|
- backend: "llama-cpp-quantization"
|
||||||
|
tag-suffix: "-metal-darwin-arm64-llama-cpp-quantization"
|
||||||
|
build-type: "mps"
|
||||||
with:
|
with:
|
||||||
backend: ${{ matrix.backend }}
|
backend: ${{ matrix.backend }}
|
||||||
build-type: ${{ matrix.build-type }}
|
build-type: ${{ matrix.build-type }}
|
||||||
|
|||||||
48
.github/workflows/bump-inference-defaults.yml
vendored
Normal file
48
.github/workflows/bump-inference-defaults.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
name: Bump inference defaults
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
# Run daily at 06:00 UTC
|
||||||
|
- cron: '0 6 * * *'
|
||||||
|
workflow_dispatch: # Allow manual trigger
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
bump:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
|
||||||
|
- name: Re-fetch inference defaults
|
||||||
|
run: make generate-force
|
||||||
|
|
||||||
|
- name: Check for changes
|
||||||
|
id: diff
|
||||||
|
run: |
|
||||||
|
if git diff --quiet core/config/inference_defaults.json; then
|
||||||
|
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Create Pull Request
|
||||||
|
if: steps.diff.outputs.changed == 'true'
|
||||||
|
uses: peter-evans/create-pull-request@v8
|
||||||
|
with:
|
||||||
|
commit-message: "chore: bump inference defaults from unsloth"
|
||||||
|
title: "chore: bump inference defaults from unsloth"
|
||||||
|
body: |
|
||||||
|
Auto-generated update of `core/config/inference_defaults.json` from
|
||||||
|
[unsloth's inference_defaults.json](https://github.com/unslothai/unsloth/blob/main/studio/backend/assets/configs/inference_defaults.json).
|
||||||
|
|
||||||
|
This PR was created automatically by the `bump-inference-defaults` workflow.
|
||||||
|
branch: chore/bump-inference-defaults
|
||||||
|
delete-branch: true
|
||||||
|
labels: automated
|
||||||
2
.github/workflows/gallery-agent.yaml
vendored
2
.github/workflows/gallery-agent.yaml
vendored
@@ -55,7 +55,7 @@ jobs:
|
|||||||
- name: Run gallery agent
|
- name: Run gallery agent
|
||||||
env:
|
env:
|
||||||
#OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
|
#OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||||
OPENAI_MODE: Qwen3.5-2B-GGUF
|
OPENAI_MODEL: Qwen3.5-2B-GGUF
|
||||||
OPENAI_BASE_URL: "http://localhost:8080"
|
OPENAI_BASE_URL: "http://localhost:8080"
|
||||||
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
|
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
|
||||||
#OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
#OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
||||||
|
|||||||
75
.github/workflows/gh-pages.yml
vendored
Normal file
75
.github/workflows/gh-pages.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
name: Deploy docs to GitHub Pages
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- 'docs/**'
|
||||||
|
- 'gallery/**'
|
||||||
|
- 'images/**'
|
||||||
|
- '.github/ci/modelslist.go'
|
||||||
|
- '.github/workflows/gh-pages.yml'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: pages
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
HUGO_VERSION: "0.146.3"
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # needed for enableGitInfo
|
||||||
|
submodules: true
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.22'
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Setup Hugo
|
||||||
|
uses: peaceiris/actions-hugo@v3
|
||||||
|
with:
|
||||||
|
hugo-version: ${{ env.HUGO_VERSION }}
|
||||||
|
extended: true
|
||||||
|
|
||||||
|
- name: Setup Pages
|
||||||
|
id: pages
|
||||||
|
uses: actions/configure-pages@v6
|
||||||
|
|
||||||
|
- name: Generate gallery
|
||||||
|
run: go run ./.github/ci/modelslist.go ./gallery/index.yaml > docs/static/gallery.html
|
||||||
|
|
||||||
|
- name: Build site
|
||||||
|
working-directory: docs
|
||||||
|
run: |
|
||||||
|
mkdir -p layouts/_default
|
||||||
|
hugo --minify --baseURL "${{ steps.pages.outputs.base_url }}/"
|
||||||
|
|
||||||
|
- name: Upload artifact
|
||||||
|
uses: actions/upload-pages-artifact@v4
|
||||||
|
with:
|
||||||
|
path: docs/public
|
||||||
|
|
||||||
|
deploy:
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: build
|
||||||
|
steps:
|
||||||
|
- name: Deploy to GitHub Pages
|
||||||
|
id: deployment
|
||||||
|
uses: actions/deploy-pages@v5
|
||||||
84
.github/workflows/test-extra.yml
vendored
84
.github/workflows/test-extra.yml
vendored
@@ -14,6 +14,37 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
detect-changes:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
run-all: ${{ steps.detect.outputs.run-all }}
|
||||||
|
transformers: ${{ steps.detect.outputs.transformers }}
|
||||||
|
rerankers: ${{ steps.detect.outputs.rerankers }}
|
||||||
|
diffusers: ${{ steps.detect.outputs.diffusers }}
|
||||||
|
coqui: ${{ steps.detect.outputs.coqui }}
|
||||||
|
moonshine: ${{ steps.detect.outputs.moonshine }}
|
||||||
|
pocket-tts: ${{ steps.detect.outputs.pocket-tts }}
|
||||||
|
qwen-tts: ${{ steps.detect.outputs.qwen-tts }}
|
||||||
|
qwen-asr: ${{ steps.detect.outputs.qwen-asr }}
|
||||||
|
nemo: ${{ steps.detect.outputs.nemo }}
|
||||||
|
voxcpm: ${{ steps.detect.outputs.voxcpm }}
|
||||||
|
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
||||||
|
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||||
|
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
- name: Install dependencies
|
||||||
|
run: bun add js-yaml @octokit/core
|
||||||
|
- name: Detect changed backends
|
||||||
|
id: detect
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GITHUB_EVENT_PATH: ${{ github.event_path }}
|
||||||
|
run: bun run scripts/changed-backends.js
|
||||||
|
|
||||||
# Requires CUDA
|
# Requires CUDA
|
||||||
# tests-chatterbox-tts:
|
# tests-chatterbox-tts:
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
@@ -37,6 +68,8 @@ jobs:
|
|||||||
# make --jobs=5 --output-sync=target -C backend/python/chatterbox
|
# make --jobs=5 --output-sync=target -C backend/python/chatterbox
|
||||||
# make --jobs=5 --output-sync=target -C backend/python/chatterbox test
|
# make --jobs=5 --output-sync=target -C backend/python/chatterbox test
|
||||||
tests-transformers:
|
tests-transformers:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.transformers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -58,6 +91,8 @@ jobs:
|
|||||||
make --jobs=5 --output-sync=target -C backend/python/transformers
|
make --jobs=5 --output-sync=target -C backend/python/transformers
|
||||||
make --jobs=5 --output-sync=target -C backend/python/transformers test
|
make --jobs=5 --output-sync=target -C backend/python/transformers test
|
||||||
tests-rerankers:
|
tests-rerankers:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.rerankers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -80,6 +115,8 @@ jobs:
|
|||||||
make --jobs=5 --output-sync=target -C backend/python/rerankers test
|
make --jobs=5 --output-sync=target -C backend/python/rerankers test
|
||||||
|
|
||||||
tests-diffusers:
|
tests-diffusers:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.diffusers == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -229,6 +266,8 @@ jobs:
|
|||||||
# make --jobs=5 --output-sync=target -C backend/python/vllm test
|
# make --jobs=5 --output-sync=target -C backend/python/vllm test
|
||||||
|
|
||||||
tests-coqui:
|
tests-coqui:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.coqui == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -248,6 +287,8 @@ jobs:
|
|||||||
make --jobs=5 --output-sync=target -C backend/python/coqui
|
make --jobs=5 --output-sync=target -C backend/python/coqui
|
||||||
make --jobs=5 --output-sync=target -C backend/python/coqui test
|
make --jobs=5 --output-sync=target -C backend/python/coqui test
|
||||||
tests-moonshine:
|
tests-moonshine:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.moonshine == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -267,6 +308,8 @@ jobs:
|
|||||||
make --jobs=5 --output-sync=target -C backend/python/moonshine
|
make --jobs=5 --output-sync=target -C backend/python/moonshine
|
||||||
make --jobs=5 --output-sync=target -C backend/python/moonshine test
|
make --jobs=5 --output-sync=target -C backend/python/moonshine test
|
||||||
tests-pocket-tts:
|
tests-pocket-tts:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.pocket-tts == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -286,6 +329,8 @@ jobs:
|
|||||||
make --jobs=5 --output-sync=target -C backend/python/pocket-tts
|
make --jobs=5 --output-sync=target -C backend/python/pocket-tts
|
||||||
make --jobs=5 --output-sync=target -C backend/python/pocket-tts test
|
make --jobs=5 --output-sync=target -C backend/python/pocket-tts test
|
||||||
tests-qwen-tts:
|
tests-qwen-tts:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.qwen-tts == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -327,6 +372,8 @@ jobs:
|
|||||||
# make --jobs=5 --output-sync=target -C backend/python/fish-speech
|
# make --jobs=5 --output-sync=target -C backend/python/fish-speech
|
||||||
# make --jobs=5 --output-sync=target -C backend/python/fish-speech test
|
# make --jobs=5 --output-sync=target -C backend/python/fish-speech test
|
||||||
tests-qwen-asr:
|
tests-qwen-asr:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.qwen-asr == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -346,6 +393,8 @@ jobs:
|
|||||||
make --jobs=5 --output-sync=target -C backend/python/qwen-asr
|
make --jobs=5 --output-sync=target -C backend/python/qwen-asr
|
||||||
make --jobs=5 --output-sync=target -C backend/python/qwen-asr test
|
make --jobs=5 --output-sync=target -C backend/python/qwen-asr test
|
||||||
tests-nemo:
|
tests-nemo:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.nemo == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -365,6 +414,8 @@ jobs:
|
|||||||
make --jobs=5 --output-sync=target -C backend/python/nemo
|
make --jobs=5 --output-sync=target -C backend/python/nemo
|
||||||
make --jobs=5 --output-sync=target -C backend/python/nemo test
|
make --jobs=5 --output-sync=target -C backend/python/nemo test
|
||||||
tests-voxcpm:
|
tests-voxcpm:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.voxcpm == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -383,7 +434,38 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm
|
make --jobs=5 --output-sync=target -C backend/python/voxcpm
|
||||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm test
|
make --jobs=5 --output-sync=target -C backend/python/voxcpm test
|
||||||
|
tests-llama-cpp-quantization:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.llama-cpp-quantization == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 30
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y build-essential cmake curl git python3-pip
|
||||||
|
# Install UV
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||||
|
- name: Build llama-quantize from llama.cpp
|
||||||
|
run: |
|
||||||
|
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git /tmp/llama.cpp
|
||||||
|
cmake -B /tmp/llama.cpp/build -S /tmp/llama.cpp -DGGML_NATIVE=OFF
|
||||||
|
cmake --build /tmp/llama.cpp/build --target llama-quantize -j$(nproc)
|
||||||
|
sudo cp /tmp/llama.cpp/build/bin/llama-quantize /usr/local/bin/
|
||||||
|
- name: Install backend
|
||||||
|
run: |
|
||||||
|
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization
|
||||||
|
- name: Test llama-cpp-quantization
|
||||||
|
run: |
|
||||||
|
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization test
|
||||||
tests-acestep-cpp:
|
tests-acestep-cpp:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.acestep-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@@ -414,6 +496,8 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
make --jobs=5 --output-sync=target -C backend/go/acestep-cpp test
|
make --jobs=5 --output-sync=target -C backend/go/acestep-cpp test
|
||||||
tests-voxtral:
|
tests-voxtral:
|
||||||
|
needs: detect-changes
|
||||||
|
if: needs.detect-changes.outputs.voxtral == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
|
|||||||
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ['1.25.x']
|
go-version: ['1.26.x']
|
||||||
steps:
|
steps:
|
||||||
- name: Free Disk Space (Ubuntu)
|
- name: Free Disk Space (Ubuntu)
|
||||||
uses: jlumbroso/free-disk-space@main
|
uses: jlumbroso/free-disk-space@main
|
||||||
@@ -179,7 +179,7 @@ jobs:
|
|||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ['1.25.x']
|
go-version: ['1.26.x']
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
|
|||||||
72
.github/workflows/tests-ui-e2e.yml
vendored
Normal file
72
.github/workflows/tests-ui-e2e.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
---
|
||||||
|
name: 'UI E2E Tests'
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'core/http/**'
|
||||||
|
- 'tests/e2e-ui/**'
|
||||||
|
- 'tests/e2e/mock-backend/**'
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ci-tests-ui-e2e-${{ github.head_ref || github.ref }}-${{ github.repository }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
tests-ui-e2e:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go-version: ['1.26.x']
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- name: Setup Go ${{ matrix.go-version }}
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go-version }}
|
||||||
|
cache: false
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v6
|
||||||
|
with:
|
||||||
|
node-version: '22'
|
||||||
|
- name: Proto Dependencies
|
||||||
|
run: |
|
||||||
|
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
|
||||||
|
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||||
|
rm protoc.zip
|
||||||
|
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||||
|
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||||
|
- name: System Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y build-essential libopus-dev
|
||||||
|
- name: Build UI test server
|
||||||
|
run: PATH="$PATH:$HOME/go/bin" make build-ui-test-server
|
||||||
|
- name: Install Playwright
|
||||||
|
working-directory: core/http/react-ui
|
||||||
|
run: |
|
||||||
|
npm install
|
||||||
|
npx playwright install --with-deps chromium
|
||||||
|
- name: Run Playwright tests
|
||||||
|
working-directory: core/http/react-ui
|
||||||
|
run: npx playwright test
|
||||||
|
- name: Upload Playwright report
|
||||||
|
if: ${{ failure() }}
|
||||||
|
uses: actions/upload-artifact@v7
|
||||||
|
with:
|
||||||
|
name: playwright-report
|
||||||
|
path: core/http/react-ui/playwright-report/
|
||||||
|
retention-days: 7
|
||||||
|
- name: Setup tmate session if tests fail
|
||||||
|
if: ${{ failure() }}
|
||||||
|
uses: mxschmitt/action-tmate@v3.23
|
||||||
|
with:
|
||||||
|
detached: true
|
||||||
|
connect-timeout-seconds: 180
|
||||||
|
limit-access-to-actor: true
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -72,3 +72,8 @@ core/http/react-ui/dist
|
|||||||
|
|
||||||
# Extracted backend binaries for container-based testing
|
# Extracted backend binaries for container-based testing
|
||||||
local-backends/
|
local-backends/
|
||||||
|
|
||||||
|
# UI E2E test artifacts
|
||||||
|
tests/e2e-ui/ui-test-server
|
||||||
|
core/http/react-ui/playwright-report/
|
||||||
|
core/http/react-ui/test-results/
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
|
|||||||
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
|
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
|
||||||
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
||||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||||
|
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||||
|
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||||
|
|
||||||
## Quick Reference
|
## Quick Reference
|
||||||
|
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ ENV PATH=/opt/rocm/bin:${PATH}
|
|||||||
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
|
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
|
||||||
FROM requirements-drivers AS build-requirements
|
FROM requirements-drivers AS build-requirements
|
||||||
|
|
||||||
ARG GO_VERSION=1.25.4
|
ARG GO_VERSION=1.26.0
|
||||||
ARG CMAKE_VERSION=3.31.10
|
ARG CMAKE_VERSION=3.31.10
|
||||||
ARG CMAKE_FROM_SOURCE=false
|
ARG CMAKE_FROM_SOURCE=false
|
||||||
ARG TARGETARCH
|
ARG TARGETARCH
|
||||||
@@ -256,7 +256,7 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
FROM build-requirements AS builder-base
|
FROM build-requirements AS builder-base
|
||||||
|
|
||||||
ARG GO_TAGS=""
|
ARG GO_TAGS="auth"
|
||||||
ARG GRPC_BACKENDS
|
ARG GRPC_BACKENDS
|
||||||
ARG MAKEFLAGS
|
ARG MAKEFLAGS
|
||||||
ARG LD_FLAGS="-s -w"
|
ARG LD_FLAGS="-s -w"
|
||||||
@@ -319,7 +319,6 @@ COPY ./.git ./.git
|
|||||||
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
|
# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here
|
||||||
COPY ./pkg/grpc ./pkg/grpc
|
COPY ./pkg/grpc ./pkg/grpc
|
||||||
COPY ./pkg/utils ./pkg/utils
|
COPY ./pkg/utils ./pkg/utils
|
||||||
COPY ./pkg/langchain ./pkg/langchain
|
|
||||||
|
|
||||||
RUN ls -l ./
|
RUN ls -l ./
|
||||||
RUN make protogen-go
|
RUN make protogen-go
|
||||||
|
|||||||
39
Makefile
39
Makefile
@@ -1,5 +1,5 @@
|
|||||||
# Disable parallel execution for backend builds
|
# Disable parallel execution for backend builds
|
||||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus
|
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization
|
||||||
|
|
||||||
GOCMD=go
|
GOCMD=go
|
||||||
GOTEST=$(GOCMD) test
|
GOTEST=$(GOCMD) test
|
||||||
@@ -107,7 +107,7 @@ core/http/react-ui/dist: react-ui
|
|||||||
|
|
||||||
## Build:
|
## Build:
|
||||||
|
|
||||||
build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project
|
build: protogen-go generate install-go-tools core/http/react-ui/dist ## Build the project
|
||||||
$(info ${GREEN}I local-ai build info:${RESET})
|
$(info ${GREEN}I local-ai build info:${RESET})
|
||||||
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
||||||
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
|
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
|
||||||
@@ -398,6 +398,16 @@ protogen-go: protoc install-go-tools
|
|||||||
./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
||||||
backend/backend.proto
|
backend/backend.proto
|
||||||
|
|
||||||
|
core/config/inference_defaults.json: ## Fetch inference defaults from unsloth (only if missing)
|
||||||
|
$(GOCMD) generate ./core/config/...
|
||||||
|
|
||||||
|
.PHONY: generate
|
||||||
|
generate: core/config/inference_defaults.json ## Ensure inference defaults exist
|
||||||
|
|
||||||
|
.PHONY: generate-force
|
||||||
|
generate-force: ## Re-fetch inference defaults from unsloth (always)
|
||||||
|
$(GOCMD) generate ./core/config/...
|
||||||
|
|
||||||
.PHONY: protogen-go-clean
|
.PHONY: protogen-go-clean
|
||||||
protogen-go-clean:
|
protogen-go-clean:
|
||||||
$(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
|
$(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go
|
||||||
@@ -421,6 +431,7 @@ prepare-test-extra: protogen-python
|
|||||||
$(MAKE) -C backend/python/voxcpm
|
$(MAKE) -C backend/python/voxcpm
|
||||||
$(MAKE) -C backend/python/whisperx
|
$(MAKE) -C backend/python/whisperx
|
||||||
$(MAKE) -C backend/python/ace-step
|
$(MAKE) -C backend/python/ace-step
|
||||||
|
$(MAKE) -C backend/python/trl
|
||||||
|
|
||||||
test-extra: prepare-test-extra
|
test-extra: prepare-test-extra
|
||||||
$(MAKE) -C backend/python/transformers test
|
$(MAKE) -C backend/python/transformers test
|
||||||
@@ -440,6 +451,7 @@ test-extra: prepare-test-extra
|
|||||||
$(MAKE) -C backend/python/voxcpm test
|
$(MAKE) -C backend/python/voxcpm test
|
||||||
$(MAKE) -C backend/python/whisperx test
|
$(MAKE) -C backend/python/whisperx test
|
||||||
$(MAKE) -C backend/python/ace-step test
|
$(MAKE) -C backend/python/ace-step test
|
||||||
|
$(MAKE) -C backend/python/trl test
|
||||||
|
|
||||||
DOCKER_IMAGE?=local-ai
|
DOCKER_IMAGE?=local-ai
|
||||||
IMAGE_TYPE?=core
|
IMAGE_TYPE?=core
|
||||||
@@ -572,6 +584,8 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true
|
|||||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||||
|
BACKEND_TRL = trl|python|.|false|true
|
||||||
|
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||||
|
|
||||||
# Helper function to build docker image for a backend
|
# Helper function to build docker image for a backend
|
||||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||||
@@ -629,12 +643,14 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
|||||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||||
|
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||||
|
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||||
|
|
||||||
# Pattern rule for docker-save targets
|
# Pattern rule for docker-save targets
|
||||||
docker-save-%: backend-images
|
docker-save-%: backend-images
|
||||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||||
|
|
||||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed
|
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
### Mock Backend for E2E Tests
|
### Mock Backend for E2E Tests
|
||||||
@@ -646,6 +662,23 @@ build-mock-backend: protogen-go
|
|||||||
clean-mock-backend:
|
clean-mock-backend:
|
||||||
rm -f tests/e2e/mock-backend/mock-backend
|
rm -f tests/e2e/mock-backend/mock-backend
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
### UI E2E Test Server
|
||||||
|
########################################################
|
||||||
|
|
||||||
|
build-ui-test-server: build-mock-backend react-ui protogen-go
|
||||||
|
$(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui
|
||||||
|
|
||||||
|
test-ui-e2e: build-ui-test-server
|
||||||
|
cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test
|
||||||
|
|
||||||
|
test-ui-e2e-docker:
|
||||||
|
docker build -t localai-ui-e2e -f tests/e2e-ui/Dockerfile .
|
||||||
|
docker run --rm localai-ui-e2e
|
||||||
|
|
||||||
|
clean-ui-test-server:
|
||||||
|
rm -f tests/e2e-ui/ui-test-server
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
### END Backends
|
### END Backends
|
||||||
########################################################
|
########################################################
|
||||||
|
|||||||
398
README.md
398
README.md
@@ -5,35 +5,17 @@
|
|||||||
</h1>
|
</h1>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/go-skynet/LocalAI/fork" target="blank">
|
|
||||||
<img src="https://img.shields.io/github/forks/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI forks"/>
|
|
||||||
</a>
|
|
||||||
<a href="https://github.com/go-skynet/LocalAI/stargazers" target="blank">
|
<a href="https://github.com/go-skynet/LocalAI/stargazers" target="blank">
|
||||||
<img src="https://img.shields.io/github/stars/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI stars"/>
|
<img src="https://img.shields.io/github/stars/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI stars"/>
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/go-skynet/LocalAI/pulls" target="blank">
|
|
||||||
<img src="https://img.shields.io/github/issues-pr/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI pull-requests"/>
|
|
||||||
</a>
|
|
||||||
<a href='https://github.com/go-skynet/LocalAI/releases'>
|
<a href='https://github.com/go-skynet/LocalAI/releases'>
|
||||||
<img src='https://img.shields.io/github/release/go-skynet/LocalAI?&label=Latest&style=for-the-badge'>
|
<img src='https://img.shields.io/github/release/go-skynet/LocalAI?&label=Latest&style=for-the-badge'>
|
||||||
</a>
|
</a>
|
||||||
</p>
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<a href="LICENSE" target="blank">
|
<a href="LICENSE" target="blank">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge" alt="LocalAI License"/>
|
<img src="https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge" alt="LocalAI License"/>
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<a href="https://hub.docker.com/r/localai/localai" target="blank">
|
|
||||||
<img src="https://img.shields.io/badge/dockerhub-images-important.svg?logo=Docker" alt="LocalAI Docker hub"/>
|
|
||||||
</a>
|
|
||||||
<a href="https://quay.io/repository/go-skynet/local-ai?tab=tags&tag=latest" target="blank">
|
|
||||||
<img src="https://img.shields.io/badge/quay.io-images-important.svg?" alt="LocalAI Quay.io"/>
|
|
||||||
</a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://twitter.com/LocalAI_API" target="blank">
|
<a href="https://twitter.com/LocalAI_API" target="blank">
|
||||||
<img src="https://img.shields.io/badge/X-%23000000.svg?style=for-the-badge&logo=X&logoColor=white&label=LocalAI_API" alt="Follow LocalAI_API"/>
|
<img src="https://img.shields.io/badge/X-%23000000.svg?style=for-the-badge&logo=X&logoColor=white&label=LocalAI_API" alt="Follow LocalAI_API"/>
|
||||||
@@ -47,363 +29,161 @@
|
|||||||
<a href="https://trendshift.io/repositories/5539" target="_blank"><img src="https://trendshift.io/api/badge/repositories/5539" alt="mudler%2FLocalAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/5539" target="_blank"><img src="https://trendshift.io/api/badge/repositories/5539" alt="mudler%2FLocalAI | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
**LocalAI** is the open-source AI engine. Run any model - LLMs, vision, voice, image, video - on any hardware. No GPU required.
|
||||||
>
|
|
||||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
|
||||||
[](https://t.me/localaiofficial_bot)
|
|
||||||
|
|
||||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
- **Drop-in API compatibility** — OpenAI, Anthropic, ElevenLabs APIs
|
||||||
|
- **35+ backends** — llama.cpp, vLLM, transformers, whisper, diffusers, MLX...
|
||||||
|
- **Any hardware** — NVIDIA, AMD, Intel, Apple Silicon, Vulkan, or CPU-only
|
||||||
|
- **Multi-user ready** — API key auth, user quotas, role-based access
|
||||||
|
- **Built-in AI agents** — autonomous agents with tool use, RAG, MCP, and skills
|
||||||
|
- **Privacy-first** — your data never leaves your infrastructure
|
||||||
|
|
||||||
<p align="center">
|
Created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
||||||
<a href="https://github.com/mudler/LocalAI-examples" target="blank">
|
|
||||||
<img src="https://img.shields.io/badge/📦_Examples_Repository-Browse_Ready--to--Run_Examples-blue?style=for-the-badge" alt="LocalAI Examples Repository"/>
|
|
||||||
</a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
|
> [:book: Documentation](https://localai.io/) | [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) | [💻 Quickstart](https://localai.io/basics/getting_started/) | [🖼️ Models](https://models.localai.io/) | [❓FAQ](https://localai.io/faq/)
|
||||||
|
|
||||||
<details>
|
## Screenshots
|
||||||
<summary><strong>Table of Contents</strong></summary>
|
|
||||||
|
|
||||||
- [Local Stack Family](#local-stack-family)
|
### Chat, Model gallery
|
||||||
- [Screenshots / Video](#screenshots--video)
|
|
||||||
- [Quickstart](#-quickstart)
|
|
||||||
- [macOS Download](#macos-download)
|
|
||||||
- [Containers (Docker, podman, ...)](#containers-docker-podman-)
|
|
||||||
- [Latest project news](#-latest-project-news)
|
|
||||||
- [Features](#-features)
|
|
||||||
- [Supported Backends & Acceleration](#-supported-backends--acceleration)
|
|
||||||
- [Text Generation & Language Models](#text-generation--language-models)
|
|
||||||
- [Audio & Speech Processing](#audio--speech-processing)
|
|
||||||
- [Image & Video Generation](#image--video-generation)
|
|
||||||
- [Specialized AI Tasks](#specialized-ai-tasks)
|
|
||||||
- [Hardware Acceleration Matrix](#hardware-acceleration-matrix)
|
|
||||||
- [Community and integrations](#-community-and-integrations)
|
|
||||||
- [Resources](#-resources)
|
|
||||||
- [Media, Blogs, Social](#book--media-blogs-social)
|
|
||||||
- [Autonomous Development Team](#-autonomous-development-team)
|
|
||||||
- [Citation](#citation)
|
|
||||||
- [Sponsors](#️-sponsors)
|
|
||||||
- [Individual sponsors](#individual-sponsors)
|
|
||||||
- [Star history](#-star-history)
|
|
||||||
- [License](#-license)
|
|
||||||
- [Acknowledgements](#-acknowledgements)
|
|
||||||
- [Contributors](#-contributors)
|
|
||||||
|
|
||||||
</details>
|
https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18
|
||||||
|
|
||||||
## Local Stack Family
|
### Agents
|
||||||
|
|
||||||
Liking LocalAI? LocalAI is part of an integrated suite of AI infrastructure tools, you might also like:
|
https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a
|
||||||
|
|
||||||
- **[LocalAGI](https://github.com/mudler/LocalAGI)** - AI agent orchestration platform with OpenAI Responses API compatibility and advanced agentic capabilities
|
## Quickstart
|
||||||
- **[LocalRecall](https://github.com/mudler/LocalRecall)** - MCP/REST API knowledge base system providing persistent memory and storage for AI agents
|
|
||||||
- 🆕 **[Cogito](https://github.com/mudler/cogito)** - Go library for building intelligent, co-operative agentic software and LLM-powered workflows, focusing on improving results for small, open source language models that scales to any LLM. Powers LocalAGI and LocalAI MCP/Agentic capabilities
|
|
||||||
- 🆕 **[Wiz](https://github.com/mudler/wiz)** - Terminal-based AI agent accessible via Ctrl+Space keybinding. Portable, local-LLM friendly shell assistant with TUI/CLI modes, tool execution with approval, MCP protocol support, and multi-shell compatibility (zsh, bash, fish)
|
|
||||||
- 🆕 **[SkillServer](https://github.com/mudler/skillserver)** - Simple, centralized skills database for AI agents via MCP. Manages skills as Markdown files with MCP server integration, web UI for editing, Git synchronization, and full-text search capabilities
|
|
||||||
|
|
||||||
|
### macOS
|
||||||
## Screenshots / Video
|
|
||||||
|
|
||||||
### Youtube video
|
|
||||||
|
|
||||||
<h1 align="center">
|
|
||||||
<br>
|
|
||||||
<a href="https://www.youtube.com/watch?v=PDqYhB9nNHA" target="_blank"> <img width="300" src="https://img.youtube.com/vi/PDqYhB9nNHA/0.jpg"> </a><br>
|
|
||||||
<br>
|
|
||||||
</h1>
|
|
||||||
|
|
||||||
|
|
||||||
### Screenshots
|
|
||||||
|
|
||||||
| Talk Interface | Generate Audio |
|
|
||||||
| --- | --- |
|
|
||||||
|  |  |
|
|
||||||
|
|
||||||
| Models Overview | Generate Images |
|
|
||||||
| --- | --- |
|
|
||||||
|  |  |
|
|
||||||
|
|
||||||
| Chat Interface | Home |
|
|
||||||
| --- | --- |
|
|
||||||
|  |  |
|
|
||||||
|
|
||||||
| Login | Swarm |
|
|
||||||
| --- | --- |
|
|
||||||
|  |  |
|
|
||||||
|
|
||||||
## 💻 Quickstart
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### macOS Download:
|
|
||||||
|
|
||||||
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
|
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
|
||||||
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
|
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
|
> **Note:** The DMG is not signed by Apple. After installing, run: `sudo xattr -d com.apple.quarantine /Applications/LocalAI.app`. See [#6268](https://github.com/mudler/LocalAI/issues/6268) for details.
|
||||||
|
|
||||||
### Containers (Docker, podman, ...)
|
### Containers (Docker, podman, ...)
|
||||||
|
|
||||||
> **💡 Docker Run vs Docker Start**
|
> Already ran LocalAI before? Use `docker start -i local-ai` to restart an existing container.
|
||||||
>
|
|
||||||
> - `docker run` creates and starts a new container. If a container with the same name already exists, this command will fail.
|
|
||||||
> - `docker start` starts an existing container that was previously created with `docker run`.
|
|
||||||
>
|
|
||||||
> If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai`
|
|
||||||
|
|
||||||
#### CPU only image:
|
#### CPU only:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest
|
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
#### NVIDIA GPU Images:
|
#### NVIDIA GPU:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# CUDA 13.0
|
# CUDA 13
|
||||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-13
|
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-13
|
||||||
|
|
||||||
# CUDA 12.0
|
# CUDA 12
|
||||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12
|
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12
|
||||||
|
|
||||||
# NVIDIA Jetson (L4T) ARM64
|
# NVIDIA Jetson ARM64 (CUDA 12, for AGX Orin and similar)
|
||||||
# CUDA 12 (for Nvidia AGX Orin and similar platforms)
|
|
||||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64
|
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64
|
||||||
|
|
||||||
# CUDA 13 (for Nvidia DGX Spark)
|
# NVIDIA Jetson ARM64 (CUDA 13, for DGX Spark)
|
||||||
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13
|
docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13
|
||||||
```
|
```
|
||||||
|
|
||||||
#### AMD GPU Images (ROCm):
|
#### AMD GPU (ROCm):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas
|
docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Intel GPU Images (oneAPI):
|
#### Intel GPU (oneAPI):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
|
docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Vulkan GPU Images:
|
#### Vulkan GPU:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan
|
docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan
|
||||||
```
|
```
|
||||||
|
|
||||||
To load models:
|
### Loading models
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# From the model gallery (see available models with `local-ai models list`, in the WebUI from the model tab, or visiting https://models.localai.io)
|
# From the model gallery (see available models with `local-ai models list` or at https://models.localai.io)
|
||||||
local-ai run llama-3.2-1b-instruct:q4_k_m
|
local-ai run llama-3.2-1b-instruct:q4_k_m
|
||||||
# Start LocalAI with the phi-2 model directly from huggingface
|
# From Huggingface
|
||||||
local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf
|
local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf
|
||||||
# Install and run a model from the Ollama OCI registry
|
# From the Ollama OCI registry
|
||||||
local-ai run ollama://gemma:2b
|
local-ai run ollama://gemma:2b
|
||||||
# Run a model from a configuration file
|
# From a YAML config
|
||||||
local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
|
||||||
# Install and run a model from a standard OCI registry (e.g., Docker Hub)
|
# From a standard OCI registry (e.g., Docker Hub)
|
||||||
local-ai run oci://localai/phi-2:latest
|
local-ai run oci://localai/phi-2:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
|
> **Automatic Backend Detection**: LocalAI automatically detects your GPU capabilities and downloads the appropriate backend. For advanced options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/).
|
||||||
|
|
||||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
For more details, see the [Getting Started guide](https://localai.io/basics/getting_started/).
|
||||||
|
|
||||||
## 📰 Latest project news
|
## Latest News
|
||||||
- March 2026: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790),[MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
|
||||||
- February 2026: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
|
||||||
- January 2026: **LocalAI 3.10.0** - Major release with Anthropic API support, Open Responses API for stateful agents, video & image generation suite (LTX-2), unified GPU backends, tool streaming & XML parsing, system-aware backend gallery, crash fixes for AVX-only CPUs and AMD VRAM reporting, request tracing, and new backends: **Moonshine** (ultra-fast transcription), **Pocket-TTS** (lightweight TTS). Vulkan arm64 builds now available. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0).
|
|
||||||
- December 2025: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic fitting of models to multiple GPUS(llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Added Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
|
|
||||||
- November 2025: Major improvements to the UX. Among these: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245) and [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
|
|
||||||
- October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools
|
|
||||||
- September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments.
|
|
||||||
- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
|
|
||||||
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
|
|
||||||
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
|
||||||
- June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607).
|
|
||||||
- May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery).
|
|
||||||
- May 2025: Important: image name changes [See release](https://github.com/mudler/LocalAI/releases/tag/v2.29.0)
|
|
||||||
- Apr 2025: Rebrand, WebUI enhancements
|
|
||||||
- Apr 2025: [LocalAGI](https://github.com/mudler/LocalAGI) and [LocalRecall](https://github.com/mudler/LocalRecall) join the LocalAI family stack.
|
|
||||||
- Apr 2025: WebUI overhaul
|
|
||||||
- Feb 2025: Backend cleanup, Breaking changes, new backends (kokoro, OutelTTS, faster-whisper), Nvidia L4T images
|
|
||||||
- Jan 2025: LocalAI model release: https://huggingface.co/mudler/LocalAI-functioncall-phi-4-v0.3, SANA support in diffusers: https://github.com/mudler/LocalAI/pull/4603
|
|
||||||
- Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 )
|
|
||||||
- Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 )
|
|
||||||
- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
|
|
||||||
- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
|
|
||||||
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
|
|
||||||
- July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723. P2P Global community pools: https://github.com/mudler/LocalAI/issues/3113
|
|
||||||
- May 2024: 🔥🔥 Decentralized P2P llama.cpp: https://github.com/mudler/LocalAI/pull/2343 (peer2peer llama.cpp!) 👉 Docs https://localai.io/features/distribute/
|
|
||||||
- May 2024: 🔥🔥 Distributed inferencing: https://github.com/mudler/LocalAI/pull/2324
|
|
||||||
- April 2024: Reranker API: https://github.com/mudler/LocalAI/pull/2121
|
|
||||||
|
|
||||||
Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
- **March 2026**: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790), [MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947)
|
||||||
|
- **February 2026**: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396)
|
||||||
|
- **January 2026**: **LocalAI 3.10.0** — Anthropic API support, Open Responses API, video & image generation (LTX-2), unified GPU backends, tool streaming, Moonshine, Pocket-TTS. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0)
|
||||||
|
- **December 2025**: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic multi-GPU model fitting (llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494)
|
||||||
|
- **November 2025**: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245), [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
|
||||||
|
- **October 2025**: [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support for agentic capabilities
|
||||||
|
- **September 2025**: New Launcher for macOS and Linux, extended backend support for Mac and Nvidia L4T, MLX-Audio, WAN 2.2
|
||||||
|
- **August 2025**: MLX, MLX-VLM, Diffusers, llama.cpp now supported on Apple Silicon
|
||||||
|
- **July 2025**: All backends migrated outside the main binary — [lightweight, modular architecture](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
||||||
|
|
||||||
## 🚀 [Features](https://localai.io/features/)
|
For older news and full release notes, see [GitHub Releases](https://github.com/mudler/LocalAI/releases) and the [News page](https://localai.io/basics/news/).
|
||||||
|
|
||||||
- 🧩 [Backend Gallery](https://localai.io/backends/): Install/remove backends on the fly, powered by OCI images — fully customizable and API-driven.
|
## Features
|
||||||
- 📖 [Text generation with GPTs](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [:book: and more](https://localai.io/model-compatibility/index.html#model-compatibility-table))
|
|
||||||
- 🗣 [Text to Audio](https://localai.io/features/text-to-audio/)
|
|
||||||
- 🔈 [Audio to Text](https://localai.io/features/audio-to-text/)
|
|
||||||
- 🎨 [Image generation](https://localai.io/features/image-generation)
|
|
||||||
- 🔥 [OpenAI-alike tools API](https://localai.io/features/openai-functions/)
|
|
||||||
- ⚡ [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
|
|
||||||
- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
|
|
||||||
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
|
||||||
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
|
||||||
- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
|
|
||||||
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
|
|
||||||
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
|
||||||
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
|
||||||
- 🆕🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - Agentic capabilities with external tools and [LocalAGI's Agentic capabilities](https://github.com/mudler/LocalAGI)
|
|
||||||
- 🆕🤖 [Built-in Agents](https://localai.io/features/agents/) - Autonomous AI agents with tool use, knowledge base (RAG), skills, SSE streaming, import/export, and [Agent Hub](https://agenthub.localai.io) — powered by [LocalAGI](https://github.com/mudler/LocalAGI)
|
|
||||||
- 🔊 Voice activity detection (Silero-VAD support)
|
|
||||||
- 🌍 Integrated WebUI!
|
|
||||||
|
|
||||||
## 🧩 Supported Backends & Acceleration
|
- [Text generation](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [and more](https://localai.io/model-compatibility/))
|
||||||
|
- [Text to Audio](https://localai.io/features/text-to-audio/)
|
||||||
|
- [Audio to Text](https://localai.io/features/audio-to-text/)
|
||||||
|
- [Image generation](https://localai.io/features/image-generation)
|
||||||
|
- [OpenAI-compatible tools API](https://localai.io/features/openai-functions/)
|
||||||
|
- [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech)
|
||||||
|
- [Embeddings generation](https://localai.io/features/embeddings/)
|
||||||
|
- [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||||
|
- [Download models from Huggingface](https://localai.io/models/)
|
||||||
|
- [Vision API](https://localai.io/features/gpt-vision/)
|
||||||
|
- [Object Detection](https://localai.io/features/object-detection/)
|
||||||
|
- [Reranker API](https://localai.io/features/reranker/)
|
||||||
|
- [P2P Inferencing](https://localai.io/features/distribute/)
|
||||||
|
- [Distributed Mode](https://localai.io/features/distributed-mode/) — Horizontal scaling with PostgreSQL + NATS
|
||||||
|
- [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/)
|
||||||
|
- [Built-in Agents](https://localai.io/features/agents/) — Autonomous AI agents with tool use, RAG, skills, SSE streaming, and [Agent Hub](https://agenthub.localai.io)
|
||||||
|
- [Backend Gallery](https://localai.io/backends/) — Install/remove backends on the fly via OCI images
|
||||||
|
- Voice Activity Detection (Silero-VAD)
|
||||||
|
- Integrated WebUI
|
||||||
|
|
||||||
LocalAI supports a comprehensive range of AI backends with multiple acceleration options:
|
## Supported Backends & Acceleration
|
||||||
|
|
||||||
### Text Generation & Language Models
|
LocalAI supports **35+ backends** including llama.cpp, vLLM, transformers, whisper.cpp, diffusers, MLX, MLX-VLM, and many more. Hardware acceleration is available for **NVIDIA** (CUDA 12/13), **AMD** (ROCm), **Intel** (oneAPI/SYCL), **Apple Silicon** (Metal), **Vulkan**, and **NVIDIA Jetson** (L4T). All backends can be installed on-the-fly from the [Backend Gallery](https://localai.io/backends/).
|
||||||
| Backend | Description | Acceleration Support |
|
|
||||||
|---------|-------------|---------------------|
|
|
||||||
| **llama.cpp** | LLM inference in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, Metal, CPU |
|
|
||||||
| **vLLM** | Fast LLM inference with PagedAttention | CUDA 12/13, ROCm, Intel |
|
|
||||||
| **transformers** | HuggingFace transformers framework | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) |
|
|
||||||
| **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) |
|
|
||||||
| **vLLM Omni** | Multimodal vLLM with vision and audio | CUDA 12/13, ROCm, Intel |
|
|
||||||
|
|
||||||
### Audio & Speech Processing
|
See the full [Backend & Model Compatibility Table](https://localai.io/model-compatibility/) and [GPU Acceleration guide](https://localai.io/features/gpu-acceleration/).
|
||||||
| Backend | Description | Acceleration Support |
|
|
||||||
|---------|-------------|---------------------|
|
|
||||||
| **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, CPU |
|
|
||||||
| **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **moonshine** | Ultra-fast transcription engine for low-end devices | CUDA 12/13, Metal, CPU |
|
|
||||||
| **coqui** | Advanced TTS with 1100+ languages | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **kokoro** | Lightweight TTS model | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **chatterbox** | Production-grade TTS | CUDA 12/13, CPU |
|
|
||||||
| **piper** | Fast neural TTS system | CPU |
|
|
||||||
| **kitten-tts** | Kitten TTS models | CPU |
|
|
||||||
| **silero-vad** | Voice Activity Detection | CPU |
|
|
||||||
| **neutts** | Text-to-speech with voice cloning | CUDA 12/13, ROCm, CPU |
|
|
||||||
| **vibevoice** | Real-time TTS with voice cloning | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **pocket-tts** | Lightweight CPU-based TTS | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **qwen-tts** | High-quality TTS with custom voice, voice design, and voice cloning | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **nemo** | NVIDIA NeMo framework for speech models | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **outetts** | OuteTTS with voice cloning | CUDA 12/13, CPU |
|
|
||||||
| **faster-qwen3-tts** | Faster Qwen3 TTS | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **qwen-asr** | Qwen ASR speech recognition | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **voxcpm** | VoxCPM speech understanding | CUDA 12/13, Metal, CPU |
|
|
||||||
| **whisperx** | Enhanced Whisper transcription | CUDA 12/13, ROCm, Intel, CPU |
|
|
||||||
| **ace-step** | Music generation from text descriptions, lyrics, or audio samples | CUDA 12/13, ROCm, Intel, Metal, CPU |
|
|
||||||
|
|
||||||
### Image & Video Generation
|
## Resources
|
||||||
| Backend | Description | Acceleration Support |
|
|
||||||
|---------|-------------|---------------------|
|
|
||||||
| **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12/13, Intel SYCL, Vulkan, CPU |
|
|
||||||
| **diffusers** | HuggingFace diffusion models | CUDA 12/13, ROCm, Intel, Metal, CPU |
|
|
||||||
|
|
||||||
### Specialized AI Tasks
|
- [Documentation](https://localai.io/)
|
||||||
| Backend | Description | Acceleration Support |
|
- [LLM fine-tuning guide](https://localai.io/docs/advanced/fine-tuning/)
|
||||||
|---------|-------------|---------------------|
|
- [Build from source](https://localai.io/basics/build/)
|
||||||
| **rfdetr** | Real-time object detection | CUDA 12/13, Intel, CPU |
|
- [Kubernetes installation](https://localai.io/basics/getting_started/#run-localai-in-kubernetes)
|
||||||
| **rerankers** | Document reranking API | CUDA 12/13, ROCm, Intel, CPU |
|
- [Integrations & community projects](https://localai.io/docs/integrations/)
|
||||||
| **local-store** | Vector database | CPU |
|
- [Media & blog posts](https://localai.io/basics/news/#media-blogs-social)
|
||||||
| **huggingface** | HuggingFace API integration | API-based |
|
- [Examples](https://github.com/mudler/LocalAI-examples)
|
||||||
|
|
||||||
### Hardware Acceleration Matrix
|
## Autonomous Development Team
|
||||||
|
|
||||||
| Acceleration Type | Supported Backends | Hardware Support |
|
LocalAI is helped being maintained by a team of autonomous AI agents led by an AI Scrum Master.
|
||||||
|-------------------|-------------------|------------------|
|
|
||||||
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
|
|
||||||
| **NVIDIA CUDA 13** | All CUDA-compatible backends | Nvidia hardware |
|
|
||||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, neutts, vibevoice, pocket-tts, qwen-tts, ace-step | AMD Graphics |
|
|
||||||
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, coqui, kokoro, vibevoice, pocket-tts, qwen-tts, ace-step | Intel Arc, Intel iGPUs |
|
|
||||||
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, moonshine, ace-step | Apple M1/M2/M3+ |
|
|
||||||
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
|
|
||||||
| **NVIDIA Jetson (CUDA 12)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr, ace-step | ARM64 embedded AI (AGX Orin, etc.) |
|
|
||||||
| **NVIDIA Jetson (CUDA 13)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (DGX Spark) |
|
|
||||||
| **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support |
|
|
||||||
|
|
||||||
### 🔗 Community and integrations
|
- **Live Reports**: [reports.localai.io](http://reports.localai.io)
|
||||||
|
- **Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6)
|
||||||
Build and deploy custom containers:
|
- **Blog Post**: [Learn about the experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
||||||
- https://github.com/sozercan/aikit
|
|
||||||
|
|
||||||
WebUIs:
|
|
||||||
- https://github.com/Jirubizu/localai-admin
|
|
||||||
- https://github.com/go-skynet/LocalAI-frontend
|
|
||||||
- QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot
|
|
||||||
|
|
||||||
Agentic Libraries:
|
|
||||||
- https://github.com/mudler/cogito
|
|
||||||
|
|
||||||
MCPs:
|
|
||||||
- https://github.com/mudler/MCPs
|
|
||||||
|
|
||||||
OS Assistant:
|
|
||||||
|
|
||||||
- https://github.com/mudler/Keygeist - Keygeist is an AI-powered keyboard operator that listens for key combinations and responds with AI-generated text typed directly into your Linux box.
|
|
||||||
|
|
||||||
Model galleries
|
|
||||||
- https://github.com/go-skynet/model-gallery
|
|
||||||
|
|
||||||
Voice:
|
|
||||||
- https://github.com/richiejp/VoxInput
|
|
||||||
|
|
||||||
Other:
|
|
||||||
- Helm chart https://github.com/go-skynet/helm-charts
|
|
||||||
- VSCode extension https://github.com/badgooooor/localai-vscode-plugin
|
|
||||||
- Langchain: https://python.langchain.com/docs/integrations/providers/localai/
|
|
||||||
- Terminal utility https://github.com/djcopley/ShellOracle
|
|
||||||
- Local Smart assistant https://github.com/mudler/LocalAGI
|
|
||||||
- Home Assistant https://github.com/drndos/hass-openai-custom-conversation / https://github.com/valentinfrlch/ha-llmvision / https://github.com/loryanstrant/HA-LocalAI-Monitor
|
|
||||||
- Discord bot https://github.com/mudler/LocalAGI/tree/main/examples/discord
|
|
||||||
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
|
|
||||||
- Shell-Pilot(Interact with LLM using LocalAI models via pure shell scripts on your Linux or MacOS system) https://github.com/reid41/shell-pilot
|
|
||||||
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
|
|
||||||
- Another Telegram Bot https://github.com/JackBekket/Hellper
|
|
||||||
- Auto-documentation https://github.com/JackBekket/Reflexia
|
|
||||||
- Github bot which answer on issues, with code and documentation as context https://github.com/JackBekket/GitHelper
|
|
||||||
- Github Actions: https://github.com/marketplace/actions/start-localai
|
|
||||||
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
|
|
||||||
|
|
||||||
|
|
||||||
### 🔗 Resources
|
|
||||||
|
|
||||||
- [LLM finetuning guide](https://localai.io/docs/advanced/fine-tuning/)
|
|
||||||
- [How to build locally](https://localai.io/basics/build/index.html)
|
|
||||||
- [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes)
|
|
||||||
- [Projects integrating LocalAI](https://localai.io/docs/integrations/)
|
|
||||||
- [How tos section](https://io.midori-ai.xyz/howtos/) (curated by our community)
|
|
||||||
|
|
||||||
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
|
||||||
|
|
||||||
- 🆕 [LocalAI Autonomous Dev Team Blog Post](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
|
||||||
|
|
||||||
- [Run Visual studio code with LocalAI (SUSE)](https://www.suse.com/c/running-ai-locally/)
|
|
||||||
- 🆕 [Run LocalAI on Jetson Nano Devkit](https://mudler.pm/posts/local-ai-jetson-nano-devkit/)
|
|
||||||
- [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/blog/low-code-llm-apps-with-local-ai-flowise-and-pulumi/)
|
|
||||||
- [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance)
|
|
||||||
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
|
|
||||||
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
|
|
||||||
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
|
||||||
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65)
|
|
||||||
|
|
||||||
|
|
||||||
## 🤖 Autonomous Development Team
|
|
||||||
|
|
||||||
LocalAI is now helped being maintained (for small tasks!) by a full team of autonomous AI agents led by an AI Scrum Master! This experiment demonstrates how open source projects can leverage AI agents for sustainable, long-term maintenance.
|
|
||||||
|
|
||||||
- **📊 Live Reports**: [Automatically generated reports](http://reports.localai.io)
|
|
||||||
- **📋 Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6)
|
|
||||||
- **📝 Blog Post**: [Learn about the autonomous dev team experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/)
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
@@ -419,7 +199,7 @@ If you utilize this repository, data in a downstream project, please consider ci
|
|||||||
howpublished = {\url{https://github.com/go-skynet/LocalAI}},
|
howpublished = {\url{https://github.com/go-skynet/LocalAI}},
|
||||||
```
|
```
|
||||||
|
|
||||||
## ❤️ Sponsors
|
## Sponsors
|
||||||
|
|
||||||
> Do you find LocalAI useful?
|
> Do you find LocalAI useful?
|
||||||
|
|
||||||
@@ -438,19 +218,19 @@ A huge thank you to our generous sponsors who support this project covering CI e
|
|||||||
|
|
||||||
### Individual sponsors
|
### Individual sponsors
|
||||||
|
|
||||||
A special thanks to individual sponsors that contributed to the project, a full list is in [Github](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler), a special shout out goes to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
|
A special thanks to individual sponsors, a full list is on [GitHub](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler). Special shout out to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone!
|
||||||
|
|
||||||
## 🌟 Star history
|
## Star history
|
||||||
|
|
||||||
[](https://star-history.com/#go-skynet/LocalAI&Date)
|
[](https://star-history.com/#go-skynet/LocalAI&Date)
|
||||||
|
|
||||||
## 📖 License
|
## License
|
||||||
|
|
||||||
LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/).
|
LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/).
|
||||||
|
|
||||||
MIT - Author Ettore Di Giacinto <mudler@localai.io>
|
MIT - Author Ettore Di Giacinto <mudler@localai.io>
|
||||||
|
|
||||||
## 🙇 Acknowledgements
|
## Acknowledgements
|
||||||
|
|
||||||
LocalAI couldn't have been built without the help of great software already available from the community. Thank you!
|
LocalAI couldn't have been built without the help of great software already available from the community. Thank you!
|
||||||
|
|
||||||
@@ -463,9 +243,9 @@ LocalAI couldn't have been built without the help of great software already avai
|
|||||||
- https://github.com/rhasspy/piper
|
- https://github.com/rhasspy/piper
|
||||||
- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation
|
- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation
|
||||||
|
|
||||||
## 🤗 Contributors
|
## Contributors
|
||||||
|
|
||||||
This is a community project, a special thanks to our contributors! 🤗
|
This is a community project, a special thanks to our contributors!
|
||||||
<a href="https://github.com/go-skynet/LocalAI/graphs/contributors">
|
<a href="https://github.com/go-skynet/LocalAI/graphs/contributors">
|
||||||
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" />
|
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" />
|
||||||
</a>
|
</a>
|
||||||
|
|||||||
@@ -39,6 +39,19 @@ service Backend {
|
|||||||
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
||||||
|
|
||||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||||
|
|
||||||
|
// Fine-tuning RPCs
|
||||||
|
rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {}
|
||||||
|
rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {}
|
||||||
|
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
|
||||||
|
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
|
||||||
|
rpc ExportModel(ExportModelRequest) returns (Result) {}
|
||||||
|
|
||||||
|
// Quantization RPCs
|
||||||
|
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
|
||||||
|
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||||
|
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define the empty request
|
// Define the empty request
|
||||||
@@ -166,6 +179,7 @@ message PredictOptions {
|
|||||||
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
||||||
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
||||||
map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking)
|
map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking)
|
||||||
|
float MinP = 53; // Minimum probability sampling threshold (0.0 = disabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolCallDelta represents an incremental tool call update from the C++ parser.
|
// ToolCallDelta represents an incremental tool call update from the C++ parser.
|
||||||
@@ -472,7 +486,7 @@ message ToolFormatMarkers {
|
|||||||
string id_field = 16; // e.g., "id"
|
string id_field = 16; // e.g., "id"
|
||||||
bool fun_name_is_key = 17;
|
bool fun_name_is_key = 17;
|
||||||
bool tools_array_wrapped = 18;
|
bool tools_array_wrapped = 18;
|
||||||
bool uses_python_dicts = 19;
|
reserved 19;
|
||||||
|
|
||||||
// Reasoning markers
|
// Reasoning markers
|
||||||
string reasoning_start = 20; // e.g., "<think>"
|
string reasoning_start = 20; // e.g., "<think>"
|
||||||
@@ -528,3 +542,139 @@ message ModelMetadataResponse {
|
|||||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||||
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fine-tuning messages
|
||||||
|
|
||||||
|
message FineTuneRequest {
|
||||||
|
// Model identification
|
||||||
|
string model = 1; // HF model name or local path
|
||||||
|
string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train
|
||||||
|
string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training"
|
||||||
|
|
||||||
|
// Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion)
|
||||||
|
int32 adapter_rank = 10; // LoRA rank (r), default 16
|
||||||
|
int32 adapter_alpha = 11; // scaling factor, default 16
|
||||||
|
float adapter_dropout = 12; // default 0.0
|
||||||
|
repeated string target_modules = 13; // layer names to adapt
|
||||||
|
|
||||||
|
// Universal training hyperparameters
|
||||||
|
float learning_rate = 20; // default 2e-4
|
||||||
|
int32 num_epochs = 21; // default 3
|
||||||
|
int32 batch_size = 22; // default 2
|
||||||
|
int32 gradient_accumulation_steps = 23; // default 4
|
||||||
|
int32 warmup_steps = 24; // default 5
|
||||||
|
int32 max_steps = 25; // 0 = use epochs
|
||||||
|
int32 save_steps = 26; // 0 = only save final
|
||||||
|
float weight_decay = 27; // default 0.01
|
||||||
|
bool gradient_checkpointing = 28;
|
||||||
|
string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy
|
||||||
|
int32 seed = 30; // default 3407
|
||||||
|
string mixed_precision = 31; // fp16, bf16, fp8, no
|
||||||
|
|
||||||
|
// Dataset
|
||||||
|
string dataset_source = 40; // HF dataset ID, local file/dir path
|
||||||
|
string dataset_split = 41; // train, test, etc.
|
||||||
|
|
||||||
|
// Output
|
||||||
|
string output_dir = 50;
|
||||||
|
string job_id = 51; // client-assigned or auto-generated
|
||||||
|
|
||||||
|
// Resume training from a checkpoint
|
||||||
|
string resume_from_checkpoint = 55; // path to checkpoint dir to resume from
|
||||||
|
|
||||||
|
// Backend-specific AND method-specific extensibility
|
||||||
|
map<string, string> extra_options = 60;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FineTuneJobResult {
|
||||||
|
string job_id = 1;
|
||||||
|
bool success = 2;
|
||||||
|
string message = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FineTuneProgressRequest {
|
||||||
|
string job_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FineTuneProgressUpdate {
|
||||||
|
string job_id = 1;
|
||||||
|
int32 current_step = 2;
|
||||||
|
int32 total_steps = 3;
|
||||||
|
float current_epoch = 4;
|
||||||
|
float total_epochs = 5;
|
||||||
|
float loss = 6;
|
||||||
|
float learning_rate = 7;
|
||||||
|
float grad_norm = 8;
|
||||||
|
float eval_loss = 9;
|
||||||
|
float eta_seconds = 10;
|
||||||
|
float progress_percent = 11;
|
||||||
|
string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped
|
||||||
|
string message = 13;
|
||||||
|
string checkpoint_path = 14; // set when a checkpoint is saved
|
||||||
|
string sample_path = 15; // set when a sample is generated (video/image backends)
|
||||||
|
map<string, float> extra_metrics = 16; // method-specific metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
message FineTuneStopRequest {
|
||||||
|
string job_id = 1;
|
||||||
|
bool save_checkpoint = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListCheckpointsRequest {
|
||||||
|
string output_dir = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListCheckpointsResponse {
|
||||||
|
repeated CheckpointInfo checkpoints = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CheckpointInfo {
|
||||||
|
string path = 1;
|
||||||
|
int32 step = 2;
|
||||||
|
float epoch = 3;
|
||||||
|
float loss = 4;
|
||||||
|
string created_at = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExportModelRequest {
|
||||||
|
string checkpoint_path = 1;
|
||||||
|
string output_path = 2;
|
||||||
|
string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers
|
||||||
|
string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||||
|
string model = 5; // base model name (for merge operations)
|
||||||
|
map<string, string> extra_options = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quantization messages
|
||||||
|
|
||||||
|
message QuantizationRequest {
|
||||||
|
string model = 1; // HF model name or local path
|
||||||
|
string quantization_type = 2; // q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||||
|
string output_dir = 3; // where to write output files
|
||||||
|
string job_id = 4; // client-assigned job ID
|
||||||
|
map<string, string> extra_options = 5; // hf_token, custom flags, etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
message QuantizationJobResult {
|
||||||
|
string job_id = 1;
|
||||||
|
bool success = 2;
|
||||||
|
string message = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message QuantizationProgressRequest {
|
||||||
|
string job_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message QuantizationProgressUpdate {
|
||||||
|
string job_id = 1;
|
||||||
|
float progress_percent = 2;
|
||||||
|
string status = 3; // queued, downloading, converting, quantizing, completed, failed, stopped
|
||||||
|
string message = 4;
|
||||||
|
string output_file = 5; // set when completed — path to the output GGUF file
|
||||||
|
map<string, float> extra_metrics = 6; // e.g. file_size_mb, compression_ratio
|
||||||
|
}
|
||||||
|
|
||||||
|
message QuantizationStopRequest {
|
||||||
|
string job_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
LLAMA_VERSION?=e30f1fdf74ea9238ff562901aa974c75aab6619b
|
LLAMA_VERSION?=95a6ebabb277c4cc18247e7bc2a5502133caca63
|
||||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
|
|||||||
@@ -22,8 +22,10 @@
|
|||||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||||
#include <grpcpp/grpcpp.h>
|
#include <grpcpp/grpcpp.h>
|
||||||
#include <grpcpp/health_check_service_interface.h>
|
#include <grpcpp/health_check_service_interface.h>
|
||||||
|
#include <grpcpp/security/server_credentials.h>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <cstdlib>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
@@ -37,6 +39,47 @@ using grpc::Server;
|
|||||||
using grpc::ServerBuilder;
|
using grpc::ServerBuilder;
|
||||||
using grpc::ServerContext;
|
using grpc::ServerContext;
|
||||||
using grpc::Status;
|
using grpc::Status;
|
||||||
|
|
||||||
|
// gRPC bearer token auth via AuthMetadataProcessor for distributed mode.
|
||||||
|
// Reads LOCALAI_GRPC_AUTH_TOKEN from the environment. When set, rejects
|
||||||
|
// requests without a matching "authorization: Bearer <token>" metadata header.
|
||||||
|
class TokenAuthMetadataProcessor : public grpc::AuthMetadataProcessor {
|
||||||
|
public:
|
||||||
|
explicit TokenAuthMetadataProcessor(const std::string& token) : token_(token) {}
|
||||||
|
|
||||||
|
bool IsBlocking() const override { return false; }
|
||||||
|
|
||||||
|
grpc::Status Process(const InputMetadata& auth_metadata,
|
||||||
|
grpc::AuthContext* /*context*/,
|
||||||
|
OutputMetadata* /*consumed_auth_metadata*/,
|
||||||
|
OutputMetadata* /*response_metadata*/) override {
|
||||||
|
auto it = auth_metadata.find("authorization");
|
||||||
|
if (it != auth_metadata.end()) {
|
||||||
|
std::string expected = "Bearer " + token_;
|
||||||
|
std::string got(it->second.data(), it->second.size());
|
||||||
|
// Constant-time comparison
|
||||||
|
if (expected.size() == got.size() && ct_memcmp(expected.data(), got.data(), expected.size()) == 0) {
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string token_;
|
||||||
|
|
||||||
|
// Minimal constant-time comparison (avoids OpenSSL dependency)
|
||||||
|
static int ct_memcmp(const void* a, const void* b, size_t n) {
|
||||||
|
const unsigned char* pa = static_cast<const unsigned char*>(a);
|
||||||
|
const unsigned char* pb = static_cast<const unsigned char*>(b);
|
||||||
|
unsigned char result = 0;
|
||||||
|
for (size_t i = 0; i < n; i++) {
|
||||||
|
result |= pa[i] ^ pb[i];
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// END LocalAI
|
// END LocalAI
|
||||||
|
|
||||||
|
|
||||||
@@ -136,6 +179,7 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
|
|||||||
data["mirostat_eta"] = predict->mirostateta();
|
data["mirostat_eta"] = predict->mirostateta();
|
||||||
data["n_keep"] = predict->nkeep();
|
data["n_keep"] = predict->nkeep();
|
||||||
data["seed"] = predict->seed();
|
data["seed"] = predict->seed();
|
||||||
|
data["min_p"] = predict->minp();
|
||||||
|
|
||||||
|
|
||||||
std::string grammar_str = predict->grammar();
|
std::string grammar_str = predict->grammar();
|
||||||
@@ -2687,7 +2731,6 @@ public:
|
|||||||
tf->set_id_field(ap.tools.format.id_field);
|
tf->set_id_field(ap.tools.format.id_field);
|
||||||
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
|
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
|
||||||
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
|
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
|
||||||
tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts);
|
|
||||||
tf->set_function_field(ap.tools.format.function_field);
|
tf->set_function_field(ap.tools.format.function_field);
|
||||||
|
|
||||||
tf->set_gen_id_field(ap.tools.format.gen_id_field);
|
tf->set_gen_id_field(ap.tools.format.gen_id_field);
|
||||||
@@ -2760,11 +2803,24 @@ int main(int argc, char** argv) {
|
|||||||
BackendServiceImpl service(ctx_server);
|
BackendServiceImpl service(ctx_server);
|
||||||
|
|
||||||
ServerBuilder builder;
|
ServerBuilder builder;
|
||||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
// Add bearer token auth via AuthMetadataProcessor if LOCALAI_GRPC_AUTH_TOKEN is set
|
||||||
|
const char* auth_token = std::getenv("LOCALAI_GRPC_AUTH_TOKEN");
|
||||||
|
std::shared_ptr<grpc::ServerCredentials> creds;
|
||||||
|
if (auth_token != nullptr && auth_token[0] != '\0') {
|
||||||
|
creds = grpc::InsecureServerCredentials();
|
||||||
|
creds->SetAuthMetadataProcessor(
|
||||||
|
std::make_shared<TokenAuthMetadataProcessor>(auth_token));
|
||||||
|
std::cout << "gRPC auth enabled via LOCALAI_GRPC_AUTH_TOKEN" << std::endl;
|
||||||
|
} else {
|
||||||
|
creds = grpc::InsecureServerCredentials();
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.AddListeningPort(server_address, creds);
|
||||||
builder.RegisterService(&service);
|
builder.RegisterService(&service);
|
||||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||||
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
||||||
|
|
||||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||||
// run the HTTP server in a thread - see comment below
|
// run the HTTP server in a thread - see comment below
|
||||||
std::thread t([&]()
|
std::thread t([&]()
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
|||||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||||
# ARM64 architecture
|
# ARM64 architecture
|
||||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||||
@@ -33,6 +36,9 @@ elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
|||||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||||
|
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||||
else
|
else
|
||||||
echo "Error: Could not detect architecture"
|
echo "Error: Could not detect architecture"
|
||||||
exit 1
|
exit 1
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# acestep.cpp version
|
# acestep.cpp version
|
||||||
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
|
||||||
ACESTEP_CPP_VERSION?=5aa065445541094cba934299cd498bbb9fa5c434
|
ACESTEP_CPP_VERSION?=6f35c874ee11e86d511b860019b84976f5b52d3a
|
||||||
SO_TARGET?=libgoacestepcpp.so
|
SO_TARGET?=libgoacestepcpp.so
|
||||||
|
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|||||||
@@ -106,12 +106,13 @@ func TestLoadModel(t *testing.T) {
|
|||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
client := pb.NewBackendClient(conn)
|
client := pb.NewBackendClient(conn)
|
||||||
|
|
||||||
// Get base directory from main model file for relative paths
|
// Get base directory from main model file for relative paths
|
||||||
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
|
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
|
||||||
|
|
||||||
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||||
ModelFile: mainModelPath,
|
ModelFile: mainModelPath,
|
||||||
|
ModelPath: modelDir,
|
||||||
Options: []string{
|
Options: []string{
|
||||||
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
|
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
|
||||||
"dit_model:acestep-v15-turbo-Q8_0.gguf",
|
"dit_model:acestep-v15-turbo-Q8_0.gguf",
|
||||||
@@ -133,7 +134,7 @@ func TestSoundGeneration(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tmpDir)
|
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||||
|
|
||||||
outputFile := filepath.Join(tmpDir, "output.wav")
|
outputFile := filepath.Join(tmpDir, "output.wav")
|
||||||
|
|
||||||
@@ -151,6 +152,7 @@ func TestSoundGeneration(t *testing.T) {
|
|||||||
// Load models
|
// Load models
|
||||||
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
|
||||||
ModelFile: mainModelPath,
|
ModelFile: mainModelPath,
|
||||||
|
ModelPath: modelDir,
|
||||||
Options: []string{
|
Options: []string{
|
||||||
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
|
"text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
|
||||||
"dit_model:acestep-v15-turbo-Q8_0.gguf",
|
"dit_model:acestep-v15-turbo-Q8_0.gguf",
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
|
||||||
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
|
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,23 +24,23 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
|
|||||||
lmModel := opts.ModelFile
|
lmModel := opts.ModelFile
|
||||||
|
|
||||||
// Get the base directory from ModelFile for resolving relative paths
|
// Get the base directory from ModelFile for resolving relative paths
|
||||||
baseDir := filepath.Dir(lmModel)
|
baseDir := opts.ModelPath
|
||||||
|
|
||||||
var textEncoderModel, ditModel, vaeModel string
|
var textEncoderModel, ditModel, vaeModel string
|
||||||
|
|
||||||
for _, oo := range opts.Options {
|
for _, oo := range opts.Options {
|
||||||
parts := strings.SplitN(oo, ":", 2)
|
key, value, found := strings.Cut(oo, ":")
|
||||||
if len(parts) != 2 {
|
if !found {
|
||||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch parts[0] {
|
switch key {
|
||||||
case "text_encoder_model":
|
case "text_encoder_model":
|
||||||
textEncoderModel = parts[1]
|
textEncoderModel = value
|
||||||
case "dit_model":
|
case "dit_model":
|
||||||
ditModel = parts[1]
|
ditModel = value
|
||||||
case "vae_model":
|
case "vae_model":
|
||||||
vaeModel = parts[1]
|
vaeModel = value
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ type LLM struct {
|
|||||||
draftModel *llama.LLama
|
draftModel *llama.LLama
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Free releases GPU resources and frees the llama model
|
// Free releases GPU resources and frees the llama model
|
||||||
// This should be called when the model is being unloaded to properly release VRAM
|
// This should be called when the model is being unloaded to properly release VRAM
|
||||||
func (llm *LLM) Free() error {
|
func (llm *LLM) Free() error {
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build debug
|
//go:build debug
|
||||||
// +build debug
|
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build !debug
|
//go:build !debug
|
||||||
// +build !debug
|
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
|||||||
@@ -332,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
|||||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||||
|
|
||||||
var dot float32
|
var dot float32
|
||||||
for i := 0; i < len(k1); i++ {
|
for i := range len(k1) {
|
||||||
dot += k1[i] * k2[i]
|
dot += k1[i] * k2[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -419,7 +419,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
|||||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||||
|
|
||||||
var dot, mag2 float64
|
var dot, mag2 float64
|
||||||
for i := 0; i < len(k1); i++ {
|
for i := range len(k1) {
|
||||||
dot += float64(k1[i] * k2[i])
|
dot += float64(k1[i] * k2[i])
|
||||||
mag2 += float64(k2[i] * k2[i])
|
mag2 += float64(k2[i] * k2[i])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -701,7 +701,7 @@ var _ = Describe("Opus", func() {
|
|||||||
// to one-shot (only difference is resampler batch boundaries).
|
// to one-shot (only difference is resampler batch boundaries).
|
||||||
var maxDiff float64
|
var maxDiff float64
|
||||||
var sumDiffSq float64
|
var sumDiffSq float64
|
||||||
for i := 0; i < minLen; i++ {
|
for i := range minLen {
|
||||||
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
|
diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i]))
|
||||||
if diff > maxDiff {
|
if diff > maxDiff {
|
||||||
maxDiff = diff
|
maxDiff = diff
|
||||||
@@ -774,7 +774,7 @@ var _ = Describe("Opus", func() {
|
|||||||
minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
|
minLen := min(len(refTail), min(len(persistentTail), len(freshTail)))
|
||||||
|
|
||||||
var persistentMaxDiff, freshMaxDiff float64
|
var persistentMaxDiff, freshMaxDiff float64
|
||||||
for i := 0; i < minLen; i++ {
|
for i := range minLen {
|
||||||
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
|
pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i]))
|
||||||
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
|
fd := math.Abs(float64(refTail[i]) - float64(freshTail[i]))
|
||||||
if pd > persistentMaxDiff {
|
if pd > persistentMaxDiff {
|
||||||
@@ -932,7 +932,7 @@ var _ = Describe("Opus", func() {
|
|||||||
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
|
GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n",
|
||||||
mean, stddev, stddev/mean, 16000.0/440.0/2.0)
|
mean, stddev, stddev/mean, 16000.0/440.0/2.0)
|
||||||
|
|
||||||
Expect(stddev / mean).To(BeNumerically("<", 0.15),
|
Expect(stddev/mean).To(BeNumerically("<", 0.15),
|
||||||
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
|
fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean))
|
||||||
|
|
||||||
// Also check frequency is correct
|
// Also check frequency is correct
|
||||||
@@ -978,7 +978,7 @@ var _ = Describe("Opus", func() {
|
|||||||
|
|
||||||
// Every sample must be identical — the resampler is deterministic
|
// Every sample must be identical — the resampler is deterministic
|
||||||
var maxDiff float64
|
var maxDiff float64
|
||||||
for i := 0; i < len(oneShot); i++ {
|
for i := range len(oneShot) {
|
||||||
diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
|
diff := math.Abs(float64(oneShot[i]) - float64(batched[i]))
|
||||||
if diff > maxDiff {
|
if diff > maxDiff {
|
||||||
maxDiff = diff
|
maxDiff = diff
|
||||||
@@ -1037,13 +1037,13 @@ var _ = Describe("Opus", func() {
|
|||||||
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
|
binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen))
|
||||||
copy(hdr[8:12], "WAVE")
|
copy(hdr[8:12], "WAVE")
|
||||||
copy(hdr[12:16], "fmt ")
|
copy(hdr[12:16], "fmt ")
|
||||||
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size
|
||||||
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM
|
||||||
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono
|
||||||
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate
|
||||||
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate
|
||||||
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align
|
||||||
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample
|
||||||
copy(hdr[36:40], "data")
|
copy(hdr[36:40], "data")
|
||||||
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
|
binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen))
|
||||||
|
|
||||||
@@ -1126,7 +1126,7 @@ var _ = Describe("Opus", func() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
pcm := make([]byte, toneNumSamples*2)
|
pcm := make([]byte, toneNumSamples*2)
|
||||||
for i := 0; i < toneNumSamples; i++ {
|
for i := range toneNumSamples {
|
||||||
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
|
sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate)))
|
||||||
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
|
binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# stablediffusion.cpp (ggml)
|
# stablediffusion.cpp (ggml)
|
||||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||||
STABLEDIFFUSION_GGML_VERSION?=d6dd6d7b555c233bb9bc9f20b4751eb8c9269743
|
STABLEDIFFUSION_GGML_VERSION?=87ecb95cbc65dc8e58e3d88f4f4a59a0939796f5
|
||||||
|
|
||||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||||
|
|
||||||
|
|||||||
@@ -27,107 +27,7 @@
|
|||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
|
||||||
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
|
||||||
const char* sample_method_str[] = {
|
|
||||||
"euler",
|
|
||||||
"euler_a",
|
|
||||||
"heun",
|
|
||||||
"dpm2",
|
|
||||||
"dpm++2s_a",
|
|
||||||
"dpm++2m",
|
|
||||||
"dpm++2mv2",
|
|
||||||
"ipndm",
|
|
||||||
"ipndm_v",
|
|
||||||
"lcm",
|
|
||||||
"ddim_trailing",
|
|
||||||
"tcd",
|
|
||||||
"res_multistep",
|
|
||||||
"res_2s",
|
|
||||||
};
|
|
||||||
|
|
||||||
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
|
|
||||||
|
|
||||||
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
|
||||||
const char* schedulers[] = {
|
|
||||||
"discrete",
|
|
||||||
"karras",
|
|
||||||
"exponential",
|
|
||||||
"ays",
|
|
||||||
"gits",
|
|
||||||
"sgm_uniform",
|
|
||||||
"simple",
|
|
||||||
"smoothstep",
|
|
||||||
"kl_optimal",
|
|
||||||
"lcm",
|
|
||||||
"bong_tangent",
|
|
||||||
};
|
|
||||||
|
|
||||||
static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");
|
|
||||||
|
|
||||||
// New enum string arrays
|
|
||||||
const char* rng_type_str[] = {
|
|
||||||
"std_default",
|
|
||||||
"cuda",
|
|
||||||
"cpu",
|
|
||||||
};
|
|
||||||
static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch");
|
|
||||||
|
|
||||||
const char* prediction_str[] = {
|
|
||||||
"epsilon",
|
|
||||||
"v",
|
|
||||||
"edm_v",
|
|
||||||
"flow",
|
|
||||||
"flux_flow",
|
|
||||||
"flux2_flow",
|
|
||||||
};
|
|
||||||
static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch");
|
|
||||||
|
|
||||||
const char* lora_apply_mode_str[] = {
|
|
||||||
"auto",
|
|
||||||
"immediately",
|
|
||||||
"at_runtime",
|
|
||||||
};
|
|
||||||
static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch");
|
|
||||||
|
|
||||||
constexpr const char* sd_type_str[] = {
|
|
||||||
"f32", // 0
|
|
||||||
"f16", // 1
|
|
||||||
"q4_0", // 2
|
|
||||||
"q4_1", // 3
|
|
||||||
nullptr, // 4
|
|
||||||
nullptr, // 5
|
|
||||||
"q5_0", // 6
|
|
||||||
"q5_1", // 7
|
|
||||||
"q8_0", // 8
|
|
||||||
"q8_1", // 9
|
|
||||||
"q2_k", // 10
|
|
||||||
"q3_k", // 11
|
|
||||||
"q4_k", // 12
|
|
||||||
"q5_k", // 13
|
|
||||||
"q6_k", // 14
|
|
||||||
"q8_k", // 15
|
|
||||||
"iq2_xxs", // 16
|
|
||||||
"iq2_xs", // 17
|
|
||||||
"iq3_xxs", // 18
|
|
||||||
"iq1_s", // 19
|
|
||||||
"iq4_nl", // 20
|
|
||||||
"iq3_s", // 21
|
|
||||||
"iq2_s", // 22
|
|
||||||
"iq4_xs", // 23
|
|
||||||
"i8", // 24
|
|
||||||
"i16", // 25
|
|
||||||
"i32", // 26
|
|
||||||
"i64", // 27
|
|
||||||
"f64", // 28
|
|
||||||
"iq1_m", // 29
|
|
||||||
"bf16", // 30
|
|
||||||
nullptr, nullptr, nullptr, // 31-33
|
|
||||||
"tq1_0", // 34
|
|
||||||
"tq2_0", // 35
|
|
||||||
nullptr, nullptr, nullptr, // 36-38
|
|
||||||
"mxfp4" // 39
|
|
||||||
};
|
|
||||||
static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch");
|
|
||||||
|
|
||||||
sd_ctx_params_t ctx_params;
|
sd_ctx_params_t ctx_params;
|
||||||
sd_ctx_t* sd_c;
|
sd_ctx_t* sd_c;
|
||||||
@@ -596,75 +496,45 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
|||||||
if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval);
|
if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval);
|
||||||
|
|
||||||
if (!strcmp(optname, "rng_type")) {
|
if (!strcmp(optname, "rng_type")) {
|
||||||
int found = -1;
|
rng_type_t parsed = str_to_rng_type(optval);
|
||||||
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
|
if (parsed != RNG_TYPE_COUNT) {
|
||||||
if (!strcmp(optval, rng_type_str[m])) {
|
rng_type = parsed;
|
||||||
found = m;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (found != -1) {
|
|
||||||
rng_type = (rng_type_t)found;
|
|
||||||
fprintf(stderr, "Found rng_type: %s\n", optval);
|
fprintf(stderr, "Found rng_type: %s\n", optval);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Invalid rng_type: %s, using default\n", optval);
|
fprintf(stderr, "Invalid rng_type: %s, using default\n", optval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!strcmp(optname, "sampler_rng_type")) {
|
if (!strcmp(optname, "sampler_rng_type")) {
|
||||||
int found = -1;
|
rng_type_t parsed = str_to_rng_type(optval);
|
||||||
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
|
if (parsed != RNG_TYPE_COUNT) {
|
||||||
if (!strcmp(optval, rng_type_str[m])) {
|
sampler_rng_type = parsed;
|
||||||
found = m;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (found != -1) {
|
|
||||||
sampler_rng_type = (rng_type_t)found;
|
|
||||||
fprintf(stderr, "Found sampler_rng_type: %s\n", optval);
|
fprintf(stderr, "Found sampler_rng_type: %s\n", optval);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval);
|
fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!strcmp(optname, "prediction")) {
|
if (!strcmp(optname, "prediction")) {
|
||||||
int found = -1;
|
prediction_t parsed = str_to_prediction(optval);
|
||||||
for (int m = 0; m < PREDICTION_COUNT; m++) {
|
if (parsed != PREDICTION_COUNT) {
|
||||||
if (!strcmp(optval, prediction_str[m])) {
|
prediction = parsed;
|
||||||
found = m;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (found != -1) {
|
|
||||||
prediction = (prediction_t)found;
|
|
||||||
fprintf(stderr, "Found prediction: %s\n", optval);
|
fprintf(stderr, "Found prediction: %s\n", optval);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Invalid prediction: %s, using default\n", optval);
|
fprintf(stderr, "Invalid prediction: %s, using default\n", optval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!strcmp(optname, "lora_apply_mode")) {
|
if (!strcmp(optname, "lora_apply_mode")) {
|
||||||
int found = -1;
|
lora_apply_mode_t parsed = str_to_lora_apply_mode(optval);
|
||||||
for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) {
|
if (parsed != LORA_APPLY_MODE_COUNT) {
|
||||||
if (!strcmp(optval, lora_apply_mode_str[m])) {
|
lora_apply_mode = parsed;
|
||||||
found = m;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (found != -1) {
|
|
||||||
lora_apply_mode = (lora_apply_mode_t)found;
|
|
||||||
fprintf(stderr, "Found lora_apply_mode: %s\n", optval);
|
fprintf(stderr, "Found lora_apply_mode: %s\n", optval);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval);
|
fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!strcmp(optname, "wtype")) {
|
if (!strcmp(optname, "wtype")) {
|
||||||
int found = -1;
|
sd_type_t parsed = str_to_sd_type(optval);
|
||||||
for (int m = 0; m < SD_TYPE_COUNT; m++) {
|
if (parsed != SD_TYPE_COUNT) {
|
||||||
if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) {
|
wtype = parsed;
|
||||||
found = m;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (found != -1) {
|
|
||||||
wtype = (sd_type_t)found;
|
|
||||||
fprintf(stderr, "Found wtype: %s\n", optval);
|
fprintf(stderr, "Found wtype: %s\n", optval);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Invalid wtype: %s, using default\n", optval);
|
fprintf(stderr, "Invalid wtype: %s, using default\n", optval);
|
||||||
@@ -735,27 +605,25 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
|||||||
fprintf (stderr, "Created context: OK\n");
|
fprintf (stderr, "Created context: OK\n");
|
||||||
|
|
||||||
int sample_method_found = -1;
|
int sample_method_found = -1;
|
||||||
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
|
sample_method_t sm = str_to_sample_method(sampler);
|
||||||
if (!strcmp(sampler, sample_method_str[m])) {
|
if (sm != SAMPLE_METHOD_COUNT) {
|
||||||
sample_method_found = m;
|
sample_method_found = (int)sm;
|
||||||
fprintf(stderr, "Found sampler: %s\n", sampler);
|
fprintf(stderr, "Found sampler: %s\n", sampler);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (sample_method_found == -1) {
|
if (sample_method_found == -1) {
|
||||||
sample_method_found = sd_get_default_sample_method(sd_ctx);
|
sample_method_found = sd_get_default_sample_method(sd_ctx);
|
||||||
fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]);
|
fprintf(stderr, "Invalid sample method, using default: %s\n", sd_sample_method_name((sample_method_t)sample_method_found));
|
||||||
}
|
}
|
||||||
sample_method = (sample_method_t)sample_method_found;
|
sample_method = (sample_method_t)sample_method_found;
|
||||||
|
|
||||||
for (int d = 0; d < SCHEDULER_COUNT; d++) {
|
scheduler_t sched = str_to_scheduler(scheduler_str);
|
||||||
if (!strcmp(scheduler_str, schedulers[d])) {
|
if (sched != SCHEDULER_COUNT) {
|
||||||
scheduler = (scheduler_t)d;
|
scheduler = sched;
|
||||||
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
|
fprintf(stderr, "Found scheduler: %s\n", scheduler_str);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (scheduler == SCHEDULER_COUNT) {
|
if (scheduler == SCHEDULER_COUNT) {
|
||||||
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||||
fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]);
|
fprintf(stderr, "Invalid scheduler, using default: %s\n", sd_scheduler_name(scheduler));
|
||||||
}
|
}
|
||||||
|
|
||||||
sd_c = sd_ctx;
|
sd_c = sd_ctx;
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ func TestAudioTranscription(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tmpDir)
|
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||||
|
|
||||||
// Download sample audio — JFK "ask not what your country can do for you" clip
|
// Download sample audio — JFK "ask not what your country can do for you" clip
|
||||||
audioFile := filepath.Join(tmpDir, "sample.wav")
|
audioFile := filepath.Join(tmpDir, "sample.wav")
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||||
WHISPER_CPP_VERSION?=30c5194c9691e4e9a98b3dea9f19727397d3f46e
|
WHISPER_CPP_VERSION?=95ea8f9bfb03a15db08a8989966fd1ae3361e20d
|
||||||
SO_TARGET?=libgowhisper.so
|
SO_TARGET?=libgowhisper.so
|
||||||
|
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|||||||
@@ -726,6 +726,7 @@
|
|||||||
- TTS
|
- TTS
|
||||||
- &opus
|
- &opus
|
||||||
name: "opus"
|
name: "opus"
|
||||||
|
alias: "opus"
|
||||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus"
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus"
|
||||||
urls:
|
urls:
|
||||||
- https://opus-codec.org/
|
- https://opus-codec.org/
|
||||||
@@ -3029,3 +3030,82 @@
|
|||||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
|
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
|
||||||
mirrors:
|
mirrors:
|
||||||
- localai/localai-backends:master-metal-darwin-arm64-voxtral
|
- localai/localai-backends:master-metal-darwin-arm64-voxtral
|
||||||
|
- &trl
|
||||||
|
name: "trl"
|
||||||
|
alias: "trl"
|
||||||
|
license: apache-2.0
|
||||||
|
description: |
|
||||||
|
HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods.
|
||||||
|
Works on CPU and GPU.
|
||||||
|
urls:
|
||||||
|
- https://github.com/huggingface/trl
|
||||||
|
tags:
|
||||||
|
- fine-tuning
|
||||||
|
- LLM
|
||||||
|
- CPU
|
||||||
|
- GPU
|
||||||
|
- CUDA
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-trl"
|
||||||
|
nvidia: "cuda12-trl"
|
||||||
|
nvidia-cuda-12: "cuda12-trl"
|
||||||
|
nvidia-cuda-13: "cuda13-trl"
|
||||||
|
## TRL backend images
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cpu-trl"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cpu-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cpu-trl-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cpu-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cuda12-trl"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cublas-cuda12-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cuda12-trl-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cublas-cuda12-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cuda13-trl"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cublas-cuda13-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cuda13-trl-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cublas-cuda13-trl
|
||||||
|
## llama.cpp quantization backend
|
||||||
|
- &llama-cpp-quantization
|
||||||
|
name: "llama-cpp-quantization"
|
||||||
|
alias: "llama-cpp-quantization"
|
||||||
|
license: mit
|
||||||
|
icon: https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png
|
||||||
|
description: |
|
||||||
|
Model quantization backend using llama.cpp. Downloads HuggingFace models, converts them to GGUF format,
|
||||||
|
and quantizes them to various formats (q4_k_m, q5_k_m, q8_0, f16, etc.).
|
||||||
|
urls:
|
||||||
|
- https://github.com/ggml-org/llama.cpp
|
||||||
|
tags:
|
||||||
|
- quantization
|
||||||
|
- GGUF
|
||||||
|
- CPU
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-llama-cpp-quantization"
|
||||||
|
metal: "metal-darwin-arm64-llama-cpp-quantization"
|
||||||
|
- !!merge <<: *llama-cpp-quantization
|
||||||
|
name: "cpu-llama-cpp-quantization"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp-quantization"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cpu-llama-cpp-quantization
|
||||||
|
- !!merge <<: *llama-cpp-quantization
|
||||||
|
name: "metal-darwin-arm64-llama-cpp-quantization"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp-quantization"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-metal-darwin-arm64-llama-cpp-quantization
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ import tempfile
|
|||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from acestep.inference import (
|
from acestep.inference import (
|
||||||
GenerationParams,
|
GenerationParams,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
@@ -444,6 +448,8 @@ def serve(address):
|
|||||||
("grpc.max_send_message_length", 50 * 1024 * 1024),
|
("grpc.max_send_message_length", 50 * 1024 * 1024),
|
||||||
("grpc.max_receive_message_length", 50 * 1024 * 1024),
|
("grpc.max_receive_message_length", 50 * 1024 * 1024),
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ import torchaudio as ta
|
|||||||
from chatterbox.tts import ChatterboxTTS
|
from chatterbox.tts import ChatterboxTTS
|
||||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
@@ -225,7 +229,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
78
backend/python/common/grpc_auth.py
Normal file
78
backend/python/common/grpc_auth.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""Shared gRPC bearer token authentication interceptor for LocalAI Python backends.
|
||||||
|
|
||||||
|
When the environment variable LOCALAI_GRPC_AUTH_TOKEN is set, requests without
|
||||||
|
a valid Bearer token in the 'authorization' metadata header are rejected with
|
||||||
|
UNAUTHENTICATED. When the variable is empty or unset, no authentication is
|
||||||
|
performed (backward compatible).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
|
||||||
|
class _AbortHandler(grpc.RpcMethodHandler):
|
||||||
|
"""A method handler that immediately aborts with UNAUTHENTICATED."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.request_streaming = False
|
||||||
|
self.response_streaming = False
|
||||||
|
self.request_deserializer = None
|
||||||
|
self.response_serializer = None
|
||||||
|
self.unary_unary = self._abort
|
||||||
|
self.unary_stream = None
|
||||||
|
self.stream_unary = None
|
||||||
|
self.stream_stream = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _abort(request, context):
|
||||||
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "invalid token")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenAuthInterceptor(grpc.ServerInterceptor):
|
||||||
|
"""Sync gRPC server interceptor that validates a bearer token."""
|
||||||
|
|
||||||
|
def __init__(self, token: str):
|
||||||
|
self._token = token
|
||||||
|
self._abort_handler = _AbortHandler()
|
||||||
|
|
||||||
|
def intercept_service(self, continuation, handler_call_details):
|
||||||
|
metadata = dict(handler_call_details.invocation_metadata)
|
||||||
|
auth = metadata.get("authorization", "")
|
||||||
|
expected = "Bearer " + self._token
|
||||||
|
if not hmac.compare_digest(auth, expected):
|
||||||
|
return self._abort_handler
|
||||||
|
return continuation(handler_call_details)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTokenAuthInterceptor(grpc.aio.ServerInterceptor):
|
||||||
|
"""Async gRPC server interceptor that validates a bearer token."""
|
||||||
|
|
||||||
|
def __init__(self, token: str):
|
||||||
|
self._token = token
|
||||||
|
|
||||||
|
async def intercept_service(self, continuation, handler_call_details):
|
||||||
|
metadata = dict(handler_call_details.invocation_metadata)
|
||||||
|
auth = metadata.get("authorization", "")
|
||||||
|
expected = "Bearer " + self._token
|
||||||
|
if not hmac.compare_digest(auth, expected):
|
||||||
|
return _AbortHandler()
|
||||||
|
return await continuation(handler_call_details)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_interceptors(*, aio: bool = False):
|
||||||
|
"""Return a list of gRPC interceptors for bearer token auth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aio: If True, return async-compatible interceptors for grpc.aio.server().
|
||||||
|
If False (default), return sync interceptors for grpc.server().
|
||||||
|
|
||||||
|
Returns an empty list when LOCALAI_GRPC_AUTH_TOKEN is not set.
|
||||||
|
"""
|
||||||
|
token = os.environ.get("LOCALAI_GRPC_AUTH_TOKEN", "")
|
||||||
|
if not token:
|
||||||
|
return []
|
||||||
|
if aio:
|
||||||
|
return [AsyncTokenAuthInterceptor(token)]
|
||||||
|
return [TokenAuthInterceptor(token)]
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
grpcio==1.78.1
|
grpcio==1.80.0
|
||||||
protobuf
|
protobuf
|
||||||
grpcio-tools
|
grpcio-tools
|
||||||
@@ -15,6 +15,10 @@ import torch
|
|||||||
from TTS.api import TTS
|
from TTS.api import TTS
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@@ -93,7 +97,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
transformers==4.48.3
|
transformers==4.48.3
|
||||||
accelerate
|
accelerate
|
||||||
torch==2.4.1
|
torch==2.4.1
|
||||||
|
torchaudio==2.4.1
|
||||||
coqui-tts
|
coqui-tts
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
grpcio==1.78.1
|
grpcio==1.80.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
packaging==24.1
|
packaging==24.1
|
||||||
@@ -22,6 +22,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
# Import dynamic loader for pipeline discovery
|
# Import dynamic loader for pipeline discovery
|
||||||
from diffusers_dynamic_loader import (
|
from diffusers_dynamic_loader import (
|
||||||
@@ -1042,7 +1046,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -15,6 +15,10 @@ import torch
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
@@ -165,6 +169,8 @@ def serve(address):
|
|||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
]
|
]
|
||||||
|
,
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ import torch
|
|||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@@ -70,7 +74,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ import numpy as np
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
@@ -424,6 +428,8 @@ def serve(address):
|
|||||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ from kittentts import KittenTTS
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@@ -77,7 +81,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ from kokoro import KPipeline
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@@ -84,7 +88,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -21,3 +21,8 @@ if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
installRequirements
|
installRequirements
|
||||||
|
|
||||||
|
# spaCy is a dependency of misaki (used by kokoro for English phonemization).
|
||||||
|
# Pre-download the model here because at runtime the portable Python environment
|
||||||
|
# has no pip/uv, so spacy's auto-download would fail.
|
||||||
|
python -m spacy download en_core_web_sm
|
||||||
|
|||||||
26
backend/python/llama-cpp-quantization/Makefile
Normal file
26
backend/python/llama-cpp-quantization/Makefile
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Version of llama.cpp to fetch convert_hf_to_gguf.py from
|
||||||
|
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||||
|
|
||||||
|
.PHONY: llama-cpp-quantization
|
||||||
|
llama-cpp-quantization:
|
||||||
|
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||||
|
|
||||||
|
.PHONY: run
|
||||||
|
run: llama-cpp-quantization
|
||||||
|
@echo "Running llama-cpp-quantization..."
|
||||||
|
bash run.sh
|
||||||
|
@echo "llama-cpp-quantization run."
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test: llama-cpp-quantization
|
||||||
|
@echo "Testing llama-cpp-quantization..."
|
||||||
|
bash test.sh
|
||||||
|
@echo "llama-cpp-quantization tested."
|
||||||
|
|
||||||
|
.PHONY: protogen-clean
|
||||||
|
protogen-clean:
|
||||||
|
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean: protogen-clean
|
||||||
|
rm -rf venv __pycache__
|
||||||
426
backend/python/llama-cpp-quantization/backend.py
Normal file
426
backend/python/llama-cpp-quantization/backend.py
Normal file
@@ -0,0 +1,426 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
llama.cpp quantization backend for LocalAI.
|
||||||
|
|
||||||
|
Downloads HuggingFace models, converts them to GGUF format using
|
||||||
|
convert_hf_to_gguf.py, and quantizes using llama-quantize.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveJob:
|
||||||
|
"""Tracks a running quantization job."""
|
||||||
|
def __init__(self, job_id):
|
||||||
|
self.job_id = job_id
|
||||||
|
self.progress_queue = queue.Queue()
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
self.thread = None
|
||||||
|
self.process = None # subprocess handle for killing
|
||||||
|
|
||||||
|
|
||||||
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
def __init__(self):
|
||||||
|
self.jobs = {} # job_id -> ActiveJob
|
||||||
|
|
||||||
|
def Health(self, request, context):
|
||||||
|
return backend_pb2.Reply(message=b"OK")
|
||||||
|
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
"""Accept LoadModel — actual work happens in StartQuantization."""
|
||||||
|
return backend_pb2.Result(success=True, message="OK")
|
||||||
|
|
||||||
|
def StartQuantization(self, request, context):
|
||||||
|
job_id = request.job_id
|
||||||
|
if job_id in self.jobs:
|
||||||
|
return backend_pb2.QuantizationJobResult(
|
||||||
|
job_id=job_id,
|
||||||
|
success=False,
|
||||||
|
message=f"Job {job_id} already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
job = ActiveJob(job_id)
|
||||||
|
self.jobs[job_id] = job
|
||||||
|
|
||||||
|
job.thread = threading.Thread(
|
||||||
|
target=self._do_quantization,
|
||||||
|
args=(job, request),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
job.thread.start()
|
||||||
|
|
||||||
|
return backend_pb2.QuantizationJobResult(
|
||||||
|
job_id=job_id,
|
||||||
|
success=True,
|
||||||
|
message="Quantization job started",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _send_progress(self, job, status, message, progress_percent=0.0, output_file="", extra_metrics=None):
|
||||||
|
update = backend_pb2.QuantizationProgressUpdate(
|
||||||
|
job_id=job.job_id,
|
||||||
|
progress_percent=progress_percent,
|
||||||
|
status=status,
|
||||||
|
message=message,
|
||||||
|
output_file=output_file,
|
||||||
|
extra_metrics=extra_metrics or {},
|
||||||
|
)
|
||||||
|
job.progress_queue.put(update)
|
||||||
|
|
||||||
|
def _do_quantization(self, job, request):
|
||||||
|
try:
|
||||||
|
model = request.model
|
||||||
|
quant_type = request.quantization_type or "q4_k_m"
|
||||||
|
output_dir = request.output_dir
|
||||||
|
extra_options = dict(request.extra_options) if request.extra_options else {}
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if job.stop_event.is_set():
|
||||||
|
self._send_progress(job, "stopped", "Job stopped before starting")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 1: Download / resolve model
|
||||||
|
self._send_progress(job, "downloading", f"Resolving model: {model}", progress_percent=0.0)
|
||||||
|
|
||||||
|
model_path = self._resolve_model(job, model, output_dir, extra_options)
|
||||||
|
if model_path is None:
|
||||||
|
return # error already sent
|
||||||
|
|
||||||
|
if job.stop_event.is_set():
|
||||||
|
self._send_progress(job, "stopped", "Job stopped during download")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 2: Convert to f16 GGUF
|
||||||
|
self._send_progress(job, "converting", "Converting model to GGUF (f16)...", progress_percent=30.0)
|
||||||
|
|
||||||
|
f16_gguf_path = os.path.join(output_dir, "model-f16.gguf")
|
||||||
|
if not self._convert_to_gguf(job, model_path, f16_gguf_path, extra_options):
|
||||||
|
return # error already sent
|
||||||
|
|
||||||
|
if job.stop_event.is_set():
|
||||||
|
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Quantize
|
||||||
|
# If the user requested f16, skip quantization — the f16 GGUF is the final output
|
||||||
|
if quant_type.lower() in ("f16", "fp16"):
|
||||||
|
output_file = f16_gguf_path
|
||||||
|
self._send_progress(
|
||||||
|
job, "completed",
|
||||||
|
f"Model converted to f16 GGUF: {output_file}",
|
||||||
|
progress_percent=100.0,
|
||||||
|
output_file=output_file,
|
||||||
|
extra_metrics=self._file_metrics(output_file),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
output_file = os.path.join(output_dir, f"model-{quant_type}.gguf")
|
||||||
|
self._send_progress(job, "quantizing", f"Quantizing to {quant_type}...", progress_percent=50.0)
|
||||||
|
|
||||||
|
if not self._quantize(job, f16_gguf_path, output_file, quant_type):
|
||||||
|
return # error already sent
|
||||||
|
|
||||||
|
# Clean up f16 intermediate file to save disk space
|
||||||
|
try:
|
||||||
|
os.remove(f16_gguf_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._send_progress(
|
||||||
|
job, "completed",
|
||||||
|
f"Quantization complete: {quant_type}",
|
||||||
|
progress_percent=100.0,
|
||||||
|
output_file=output_file,
|
||||||
|
extra_metrics=self._file_metrics(output_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||||
|
|
||||||
|
def _resolve_model(self, job, model, output_dir, extra_options):
|
||||||
|
"""Download model from HuggingFace or return local path."""
|
||||||
|
# If it's a local path that exists, use it directly
|
||||||
|
if os.path.isdir(model):
|
||||||
|
return model
|
||||||
|
|
||||||
|
# If it looks like a GGUF file path, use it directly
|
||||||
|
if os.path.isfile(model) and model.endswith(".gguf"):
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Download from HuggingFace
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
hf_token = extra_options.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||||
|
cache_dir = os.path.join(output_dir, "hf_cache")
|
||||||
|
|
||||||
|
self._send_progress(job, "downloading", f"Downloading {model} from HuggingFace...", progress_percent=5.0)
|
||||||
|
|
||||||
|
local_path = snapshot_download(
|
||||||
|
repo_id=model,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
token=hf_token,
|
||||||
|
ignore_patterns=["*.md", "*.txt", "LICENSE*", ".gitattributes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._send_progress(job, "downloading", f"Downloaded {model}", progress_percent=25.0)
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
if "gated" in error_msg.lower() or "access" in error_msg.lower():
|
||||||
|
self._send_progress(
|
||||||
|
job, "failed",
|
||||||
|
f"Access denied for {model}. This model may be gated — "
|
||||||
|
f"please accept the license at https://huggingface.co/{model} "
|
||||||
|
f"and provide your HF token in extra_options.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._send_progress(job, "failed", f"Failed to download model: {error_msg}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_to_gguf(self, job, model_path, output_path, extra_options):
|
||||||
|
"""Convert HF model to f16 GGUF using convert_hf_to_gguf.py."""
|
||||||
|
# If the model_path is already a GGUF file, just use it as-is
|
||||||
|
if isinstance(model_path, str) and model_path.endswith(".gguf"):
|
||||||
|
# Copy or symlink the GGUF file
|
||||||
|
import shutil
|
||||||
|
shutil.copy2(model_path, output_path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Find convert_hf_to_gguf.py
|
||||||
|
convert_script = self._find_convert_script()
|
||||||
|
if convert_script is None:
|
||||||
|
self._send_progress(job, "failed", "convert_hf_to_gguf.py not found. Install it via the backend's install.sh.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable, convert_script,
|
||||||
|
model_path,
|
||||||
|
"--outfile", output_path,
|
||||||
|
"--outtype", "f16",
|
||||||
|
]
|
||||||
|
|
||||||
|
self._send_progress(job, "converting", "Running convert_hf_to_gguf.py...", progress_percent=35.0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
)
|
||||||
|
job.process = process
|
||||||
|
|
||||||
|
for line in process.stdout:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
self._send_progress(job, "converting", line, progress_percent=40.0)
|
||||||
|
if job.stop_event.is_set():
|
||||||
|
process.kill()
|
||||||
|
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||||
|
return False
|
||||||
|
|
||||||
|
process.wait()
|
||||||
|
job.process = None
|
||||||
|
|
||||||
|
if process.returncode != 0:
|
||||||
|
self._send_progress(job, "failed", f"convert_hf_to_gguf.py failed with exit code {process.returncode}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._send_progress(job, "failed", f"Conversion failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _quantize(self, job, input_path, output_path, quant_type):
|
||||||
|
"""Quantize a GGUF file using llama-quantize."""
|
||||||
|
quantize_bin = self._find_quantize_binary()
|
||||||
|
if quantize_bin is None:
|
||||||
|
self._send_progress(job, "failed", "llama-quantize binary not found. Ensure it is installed and in PATH.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
cmd = [quantize_bin, input_path, output_path, quant_type]
|
||||||
|
|
||||||
|
self._send_progress(job, "quantizing", f"Running llama-quantize ({quant_type})...", progress_percent=55.0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
)
|
||||||
|
job.process = process
|
||||||
|
|
||||||
|
for line in process.stdout:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
# Try to parse progress from llama-quantize output
|
||||||
|
progress = self._parse_quantize_progress(line)
|
||||||
|
pct = 55.0 + (progress * 0.40) if progress else 60.0
|
||||||
|
self._send_progress(job, "quantizing", line, progress_percent=pct)
|
||||||
|
if job.stop_event.is_set():
|
||||||
|
process.kill()
|
||||||
|
self._send_progress(job, "stopped", "Job stopped during quantization")
|
||||||
|
return False
|
||||||
|
|
||||||
|
process.wait()
|
||||||
|
job.process = None
|
||||||
|
|
||||||
|
if process.returncode != 0:
|
||||||
|
self._send_progress(job, "failed", f"llama-quantize failed with exit code {process.returncode}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _parse_quantize_progress(self, line):
|
||||||
|
"""Try to parse a progress percentage from llama-quantize output."""
|
||||||
|
# llama-quantize typically outputs lines like:
|
||||||
|
# [ 123/ 1234] quantizing blk.0.attn_k.weight ...
|
||||||
|
match = re.search(r'\[\s*(\d+)\s*/\s*(\d+)\]', line)
|
||||||
|
if match:
|
||||||
|
current = int(match.group(1))
|
||||||
|
total = int(match.group(2))
|
||||||
|
if total > 0:
|
||||||
|
return current / total
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_convert_script(self):
|
||||||
|
"""Find convert_hf_to_gguf.py in known locations."""
|
||||||
|
candidates = [
|
||||||
|
# Same directory as this backend
|
||||||
|
os.path.join(os.path.dirname(__file__), "convert_hf_to_gguf.py"),
|
||||||
|
# Installed via install.sh
|
||||||
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "convert_hf_to_gguf.py"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Also check if it's on PATH
|
||||||
|
import shutil
|
||||||
|
path_script = shutil.which("convert_hf_to_gguf.py")
|
||||||
|
if path_script:
|
||||||
|
candidates.append(path_script)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if os.path.isfile(candidate):
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_quantize_binary(self):
|
||||||
|
"""Find llama-quantize binary."""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Check common names on PATH
|
||||||
|
for name in ["llama-quantize", "quantize"]:
|
||||||
|
path = shutil.which(name)
|
||||||
|
if path:
|
||||||
|
return path
|
||||||
|
|
||||||
|
# Check in the backend directory (built by install.sh)
|
||||||
|
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
for name in ["llama-quantize", "quantize"]:
|
||||||
|
candidate = os.path.join(backend_dir, name)
|
||||||
|
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _file_metrics(self, filepath):
|
||||||
|
"""Return file size metrics."""
|
||||||
|
try:
|
||||||
|
size_bytes = os.path.getsize(filepath)
|
||||||
|
return {"file_size_mb": size_bytes / (1024 * 1024)}
|
||||||
|
except OSError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def QuantizationProgress(self, request, context):
|
||||||
|
job_id = request.job_id
|
||||||
|
job = self.jobs.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
context.abort(grpc.StatusCode.NOT_FOUND, f"Job {job_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
update = job.progress_queue.get(timeout=1.0)
|
||||||
|
yield update
|
||||||
|
# If this is a terminal status, stop streaming
|
||||||
|
if update.status in ("completed", "failed", "stopped"):
|
||||||
|
break
|
||||||
|
except queue.Empty:
|
||||||
|
# Check if the thread is still alive
|
||||||
|
if job.thread and not job.thread.is_alive():
|
||||||
|
# Thread finished but no terminal update — drain queue
|
||||||
|
while not job.progress_queue.empty():
|
||||||
|
update = job.progress_queue.get_nowait()
|
||||||
|
yield update
|
||||||
|
break
|
||||||
|
# Check if client disconnected
|
||||||
|
if context.is_active() is False:
|
||||||
|
break
|
||||||
|
|
||||||
|
def StopQuantization(self, request, context):
|
||||||
|
job_id = request.job_id
|
||||||
|
job = self.jobs.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Job {job_id} not found")
|
||||||
|
|
||||||
|
job.stop_event.set()
|
||||||
|
if job.process:
|
||||||
|
try:
|
||||||
|
job.process.kill()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return backend_pb2.Result(success=True, message="Stop signal sent")
|
||||||
|
|
||||||
|
|
||||||
|
def serve(address):
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
server.add_insecure_port(address)
|
||||||
|
server.start()
|
||||||
|
print(f"Quantization backend listening on {address}", file=sys.stderr, flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="llama.cpp quantization gRPC backend")
|
||||||
|
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0))
|
||||||
|
serve(args.addr)
|
||||||
58
backend/python/llama-cpp-quantization/install.sh
Executable file
58
backend/python/llama-cpp-quantization/install.sh
Executable file
@@ -0,0 +1,58 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade "
|
||||||
|
installRequirements
|
||||||
|
|
||||||
|
# Fetch convert_hf_to_gguf.py from llama.cpp
|
||||||
|
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||||
|
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||||
|
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||||
|
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
|
curl -L --fail --retry 3 \
|
||||||
|
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||||
|
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||||
|
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||||
|
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
|
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||||
|
pip install "${GGUF_PIP_SPEC}" || {
|
||||||
|
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||||
|
pip install "gguf>=0.16.0"
|
||||||
|
}
|
||||||
|
else
|
||||||
|
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||||
|
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||||
|
uv pip install "gguf>=0.16.0"
|
||||||
|
}
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build llama-quantize from llama.cpp if not already present
|
||||||
|
QUANTIZE_BIN="${EDIR}/llama-quantize"
|
||||||
|
if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then
|
||||||
|
if command -v cmake &>/dev/null; then
|
||||||
|
echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
|
LLAMA_CPP_SRC="${EDIR}/llama.cpp"
|
||||||
|
if [ ! -d "${LLAMA_CPP_SRC}" ]; then
|
||||||
|
git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \
|
||||||
|
https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \
|
||||||
|
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}"
|
||||||
|
fi
|
||||||
|
cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF
|
||||||
|
cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)"
|
||||||
|
cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}"
|
||||||
|
chmod +x "${QUANTIZE_BIN}"
|
||||||
|
echo "Built llama-quantize at ${QUANTIZE_BIN}"
|
||||||
|
else
|
||||||
|
echo "Warning: cmake not found — llama-quantize will not be available. Install cmake or provide llama-quantize on PATH."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
torch==2.10.0
|
||||||
|
transformers>=4.56.2
|
||||||
|
huggingface-hub>=1.3.0
|
||||||
|
sentencepiece
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
torch==2.10.0
|
||||||
|
transformers>=4.56.2
|
||||||
|
huggingface-hub>=1.3.0
|
||||||
|
sentencepiece
|
||||||
3
backend/python/llama-cpp-quantization/requirements.txt
Normal file
3
backend/python/llama-cpp-quantization/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
grpcio==1.78.1
|
||||||
|
protobuf
|
||||||
|
certifi
|
||||||
10
backend/python/llama-cpp-quantization/run.sh
Executable file
10
backend/python/llama-cpp-quantization/run.sh
Executable file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
startBackend $@
|
||||||
99
backend/python/llama-cpp-quantization/test.py
Normal file
99
backend/python/llama-cpp-quantization/test.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
Test script for the llama-cpp-quantization gRPC backend.
|
||||||
|
|
||||||
|
Downloads a small model (functiongemma-270m-it), converts it to GGUF,
|
||||||
|
and quantizes it to q4_k_m.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
|
SERVER_ADDR = "localhost:50051"
|
||||||
|
# Small model for CI testing (~540MB)
|
||||||
|
TEST_MODEL = "unsloth/functiongemma-270m-it"
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuantizationBackend(unittest.TestCase):
|
||||||
|
"""Tests for the llama-cpp-quantization gRPC service."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.service = subprocess.Popen(
|
||||||
|
["python3", "backend.py", "--addr", SERVER_ADDR]
|
||||||
|
)
|
||||||
|
time.sleep(5)
|
||||||
|
cls.output_dir = tempfile.mkdtemp(prefix="quantize-test-")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.service.kill()
|
||||||
|
cls.service.wait()
|
||||||
|
# Clean up output directory
|
||||||
|
if os.path.isdir(cls.output_dir):
|
||||||
|
shutil.rmtree(cls.output_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
def _channel(self):
|
||||||
|
return grpc.insecure_channel(SERVER_ADDR)
|
||||||
|
|
||||||
|
def test_01_health(self):
|
||||||
|
"""Test that the server starts and responds to health checks."""
|
||||||
|
with self._channel() as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.Health(backend_pb2.HealthMessage())
|
||||||
|
self.assertEqual(response.message, b"OK")
|
||||||
|
|
||||||
|
def test_02_quantize_small_model(self):
|
||||||
|
"""Download, convert, and quantize functiongemma-270m-it to q4_k_m."""
|
||||||
|
with self._channel() as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
|
||||||
|
job_id = "test-quantize-001"
|
||||||
|
|
||||||
|
# Start quantization
|
||||||
|
result = stub.StartQuantization(
|
||||||
|
backend_pb2.QuantizationRequest(
|
||||||
|
model=TEST_MODEL,
|
||||||
|
quantization_type="q4_k_m",
|
||||||
|
output_dir=self.output_dir,
|
||||||
|
job_id=job_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(result.success, f"StartQuantization failed: {result.message}")
|
||||||
|
self.assertEqual(result.job_id, job_id)
|
||||||
|
|
||||||
|
# Stream progress until completion
|
||||||
|
final_status = None
|
||||||
|
output_file = None
|
||||||
|
for update in stub.QuantizationProgress(
|
||||||
|
backend_pb2.QuantizationProgressRequest(job_id=job_id)
|
||||||
|
):
|
||||||
|
print(f" [{update.status}] {update.progress_percent:.1f}% - {update.message}")
|
||||||
|
final_status = update.status
|
||||||
|
if update.output_file:
|
||||||
|
output_file = update.output_file
|
||||||
|
|
||||||
|
self.assertEqual(final_status, "completed", f"Expected completed, got {final_status}")
|
||||||
|
self.assertIsNotNone(output_file, "No output_file in progress updates")
|
||||||
|
self.assertTrue(os.path.isfile(output_file), f"Output file not found: {output_file}")
|
||||||
|
|
||||||
|
# Verify the output is a valid GGUF file (starts with "GGUF" magic)
|
||||||
|
with open(output_file, "rb") as f:
|
||||||
|
magic = f.read(4)
|
||||||
|
self.assertEqual(magic, b"GGUF", f"Output file does not have GGUF magic: {magic!r}")
|
||||||
|
|
||||||
|
# Verify reasonable file size (q4_k_m of 270M model should be ~150-400MB)
|
||||||
|
size_mb = os.path.getsize(output_file) / (1024 * 1024)
|
||||||
|
print(f" Output file size: {size_mb:.1f} MB")
|
||||||
|
self.assertGreater(size_mb, 10, "Output file suspiciously small")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
11
backend/python/llama-cpp-quantization/test.sh
Executable file
11
backend/python/llama-cpp-quantization/test.sh
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
runUnittests
|
||||||
@@ -15,6 +15,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from mlx_audio.tts.utils import load_model
|
from mlx_audio.tts.utils import load_model
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -436,7 +440,9 @@ async def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ import tempfile
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
@@ -468,6 +472,8 @@ async def serve(address):
|
|||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from mlx_vlm import load, generate, stream_generate
|
from mlx_vlm import load, generate, stream_generate
|
||||||
from mlx_vlm.prompt_utils import apply_chat_template
|
from mlx_vlm.prompt_utils import apply_chat_template
|
||||||
from mlx_vlm.utils import load_config, load_image
|
from mlx_vlm.utils import load_config, load_image
|
||||||
@@ -446,7 +450,9 @@ async def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from mlx_lm import load, generate, stream_generate
|
from mlx_lm import load, generate, stream_generate
|
||||||
from mlx_lm.sample_utils import make_sampler
|
from mlx_lm.sample_utils import make_sampler
|
||||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||||
@@ -421,7 +425,9 @@ async def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ from moonshine_voice import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@@ -128,7 +132,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ import torch
|
|||||||
import nemo.collections.asr as nemo_asr
|
import nemo.collections.asr as nemo_asr
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
@@ -119,7 +123,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -4,3 +4,4 @@ certifi
|
|||||||
packaging==24.1
|
packaging==24.1
|
||||||
setuptools
|
setuptools
|
||||||
pyarrow==20.0.0
|
pyarrow==20.0.0
|
||||||
|
pybind11
|
||||||
|
|||||||
@@ -15,6 +15,10 @@ from neuttsair.neutts import NeuTTSAir
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
"""Check if a string can be converted to float."""
|
"""Check if a string can be converted to float."""
|
||||||
@@ -130,7 +134,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
import outetts
|
import outetts
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@@ -116,7 +120,9 @@ async def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ import torch
|
|||||||
from pocket_tts import TTSModel
|
from pocket_tts import TTSModel
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
"""Check if a string can be converted to float."""
|
"""Check if a string can be converted to float."""
|
||||||
@@ -225,7 +229,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ import torch
|
|||||||
from qwen_asr import Qwen3ASRModel
|
from qwen_asr import Qwen3ASRModel
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
@@ -184,7 +188,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ import hashlib
|
|||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
@@ -900,6 +904,8 @@ def serve(address):
|
|||||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||||
],
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
)
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
from rerankers import Reranker
|
from rerankers import Reranker
|
||||||
|
|
||||||
@@ -97,7 +101,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
grpcio==1.78.1
|
grpcio==1.80.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
@@ -13,6 +13,10 @@ import base64
|
|||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -139,7 +143,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -16,16 +16,22 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
|
|
||||||
XPU=os.environ.get("XPU", "0") == "1"
|
XPU=os.environ.get("XPU", "0") == "1"
|
||||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
import transformers as transformers_module
|
||||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
|
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# Backward-compat aliases for model types
|
||||||
|
TYPE_ALIASES = {"Mamba": "MambaForCausalLM"}
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
@@ -52,32 +58,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
||||||
"""
|
"""
|
||||||
def Health(self, request, context):
|
def Health(self, request, context):
|
||||||
"""
|
|
||||||
A gRPC method that returns the health status of the backend service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A HealthRequest object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Reply object that contains the health status of the backend service.
|
|
||||||
"""
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
def LoadModel(self, request, context):
|
||||||
"""
|
|
||||||
A gRPC method that loads a model into memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A LoadModelRequest object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Result object that contains the result of the LoadModel operation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_name = request.Model
|
model_name = request.Model
|
||||||
|
|
||||||
# Check to see if the Model exists in the filesystem already.
|
# Check to see if the Model exists in the filesystem already.
|
||||||
if os.path.exists(request.ModelFile):
|
if os.path.exists(request.ModelFile):
|
||||||
model_name = request.ModelFile
|
model_name = request.ModelFile
|
||||||
@@ -88,8 +73,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
|
|
||||||
self.CUDA = torch.cuda.is_available()
|
self.CUDA = torch.cuda.is_available()
|
||||||
self.OV=False
|
self.OV=False
|
||||||
self.DiaTTS=False
|
self.GenericTTS=False
|
||||||
self.SentenceTransformer = False
|
self.SentenceTransformer = False
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
device_map="cpu"
|
device_map="cpu"
|
||||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
@@ -101,7 +87,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
# Parse options from request.Options
|
# Parse options from request.Options
|
||||||
self.options = {}
|
self.options = {}
|
||||||
options = request.Options
|
options = request.Options
|
||||||
|
|
||||||
# The options are a list of strings in this form optname:optvalue
|
# The options are a list of strings in this form optname:optvalue
|
||||||
# We are storing all the options in a dict so we can use it later when generating
|
# We are storing all the options in a dict so we can use it later when generating
|
||||||
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
|
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
|
||||||
@@ -123,7 +109,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
print(f"Parsed options: {self.options}", file=sys.stderr)
|
print(f"Parsed options: {self.options}", file=sys.stderr)
|
||||||
|
|
||||||
if self.CUDA:
|
if self.CUDA:
|
||||||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
from transformers import BitsAndBytesConfig
|
||||||
if request.MainGPU:
|
if request.MainGPU:
|
||||||
device_map=request.MainGPU
|
device_map=request.MainGPU
|
||||||
else:
|
else:
|
||||||
@@ -140,40 +126,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
quantization = BitsAndBytesConfig(
|
quantization = BitsAndBytesConfig(
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
bnb_4bit_compute_dtype = None,
|
bnb_4bit_compute_dtype = None,
|
||||||
load_in_8bit=True,
|
load_in_8bit=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if request.Type == "AutoModelForCausalLM":
|
if XPU and request.Type == "AutoModelForCausalLM":
|
||||||
if XPU:
|
import intel_extension_for_pytorch as ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||||
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
|
||||||
|
|
||||||
device_map="xpu"
|
device_map="xpu"
|
||||||
compute=torch.float16
|
compute=torch.float16
|
||||||
if request.Quantization == "xpu_4bit":
|
if request.Quantization == "xpu_4bit":
|
||||||
xpu_4bit = True
|
xpu_4bit = True
|
||||||
xpu_8bit = False
|
xpu_8bit = False
|
||||||
elif request.Quantization == "xpu_8bit":
|
elif request.Quantization == "xpu_8bit":
|
||||||
xpu_4bit = False
|
xpu_4bit = False
|
||||||
xpu_8bit = True
|
xpu_8bit = True
|
||||||
else:
|
|
||||||
xpu_4bit = False
|
|
||||||
xpu_8bit = False
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
|
||||||
trust_remote_code=request.TrustRemoteCode,
|
|
||||||
use_safetensors=True,
|
|
||||||
device_map=device_map,
|
|
||||||
load_in_4bit=xpu_4bit,
|
|
||||||
load_in_8bit=xpu_8bit,
|
|
||||||
torch_dtype=compute)
|
|
||||||
else:
|
else:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
xpu_4bit = False
|
||||||
trust_remote_code=request.TrustRemoteCode,
|
xpu_8bit = False
|
||||||
use_safetensors=True,
|
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
||||||
quantization_config=quantization,
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
torch_dtype=compute)
|
load_in_4bit=xpu_4bit,
|
||||||
|
load_in_8bit=xpu_8bit,
|
||||||
|
torch_dtype=compute)
|
||||||
elif request.Type == "OVModelForCausalLM":
|
elif request.Type == "OVModelForCausalLM":
|
||||||
from optimum.intel.openvino import OVModelForCausalLM
|
from optimum.intel.openvino import OVModelForCausalLM
|
||||||
from openvino.runtime import Core
|
from openvino.runtime import Core
|
||||||
@@ -185,14 +162,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
devices = Core().available_devices
|
devices = Core().available_devices
|
||||||
if "GPU" in " ".join(devices):
|
if "GPU" in " ".join(devices):
|
||||||
device_map="AUTO:GPU"
|
device_map="AUTO:GPU"
|
||||||
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
|
|
||||||
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
|
|
||||||
if "CPU" or "NPU" in device_map:
|
if "CPU" or "NPU" in device_map:
|
||||||
if "-CPU" or "-NPU" not in device_map:
|
if "-CPU" or "-NPU" not in device_map:
|
||||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
||||||
else:
|
else:
|
||||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
||||||
self.model = OVModelForCausalLM.from_pretrained(model_name,
|
self.model = OVModelForCausalLM.from_pretrained(model_name,
|
||||||
compile=True,
|
compile=True,
|
||||||
trust_remote_code=request.TrustRemoteCode,
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
ov_config=ovconfig,
|
ov_config=ovconfig,
|
||||||
@@ -209,59 +184,60 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
devices = Core().available_devices
|
devices = Core().available_devices
|
||||||
if "GPU" in " ".join(devices):
|
if "GPU" in " ".join(devices):
|
||||||
device_map="AUTO:GPU"
|
device_map="AUTO:GPU"
|
||||||
# While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected.
|
|
||||||
# https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html
|
|
||||||
if "CPU" or "NPU" in device_map:
|
if "CPU" or "NPU" in device_map:
|
||||||
if "-CPU" or "-NPU" not in device_map:
|
if "-CPU" or "-NPU" not in device_map:
|
||||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}
|
||||||
else:
|
else:
|
||||||
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}
|
||||||
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
|
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
|
||||||
compile=True,
|
compile=True,
|
||||||
trust_remote_code=request.TrustRemoteCode,
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
ov_config=ovconfig,
|
ov_config=ovconfig,
|
||||||
export=True,
|
export=True,
|
||||||
device=device_map)
|
device=device_map)
|
||||||
self.OV = True
|
self.OV = True
|
||||||
elif request.Type == "MusicgenForConditionalGeneration":
|
|
||||||
autoTokenizer = False
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
||||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
|
||||||
elif request.Type == "DiaForConditionalGeneration":
|
|
||||||
autoTokenizer = False
|
|
||||||
print("DiaForConditionalGeneration", file=sys.stderr)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
||||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
|
||||||
if self.CUDA:
|
|
||||||
self.model = self.model.to("cuda")
|
|
||||||
self.processor = self.processor.to("cuda")
|
|
||||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
|
||||||
self.DiaTTS = True
|
|
||||||
elif request.Type == "SentenceTransformer":
|
elif request.Type == "SentenceTransformer":
|
||||||
autoTokenizer = False
|
autoTokenizer = False
|
||||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||||
self.SentenceTransformer = True
|
self.SentenceTransformer = True
|
||||||
elif request.Type == "Mamba":
|
|
||||||
autoTokenizer = False
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
self.model = MambaForCausalLM.from_pretrained(model_name)
|
|
||||||
else:
|
else:
|
||||||
print("Automodel", file=sys.stderr)
|
# Generic: dynamically resolve model class from transformers
|
||||||
self.model = AutoModel.from_pretrained(model_name,
|
model_type = TYPE_ALIASES.get(request.Type, request.Type)
|
||||||
trust_remote_code=request.TrustRemoteCode,
|
ModelClass = AutoModel # default
|
||||||
use_safetensors=True,
|
if model_type and hasattr(transformers_module, model_type):
|
||||||
quantization_config=quantization,
|
ModelClass = getattr(transformers_module, model_type)
|
||||||
device_map=device_map,
|
print(f"Using model class: {model_type}", file=sys.stderr)
|
||||||
torch_dtype=compute)
|
else:
|
||||||
|
print(f"Using default AutoModel (type={request.Type!r})", file=sys.stderr)
|
||||||
|
|
||||||
|
self.model = ModelClass.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
|
quantization_config=quantization,
|
||||||
|
device_map=device_map,
|
||||||
|
torch_dtype=compute,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to load a processor (needed for TTS/audio models)
|
||||||
|
try:
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
|
)
|
||||||
|
self.GenericTTS = True
|
||||||
|
print(f"Loaded processor for {model_name}", file=sys.stderr)
|
||||||
|
except Exception:
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
if request.ContextSize > 0:
|
if request.ContextSize > 0:
|
||||||
self.max_tokens = request.ContextSize
|
self.max_tokens = request.ContextSize
|
||||||
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
|
||||||
self.max_tokens = self.model.config.max_position_embeddings
|
self.max_tokens = self.model.config.max_position_embeddings
|
||||||
else:
|
else:
|
||||||
self.max_tokens = self.options.get("max_new_tokens", 512)
|
self.max_tokens = self.options.get("max_new_tokens", 512)
|
||||||
|
|
||||||
if autoTokenizer:
|
if autoTokenizer:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
self.XPU = False
|
self.XPU = False
|
||||||
|
|
||||||
if XPU and self.OV == False:
|
if XPU and self.OV == False:
|
||||||
@@ -275,22 +251,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
print("Error:", err, file=sys.stderr)
|
print("Error:", err, file=sys.stderr)
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
# Implement your logic here for the LoadModel service
|
|
||||||
# Replace this with your desired response
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
def Embedding(self, request, context):
|
def Embedding(self, request, context):
|
||||||
"""
|
|
||||||
A gRPC method that calculates embeddings for a given sentence.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: An EmbeddingRequest object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An EmbeddingResult object that contains the calculated embeddings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
set_seed(request.Seed)
|
set_seed(request.Seed)
|
||||||
# Tokenize input
|
# Tokenize input
|
||||||
max_length = 512
|
max_length = 512
|
||||||
@@ -303,13 +266,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||||
embeds = self.model.encode(request.Embeddings)
|
embeds = self.model.encode(request.Embeddings)
|
||||||
else:
|
else:
|
||||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||||
|
|
||||||
# Create word embeddings
|
# Create word embeddings
|
||||||
if self.CUDA:
|
if self.CUDA:
|
||||||
encoded_input = encoded_input.to("cuda")
|
encoded_input = encoded_input.to("cuda")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model_output = self.model(**encoded_input)
|
model_output = self.model(**encoded_input)
|
||||||
|
|
||||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||||
@@ -317,11 +280,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
embeds = sentence_embeddings[0]
|
embeds = sentence_embeddings[0]
|
||||||
return backend_pb2.EmbeddingResult(embeddings=embeds)
|
return backend_pb2.EmbeddingResult(embeddings=embeds)
|
||||||
|
|
||||||
async def _predict(self, request, context, streaming=False):
|
async def _predict(self, request, context, streaming=False):
|
||||||
set_seed(request.Seed)
|
set_seed(request.Seed)
|
||||||
if request.TopP < 0 or request.TopP > 1:
|
if request.TopP < 0 or request.TopP > 1:
|
||||||
request.TopP = 1
|
request.TopP = 1
|
||||||
|
|
||||||
if request.TopK <= 0:
|
if request.TopK <= 0:
|
||||||
request.TopK = 50
|
request.TopK = 50
|
||||||
|
|
||||||
@@ -334,7 +297,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
request.Temperature == None
|
request.Temperature == None
|
||||||
|
|
||||||
prompt = request.Prompt
|
prompt = request.Prompt
|
||||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||||
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
|
||||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||||
@@ -363,10 +326,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
skip_prompt=True,
|
skip_prompt=True,
|
||||||
skip_special_tokens=True)
|
skip_special_tokens=True)
|
||||||
config=dict(inputs,
|
config=dict(inputs,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
temperature=request.Temperature,
|
temperature=request.Temperature,
|
||||||
top_p=request.TopP,
|
top_p=request.TopP,
|
||||||
top_k=request.TopK,
|
top_k=request.TopK,
|
||||||
do_sample=sample,
|
do_sample=sample,
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
@@ -387,18 +350,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
else:
|
else:
|
||||||
if XPU and self.OV == False:
|
if XPU and self.OV == False:
|
||||||
outputs = self.model.generate(inputs["input_ids"],
|
outputs = self.model.generate(inputs["input_ids"],
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
temperature=request.Temperature,
|
temperature=request.Temperature,
|
||||||
top_p=request.TopP,
|
top_p=request.TopP,
|
||||||
top_k=request.TopK,
|
top_k=request.TopK,
|
||||||
do_sample=sample,
|
do_sample=sample,
|
||||||
pad_token=self.tokenizer.eos_token_id)
|
pad_token=self.tokenizer.eos_token_id)
|
||||||
else:
|
else:
|
||||||
outputs = self.model.generate(**inputs,
|
outputs = self.model.generate(**inputs,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
temperature=request.Temperature,
|
temperature=request.Temperature,
|
||||||
top_p=request.TopP,
|
top_p=request.TopP,
|
||||||
top_k=request.TopK,
|
top_k=request.TopK,
|
||||||
do_sample=sample,
|
do_sample=sample,
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
pad_token_id=self.tokenizer.eos_token_id,
|
pad_token_id=self.tokenizer.eos_token_id,
|
||||||
@@ -413,31 +376,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||||
|
|
||||||
async def Predict(self, request, context):
|
async def Predict(self, request, context):
|
||||||
"""
|
|
||||||
Generates text based on the given prompt and sampling parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The predict request.
|
|
||||||
context: The gRPC context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
backend_pb2.Reply: The predict result.
|
|
||||||
"""
|
|
||||||
gen = self._predict(request, context, streaming=False)
|
gen = self._predict(request, context, streaming=False)
|
||||||
res = await gen.__anext__()
|
res = await gen.__anext__()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def PredictStream(self, request, context):
|
async def PredictStream(self, request, context):
|
||||||
"""
|
|
||||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The predict stream request.
|
|
||||||
context: The gRPC context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
backend_pb2.Result: The predict stream result.
|
|
||||||
"""
|
|
||||||
iterations = self._predict(request, context, streaming=True)
|
iterations = self._predict(request, context, streaming=True)
|
||||||
try:
|
try:
|
||||||
async for iteration in iterations:
|
async for iteration in iterations:
|
||||||
@@ -455,18 +398,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if self.model is None:
|
if self.model is None:
|
||||||
if model_name == "":
|
if model_name == "":
|
||||||
return backend_pb2.Result(success=False, message="request.model is required")
|
return backend_pb2.Result(success=False, message="request.model is required")
|
||||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
# Dynamically resolve model class if configured, otherwise default to MusicgenForConditionalGeneration
|
||||||
|
model_type = self.options.get("model_type", "MusicgenForConditionalGeneration")
|
||||||
|
ModelClass = getattr(transformers_module, model_type)
|
||||||
|
self.model = ModelClass.from_pretrained(model_name)
|
||||||
inputs = None
|
inputs = None
|
||||||
if request.text == "":
|
if request.text == "":
|
||||||
inputs = self.model.get_unconditional_inputs(num_samples=1)
|
inputs = self.model.get_unconditional_inputs(num_samples=1)
|
||||||
elif request.HasField('src'):
|
elif request.HasField('src'):
|
||||||
# TODO SECURITY CODE GOES HERE LOL
|
|
||||||
# WHO KNOWS IF THIS WORKS???
|
|
||||||
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
|
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
|
||||||
|
|
||||||
if request.HasField('src_divisor'):
|
if request.HasField('src_divisor'):
|
||||||
wsamples = wsamples[: len(wsamples) // request.src_divisor]
|
wsamples = wsamples[: len(wsamples) // request.src_divisor]
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
audio=wsamples,
|
audio=wsamples,
|
||||||
sampling_rate=sample_rate,
|
sampling_rate=sample_rate,
|
||||||
@@ -480,7 +424,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
padding=True,
|
padding=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
if request.HasField('duration'):
|
if request.HasField('duration'):
|
||||||
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||||
guidance = self.options.get("guidance_scale", 3.0)
|
guidance = self.options.get("guidance_scale", 3.0)
|
||||||
@@ -490,92 +434,97 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if request.HasField('sample'):
|
if request.HasField('sample'):
|
||||||
dosample = request.sample
|
dosample = request.sample
|
||||||
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
|
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
|
||||||
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
print("[transformers] SoundGeneration generated!", file=sys.stderr)
|
||||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
|
||||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
# Save audio output
|
||||||
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
|
if hasattr(self.processor, 'save_audio'):
|
||||||
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
|
if hasattr(self.processor, 'batch_decode'):
|
||||||
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
|
try:
|
||||||
|
audio_values = self.processor.batch_decode(audio_values)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.processor.save_audio(audio_values, request.dst)
|
||||||
|
else:
|
||||||
|
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||||
|
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||||
|
|
||||||
|
print("[transformers] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||||
print(request, file=sys.stderr)
|
print(request, file=sys.stderr)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
return backend_pb2.Result(success=True)
|
return backend_pb2.Result(success=True)
|
||||||
|
|
||||||
|
|
||||||
def CallDiaTTS(self, request, context):
|
|
||||||
"""
|
|
||||||
Generates dialogue audio using the Dia model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A TTSRequest containing text dialogue and generation parameters
|
|
||||||
context: The gRPC context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Result object indicating success or failure
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
|
|
||||||
|
|
||||||
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
|
|
||||||
text = [request.text]
|
|
||||||
|
|
||||||
# Process the input
|
|
||||||
inputs = self.processor(text=text, padding=True, return_tensors="pt")
|
|
||||||
|
|
||||||
# Generate audio with parameters from options or defaults
|
|
||||||
generation_params = {
|
|
||||||
**inputs,
|
|
||||||
"max_new_tokens": self.max_tokens,
|
|
||||||
"guidance_scale": self.options.get("guidance_scale", 3.0),
|
|
||||||
"temperature": self.options.get("temperature", 1.8),
|
|
||||||
"top_p": self.options.get("top_p", 0.90),
|
|
||||||
"top_k": self.options.get("top_k", 45)
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs = self.model.generate(**generation_params)
|
|
||||||
|
|
||||||
# Decode and save audio
|
|
||||||
outputs = self.processor.batch_decode(outputs)
|
|
||||||
self.processor.save_audio(outputs, request.dst)
|
|
||||||
|
|
||||||
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
|
|
||||||
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
|
|
||||||
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
|
|
||||||
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
return backend_pb2.Result(success=True)
|
|
||||||
|
|
||||||
|
|
||||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
|
||||||
def TTS(self, request, context):
|
def TTS(self, request, context):
|
||||||
if self.DiaTTS:
|
|
||||||
print("DiaTTS", file=sys.stderr)
|
|
||||||
return self.CallDiaTTS(request, context)
|
|
||||||
|
|
||||||
model_name = request.model
|
|
||||||
try:
|
try:
|
||||||
if self.processor is None:
|
text = request.text
|
||||||
if model_name == "":
|
print(f"[transformers] TTS generating for text: {text[:100]}...", file=sys.stderr)
|
||||||
return backend_pb2.Result(success=False, message="request.model is required")
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
# Build inputs based on processor capabilities
|
||||||
if self.model is None:
|
if request.voice and os.path.exists(request.voice):
|
||||||
if model_name == "":
|
# Voice cloning: use chat template with reference audio
|
||||||
return backend_pb2.Result(success=False, message="request.model is required")
|
chat_template = [{
|
||||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
"role": "0",
|
||||||
inputs = self.processor(
|
"content": [
|
||||||
text=[request.text],
|
{"type": "text", "text": text},
|
||||||
padding=True,
|
{"type": "audio", "path": request.voice},
|
||||||
return_tensors="pt",
|
],
|
||||||
)
|
}]
|
||||||
tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
|
inputs = self.processor.apply_chat_template(
|
||||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
chat_template, tokenize=True, return_dict=True,
|
||||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
).to(self.model.device, self.model.dtype)
|
||||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
elif hasattr(self.processor, 'apply_chat_template'):
|
||||||
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
# Models that use chat template format (VibeVoice, CSM, etc.)
|
||||||
print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr)
|
chat_template = [{"role": "0", "content": [{"type": "text", "text": text}]}]
|
||||||
print("[transformers-musicgen] TTS for", file=sys.stderr)
|
try:
|
||||||
print(request, file=sys.stderr)
|
inputs = self.processor.apply_chat_template(
|
||||||
|
chat_template, tokenize=True, return_dict=True,
|
||||||
|
).to(self.model.device, self.model.dtype)
|
||||||
|
except Exception:
|
||||||
|
# Fallback if chat template fails (not all processors support it)
|
||||||
|
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
|
||||||
|
if self.CUDA:
|
||||||
|
inputs = inputs.to("cuda")
|
||||||
|
else:
|
||||||
|
# Direct processor call (Musicgen, etc.)
|
||||||
|
inputs = self.processor(text=[text], padding=True, return_tensors="pt")
|
||||||
|
if self.CUDA:
|
||||||
|
inputs = inputs.to("cuda")
|
||||||
|
|
||||||
|
# Build generation kwargs from self.options
|
||||||
|
gen_kwargs = {**inputs, "max_new_tokens": self.max_tokens}
|
||||||
|
for key in ["guidance_scale", "temperature", "top_p", "top_k", "do_sample"]:
|
||||||
|
if key in self.options:
|
||||||
|
gen_kwargs[key] = self.options[key]
|
||||||
|
|
||||||
|
# Add noise scheduler if configured (e.g., for VibeVoice)
|
||||||
|
noise_scheduler_type = self.options.get("noise_scheduler", None)
|
||||||
|
if noise_scheduler_type:
|
||||||
|
import diffusers
|
||||||
|
SchedulerClass = getattr(diffusers, noise_scheduler_type)
|
||||||
|
scheduler_kwargs = {}
|
||||||
|
for key in ["beta_schedule", "prediction_type"]:
|
||||||
|
if key in self.options:
|
||||||
|
scheduler_kwargs[key] = self.options[key]
|
||||||
|
gen_kwargs["noise_scheduler"] = SchedulerClass(**scheduler_kwargs)
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
audio = self.model.generate(**gen_kwargs)
|
||||||
|
print("[transformers] TTS generated!", file=sys.stderr)
|
||||||
|
|
||||||
|
# Save audio output
|
||||||
|
if hasattr(self.processor, 'save_audio'):
|
||||||
|
if hasattr(self.processor, 'batch_decode'):
|
||||||
|
try:
|
||||||
|
audio = self.processor.batch_decode(audio)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.processor.save_audio(audio, request.dst)
|
||||||
|
else:
|
||||||
|
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||||
|
wavfile.write(request.dst, rate=sampling_rate, data=audio[0, 0].numpy())
|
||||||
|
|
||||||
|
print("[transformers] TTS saved to", request.dst, file=sys.stderr)
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
return backend_pb2.Result(success=True)
|
return backend_pb2.Result(success=True)
|
||||||
@@ -587,7 +536,9 @@ async def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
|||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers>=5.0.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.2.3
|
sentence-transformers==5.2.3
|
||||||
|
diffusers
|
||||||
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
|||||||
accelerate
|
accelerate
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
transformers
|
transformers>=5.0.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.2.3
|
sentence-transformers==5.2.3
|
||||||
|
diffusers
|
||||||
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -2,7 +2,9 @@
|
|||||||
torch==2.9.0
|
torch==2.9.0
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
transformers
|
transformers>=5.0.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.2.3
|
sentence-transformers==5.2.3
|
||||||
|
diffusers
|
||||||
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.4
|
--extra-index-url https://download.pytorch.org/whl/rocm6.4
|
||||||
torch==2.8.0+rocm6.4
|
torch==2.8.0+rocm6.4
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers>=5.0.0
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.2.3
|
sentence-transformers==5.2.3
|
||||||
|
diffusers
|
||||||
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -3,7 +3,9 @@ torch
|
|||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
transformers
|
transformers>=5.0.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.2.3
|
sentence-transformers==5.2.3
|
||||||
|
diffusers
|
||||||
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
@@ -2,7 +2,9 @@ torch==2.7.1
|
|||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers>=5.0.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==5.2.3
|
sentence-transformers==5.2.3
|
||||||
|
diffusers
|
||||||
|
soundfile
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
grpcio==1.78.1
|
grpcio==1.80.0
|
||||||
protobuf==6.33.5
|
protobuf==6.33.5
|
||||||
certifi
|
certifi
|
||||||
setuptools
|
setuptools
|
||||||
|
|||||||
26
backend/python/trl/Makefile
Normal file
26
backend/python/trl/Makefile
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
|
||||||
|
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||||
|
|
||||||
|
.PHONY: trl
|
||||||
|
trl:
|
||||||
|
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||||
|
|
||||||
|
.PHONY: run
|
||||||
|
run: trl
|
||||||
|
@echo "Running trl..."
|
||||||
|
bash run.sh
|
||||||
|
@echo "trl run."
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test: trl
|
||||||
|
@echo "Testing trl..."
|
||||||
|
bash test.sh
|
||||||
|
@echo "trl tested."
|
||||||
|
|
||||||
|
.PHONY: protogen-clean
|
||||||
|
protogen-clean:
|
||||||
|
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean: protogen-clean
|
||||||
|
rm -rf venv __pycache__
|
||||||
866
backend/python/trl/backend.py
Normal file
866
backend/python/trl/backend.py
Normal file
@@ -0,0 +1,866 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
TRL fine-tuning backend for LocalAI.
|
||||||
|
|
||||||
|
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
|
||||||
|
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressCallback:
|
||||||
|
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
|
||||||
|
|
||||||
|
def __init__(self, job_id, progress_queue, total_epochs):
|
||||||
|
self.job_id = job_id
|
||||||
|
self.progress_queue = progress_queue
|
||||||
|
self.total_epochs = total_epochs
|
||||||
|
|
||||||
|
def get_callback(self):
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
parent = self
|
||||||
|
|
||||||
|
class _Callback(TrainerCallback):
|
||||||
|
def __init__(self):
|
||||||
|
self._train_start_time = None
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
|
self._train_start_time = time.time()
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
if logs is None:
|
||||||
|
return
|
||||||
|
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||||
|
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||||
|
eta = 0.0
|
||||||
|
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
|
||||||
|
elapsed = time.time() - self._train_start_time
|
||||||
|
remaining_steps = total_steps - state.global_step
|
||||||
|
if state.global_step > 0:
|
||||||
|
eta = remaining_steps * (elapsed / state.global_step)
|
||||||
|
|
||||||
|
extra_metrics = {}
|
||||||
|
for k, v in logs.items():
|
||||||
|
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
|
||||||
|
extra_metrics[k] = float(v)
|
||||||
|
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
current_epoch=float(logs.get('epoch', 0)),
|
||||||
|
total_epochs=float(parent.total_epochs),
|
||||||
|
loss=float(logs.get('loss', 0)),
|
||||||
|
learning_rate=float(logs.get('learning_rate', 0)),
|
||||||
|
grad_norm=float(logs.get('grad_norm', 0)),
|
||||||
|
eval_loss=float(logs.get('eval_loss', 0)),
|
||||||
|
eta_seconds=float(eta),
|
||||||
|
progress_percent=float(progress),
|
||||||
|
status="training",
|
||||||
|
extra_metrics=extra_metrics,
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
def on_prediction_step(self, args, state, control, **kwargs):
|
||||||
|
"""Send periodic updates during evaluation so the UI doesn't freeze."""
|
||||||
|
if not hasattr(self, '_eval_update_counter'):
|
||||||
|
self._eval_update_counter = 0
|
||||||
|
self._eval_update_counter += 1
|
||||||
|
# Throttle: send an update every 10 prediction steps
|
||||||
|
if self._eval_update_counter % 10 != 0:
|
||||||
|
return
|
||||||
|
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||||
|
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
current_epoch=float(state.epoch or 0),
|
||||||
|
total_epochs=float(parent.total_epochs),
|
||||||
|
progress_percent=float(progress),
|
||||||
|
status="training",
|
||||||
|
message=f"Evaluating... (batch {self._eval_update_counter})",
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||||
|
"""Report eval results once evaluation is done."""
|
||||||
|
# Reset prediction counter for next eval round
|
||||||
|
self._eval_update_counter = 0
|
||||||
|
|
||||||
|
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||||
|
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||||
|
|
||||||
|
eval_loss = 0.0
|
||||||
|
extra_metrics = {}
|
||||||
|
if metrics:
|
||||||
|
eval_loss = float(metrics.get('eval_loss', 0))
|
||||||
|
for k, v in metrics.items():
|
||||||
|
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
|
||||||
|
extra_metrics[k] = float(v)
|
||||||
|
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
current_epoch=float(state.epoch or 0),
|
||||||
|
total_epochs=float(parent.total_epochs),
|
||||||
|
eval_loss=eval_loss,
|
||||||
|
progress_percent=float(progress),
|
||||||
|
status="training",
|
||||||
|
message=f"Evaluation complete at step {state.global_step}",
|
||||||
|
extra_metrics=extra_metrics,
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
status="saving",
|
||||||
|
message=f"Checkpoint saved at step {state.global_step}",
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
total_steps=state.max_steps,
|
||||||
|
progress_percent=100.0,
|
||||||
|
status="completed",
|
||||||
|
message="Training completed",
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
return _Callback()
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveJob:
|
||||||
|
"""Represents an active fine-tuning job."""
|
||||||
|
|
||||||
|
def __init__(self, job_id):
|
||||||
|
self.job_id = job_id
|
||||||
|
self.progress_queue = queue.Queue()
|
||||||
|
self.trainer = None
|
||||||
|
self.thread = None
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.error = None
|
||||||
|
self.completed = False
|
||||||
|
self.stopped = False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_gated_repo_error(exc):
|
||||||
|
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
|
||||||
|
try:
|
||||||
|
from huggingface_hub.utils import GatedRepoError
|
||||||
|
if isinstance(exc, GatedRepoError):
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
msg = str(exc).lower()
|
||||||
|
if "gated repo" in msg or "access to model" in msg:
|
||||||
|
return True
|
||||||
|
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
|
||||||
|
if exc.response.status_code in (401, 403):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
def __init__(self):
|
||||||
|
self.active_job = None
|
||||||
|
|
||||||
|
def Health(self, request, context):
|
||||||
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||||
|
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
"""Accept LoadModel — actual model loading happens in StartFineTune."""
|
||||||
|
return backend_pb2.Result(success=True, message="OK")
|
||||||
|
|
||||||
|
def StartFineTune(self, request, context):
|
||||||
|
if self.active_job is not None and not self.active_job.completed:
|
||||||
|
return backend_pb2.FineTuneJobResult(
|
||||||
|
job_id="",
|
||||||
|
success=False,
|
||||||
|
message="A fine-tuning job is already running",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = request.job_id if request.job_id else str(uuid.uuid4())
|
||||||
|
job = ActiveJob(job_id)
|
||||||
|
self.active_job = job
|
||||||
|
|
||||||
|
# Start training in background thread
|
||||||
|
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
|
||||||
|
job.thread = thread
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
return backend_pb2.FineTuneJobResult(
|
||||||
|
job_id=job_id,
|
||||||
|
success=True,
|
||||||
|
message="Fine-tuning job started",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_training(self, request, job):
|
||||||
|
try:
|
||||||
|
self._do_training(request, job)
|
||||||
|
except Exception as e:
|
||||||
|
if _is_gated_repo_error(e):
|
||||||
|
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||||
|
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||||
|
else:
|
||||||
|
msg = f"Training failed: {e}"
|
||||||
|
job.error = msg
|
||||||
|
job.completed = True
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=job.job_id,
|
||||||
|
status="failed",
|
||||||
|
message=msg,
|
||||||
|
)
|
||||||
|
job.progress_queue.put(update)
|
||||||
|
# Send sentinel
|
||||||
|
job.progress_queue.put(None)
|
||||||
|
|
||||||
|
def _do_training(self, request, job):
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
|
||||||
|
extra = dict(request.extra_options)
|
||||||
|
training_method = request.training_method or "sft"
|
||||||
|
training_type = request.training_type or "lora"
|
||||||
|
|
||||||
|
# Send loading status
|
||||||
|
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
|
||||||
|
))
|
||||||
|
|
||||||
|
# Determine device and dtype
|
||||||
|
device_map = "auto" if torch.cuda.is_available() else "cpu"
|
||||||
|
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
|
||||||
|
|
||||||
|
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
|
||||||
|
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
|
||||||
|
if hf_token:
|
||||||
|
model_kwargs["token"] = hf_token
|
||||||
|
if extra.get("trust_remote_code", "false").lower() == "true":
|
||||||
|
model_kwargs["trust_remote_code"] = True
|
||||||
|
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
job.model = model
|
||||||
|
job.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Apply LoRA if requested
|
||||||
|
if training_type == "lora":
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
|
||||||
|
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
|
||||||
|
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
|
||||||
|
|
||||||
|
target_modules = list(request.target_modules) if request.target_modules else None
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
target_modules=target_modules or "all-linear",
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
|
||||||
|
))
|
||||||
|
|
||||||
|
dataset_split = request.dataset_split or "train"
|
||||||
|
if os.path.exists(request.dataset_source):
|
||||||
|
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||||
|
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
|
||||||
|
elif request.dataset_source.endswith('.csv'):
|
||||||
|
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||||
|
|
||||||
|
# Eval dataset setup
|
||||||
|
eval_dataset = None
|
||||||
|
eval_strategy = extra.get("eval_strategy", "steps")
|
||||||
|
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
|
||||||
|
|
||||||
|
if eval_strategy != "no":
|
||||||
|
eval_split = extra.get("eval_split")
|
||||||
|
eval_dataset_source = extra.get("eval_dataset_source")
|
||||||
|
if eval_split:
|
||||||
|
# Load a specific split as eval dataset
|
||||||
|
if os.path.exists(request.dataset_source):
|
||||||
|
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||||
|
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
|
||||||
|
elif request.dataset_source.endswith('.csv'):
|
||||||
|
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
|
||||||
|
else:
|
||||||
|
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||||
|
else:
|
||||||
|
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||||
|
elif eval_dataset_source:
|
||||||
|
# Load eval dataset from a separate source
|
||||||
|
eval_dataset = load_dataset(eval_dataset_source, split="train")
|
||||||
|
else:
|
||||||
|
# Auto-split from training set
|
||||||
|
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
|
||||||
|
split = dataset.train_test_split(test_size=eval_split_ratio)
|
||||||
|
dataset = split["train"]
|
||||||
|
eval_dataset = split["test"]
|
||||||
|
|
||||||
|
if eval_strategy == "no":
|
||||||
|
eval_dataset = None
|
||||||
|
|
||||||
|
# Training config
|
||||||
|
output_dir = request.output_dir or f"./output-{job.job_id}"
|
||||||
|
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
|
||||||
|
batch_size = request.batch_size if request.batch_size > 0 else 2
|
||||||
|
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
|
||||||
|
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
|
||||||
|
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
|
||||||
|
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
|
||||||
|
max_steps = request.max_steps if request.max_steps > 0 else -1
|
||||||
|
save_steps = request.save_steps if request.save_steps > 0 else 500
|
||||||
|
seed = request.seed if request.seed > 0 else 3407
|
||||||
|
optimizer = request.optimizer or "adamw_torch"
|
||||||
|
|
||||||
|
# Checkpoint save controls
|
||||||
|
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
|
||||||
|
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
|
||||||
|
|
||||||
|
# CPU vs GPU training args (can be overridden via extra_options)
|
||||||
|
use_cpu = not torch.cuda.is_available()
|
||||||
|
common_train_kwargs = {}
|
||||||
|
if use_cpu:
|
||||||
|
common_train_kwargs["use_cpu"] = True
|
||||||
|
common_train_kwargs["fp16"] = False
|
||||||
|
common_train_kwargs["bf16"] = False
|
||||||
|
common_train_kwargs["gradient_checkpointing"] = False
|
||||||
|
else:
|
||||||
|
common_train_kwargs["bf16"] = True
|
||||||
|
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
|
||||||
|
|
||||||
|
# Allow extra_options to override training kwargs
|
||||||
|
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
|
||||||
|
if flag in extra:
|
||||||
|
common_train_kwargs[flag] = extra[flag].lower() == "true"
|
||||||
|
|
||||||
|
# Create progress callback
|
||||||
|
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
|
||||||
|
|
||||||
|
# Build save kwargs (shared across all methods)
|
||||||
|
_save_kwargs = {}
|
||||||
|
if save_strategy == "steps" and save_steps > 0:
|
||||||
|
_save_kwargs["save_steps"] = save_steps
|
||||||
|
_save_kwargs["save_strategy"] = "steps"
|
||||||
|
elif save_strategy == "epoch":
|
||||||
|
_save_kwargs["save_strategy"] = "epoch"
|
||||||
|
elif save_strategy == "no":
|
||||||
|
_save_kwargs["save_strategy"] = "no"
|
||||||
|
else:
|
||||||
|
_save_kwargs["save_steps"] = save_steps
|
||||||
|
_save_kwargs["save_strategy"] = "steps"
|
||||||
|
if save_total_limit:
|
||||||
|
_save_kwargs["save_total_limit"] = save_total_limit
|
||||||
|
|
||||||
|
# Eval kwargs
|
||||||
|
_eval_kwargs = {}
|
||||||
|
if eval_dataset is not None:
|
||||||
|
_eval_kwargs["eval_strategy"] = eval_strategy
|
||||||
|
_eval_kwargs["eval_steps"] = eval_steps
|
||||||
|
|
||||||
|
# Common training arguments shared by all methods
|
||||||
|
_common_args = dict(
|
||||||
|
output_dir=output_dir,
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
learning_rate=lr,
|
||||||
|
gradient_accumulation_steps=grad_accum,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
max_steps=max_steps,
|
||||||
|
seed=seed,
|
||||||
|
optim=optimizer,
|
||||||
|
logging_steps=1,
|
||||||
|
report_to="none",
|
||||||
|
**_save_kwargs,
|
||||||
|
**common_train_kwargs,
|
||||||
|
**_eval_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select trainer based on training method
|
||||||
|
if training_method == "sft":
|
||||||
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
|
||||||
|
max_length = int(extra.get("max_seq_length", "512"))
|
||||||
|
packing = extra.get("packing", "false").lower() == "true"
|
||||||
|
|
||||||
|
training_args = SFTConfig(
|
||||||
|
max_length=max_length,
|
||||||
|
packing=packing,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "dpo":
|
||||||
|
from trl import DPOTrainer, DPOConfig
|
||||||
|
|
||||||
|
beta = float(extra.get("beta", "0.1"))
|
||||||
|
loss_type = extra.get("loss_type", "sigmoid")
|
||||||
|
max_length = int(extra.get("max_length", "512"))
|
||||||
|
|
||||||
|
training_args = DPOConfig(
|
||||||
|
beta=beta,
|
||||||
|
loss_type=loss_type,
|
||||||
|
max_length=max_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = DPOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "grpo":
|
||||||
|
from trl import GRPOTrainer, GRPOConfig
|
||||||
|
|
||||||
|
num_generations = int(extra.get("num_generations", "4"))
|
||||||
|
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||||
|
|
||||||
|
training_args = GRPOConfig(
|
||||||
|
num_generations=num_generations,
|
||||||
|
max_completion_length=max_completion_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# GRPO requires reward functions passed via extra_options as a JSON list
|
||||||
|
from reward_functions import build_reward_functions
|
||||||
|
|
||||||
|
reward_funcs = []
|
||||||
|
if extra.get("reward_funcs"):
|
||||||
|
reward_funcs = build_reward_functions(extra["reward_funcs"])
|
||||||
|
|
||||||
|
if not reward_funcs:
|
||||||
|
raise ValueError(
|
||||||
|
"GRPO requires at least one reward function. "
|
||||||
|
"Specify reward_functions in the request or "
|
||||||
|
"reward_funcs in extra_options."
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = GRPOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
reward_funcs=reward_funcs,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "orpo":
|
||||||
|
from trl import ORPOTrainer, ORPOConfig
|
||||||
|
|
||||||
|
beta = float(extra.get("beta", "0.1"))
|
||||||
|
max_length = int(extra.get("max_length", "512"))
|
||||||
|
|
||||||
|
training_args = ORPOConfig(
|
||||||
|
beta=beta,
|
||||||
|
max_length=max_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = ORPOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "kto":
|
||||||
|
from trl import KTOTrainer, KTOConfig
|
||||||
|
|
||||||
|
beta = float(extra.get("beta", "0.1"))
|
||||||
|
max_length = int(extra.get("max_length", "512"))
|
||||||
|
|
||||||
|
training_args = KTOConfig(
|
||||||
|
beta=beta,
|
||||||
|
max_length=max_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = KTOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "rloo":
|
||||||
|
from trl import RLOOTrainer, RLOOConfig
|
||||||
|
|
||||||
|
num_generations = int(extra.get("num_generations", "4"))
|
||||||
|
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||||
|
|
||||||
|
training_args = RLOOConfig(
|
||||||
|
num_generations=num_generations,
|
||||||
|
max_new_tokens=max_completion_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = RLOOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "reward":
|
||||||
|
from trl import RewardTrainer, RewardConfig
|
||||||
|
|
||||||
|
max_length = int(extra.get("max_length", "512"))
|
||||||
|
|
||||||
|
training_args = RewardConfig(
|
||||||
|
max_length=max_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported training method: {training_method}. "
|
||||||
|
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
|
||||||
|
|
||||||
|
job.trainer = trainer
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=job.job_id, status="training", message="Training started",
|
||||||
|
))
|
||||||
|
|
||||||
|
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
|
||||||
|
trainer.train(resume_from_checkpoint=resume_ckpt)
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
trainer.save_model(output_dir)
|
||||||
|
if tokenizer:
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
job.completed = True
|
||||||
|
# Sentinel to signal stream end
|
||||||
|
job.progress_queue.put(None)
|
||||||
|
|
||||||
|
def FineTuneProgress(self, request, context):
|
||||||
|
if self.active_job is None or self.active_job.job_id != request.job_id:
|
||||||
|
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||||
|
context.set_details(f"Job {request.job_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
job = self.active_job
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
update = job.progress_queue.get(timeout=1.0)
|
||||||
|
if update is None:
|
||||||
|
break
|
||||||
|
yield update
|
||||||
|
if update.status in ("completed", "failed", "stopped"):
|
||||||
|
break
|
||||||
|
except queue.Empty:
|
||||||
|
if job.completed or job.stopped:
|
||||||
|
break
|
||||||
|
if not context.is_active():
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
def StopFineTune(self, request, context):
|
||||||
|
# Stopping is handled by killing the process from Go via ShutdownModel.
|
||||||
|
return backend_pb2.Result(success=True, message="OK")
|
||||||
|
|
||||||
|
def ListCheckpoints(self, request, context):
|
||||||
|
output_dir = request.output_dir
|
||||||
|
if not os.path.isdir(output_dir):
|
||||||
|
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
|
||||||
|
|
||||||
|
checkpoints = []
|
||||||
|
for entry in sorted(os.listdir(output_dir)):
|
||||||
|
if entry.startswith("checkpoint-"):
|
||||||
|
ckpt_path = os.path.join(output_dir, entry)
|
||||||
|
if not os.path.isdir(ckpt_path):
|
||||||
|
continue
|
||||||
|
step = 0
|
||||||
|
try:
|
||||||
|
step = int(entry.split("-")[1])
|
||||||
|
except (IndexError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try to read trainer_state.json for metadata
|
||||||
|
loss = 0.0
|
||||||
|
epoch = 0.0
|
||||||
|
state_file = os.path.join(ckpt_path, "trainer_state.json")
|
||||||
|
if os.path.exists(state_file):
|
||||||
|
try:
|
||||||
|
with open(state_file) as f:
|
||||||
|
state = json.load(f)
|
||||||
|
if state.get("log_history"):
|
||||||
|
last_log = state["log_history"][-1]
|
||||||
|
loss = last_log.get("loss", 0.0)
|
||||||
|
epoch = last_log.get("epoch", 0.0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
created_at = time.strftime(
|
||||||
|
"%Y-%m-%dT%H:%M:%SZ",
|
||||||
|
time.gmtime(os.path.getmtime(ckpt_path)),
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints.append(backend_pb2.CheckpointInfo(
|
||||||
|
path=ckpt_path,
|
||||||
|
step=step,
|
||||||
|
epoch=float(epoch),
|
||||||
|
loss=float(loss),
|
||||||
|
created_at=created_at,
|
||||||
|
))
|
||||||
|
|
||||||
|
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
|
||||||
|
|
||||||
|
def ExportModel(self, request, context):
|
||||||
|
export_format = request.export_format or "lora"
|
||||||
|
output_path = request.output_path
|
||||||
|
checkpoint_path = request.checkpoint_path
|
||||||
|
|
||||||
|
# Extract HF token for gated model access
|
||||||
|
extra = dict(request.extra_options) if request.extra_options else {}
|
||||||
|
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||||
|
|
||||||
|
if not checkpoint_path or not os.path.isdir(checkpoint_path):
|
||||||
|
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
|
||||||
|
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if export_format == "lora":
|
||||||
|
# Just copy the adapter files
|
||||||
|
import shutil
|
||||||
|
for f in os.listdir(checkpoint_path):
|
||||||
|
src = os.path.join(checkpoint_path, f)
|
||||||
|
dst = os.path.join(output_path, f)
|
||||||
|
if os.path.isfile(src):
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
|
||||||
|
elif export_format in ("merged_16bit", "merged_4bit"):
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
base_model_name = request.model
|
||||||
|
if not base_model_name:
|
||||||
|
return backend_pb2.Result(success=False, message="Base model name required for merge export")
|
||||||
|
|
||||||
|
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
|
||||||
|
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||||
|
merged = model.merge_and_unload()
|
||||||
|
merged.save_pretrained(output_path)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||||
|
tokenizer.save_pretrained(output_path)
|
||||||
|
|
||||||
|
elif export_format == "gguf":
|
||||||
|
import torch
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
base_model_name = request.model
|
||||||
|
if not base_model_name:
|
||||||
|
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
|
||||||
|
|
||||||
|
# Step 1: Merge LoRA into base model
|
||||||
|
merge_dir = os.path.join(output_path, "_hf_merged")
|
||||||
|
os.makedirs(merge_dir, exist_ok=True)
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
|
||||||
|
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||||
|
merged = model.merge_and_unload()
|
||||||
|
merged.save_pretrained(merge_dir)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||||
|
tokenizer.save_pretrained(merge_dir)
|
||||||
|
|
||||||
|
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
|
||||||
|
# Gemma models need this file for GGUF conversion to use the
|
||||||
|
# SentencePiece path; without it, the script falls back to BPE
|
||||||
|
# handling which fails on unrecognized pre-tokenizer hashes.
|
||||||
|
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
|
||||||
|
if not os.path.exists(sp_model_path):
|
||||||
|
sp_copied = False
|
||||||
|
# Method 1: Load the slow tokenizer which keeps the SP model file
|
||||||
|
try:
|
||||||
|
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
|
||||||
|
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
|
||||||
|
import shutil as _shutil
|
||||||
|
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
|
||||||
|
sp_copied = True
|
||||||
|
print(f"Copied tokenizer.model from slow tokenizer cache")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Slow tokenizer method failed: {e}")
|
||||||
|
# Method 2: Download from HF hub
|
||||||
|
if not sp_copied:
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
|
||||||
|
import shutil as _shutil
|
||||||
|
_shutil.copy2(cached_sp, sp_model_path)
|
||||||
|
sp_copied = True
|
||||||
|
print(f"Copied tokenizer.model from HF hub")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"HF hub download method failed: {e}")
|
||||||
|
if not sp_copied:
|
||||||
|
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
|
||||||
|
"GGUF conversion may fail for SentencePiece models.")
|
||||||
|
|
||||||
|
# Free GPU memory before conversion
|
||||||
|
del merged, model, base_model
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
|
||||||
|
quant = request.quantization_method or "auto"
|
||||||
|
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
|
||||||
|
outtype = outtype_map.get(quant, "f16")
|
||||||
|
|
||||||
|
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
|
||||||
|
gguf_path = os.path.join(output_path, gguf_filename)
|
||||||
|
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
|
||||||
|
if not os.path.exists(convert_script):
|
||||||
|
return backend_pb2.Result(success=False,
|
||||||
|
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
|
||||||
|
|
||||||
|
# Log merge_dir contents for debugging conversion issues
|
||||||
|
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
|
||||||
|
print(f"Merge dir contents: {merge_files}", flush=True)
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["NO_LOCAL_GGUF"] = "1"
|
||||||
|
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
|
||||||
|
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
|
||||||
|
if conv_result.returncode != 0:
|
||||||
|
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
|
||||||
|
return backend_pb2.Result(success=False,
|
||||||
|
message=f"GGUF conversion failed: {diag}")
|
||||||
|
|
||||||
|
# Clean up intermediate merged model
|
||||||
|
shutil.rmtree(merge_dir, ignore_errors=True)
|
||||||
|
else:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if _is_gated_repo_error(e):
|
||||||
|
return backend_pb2.Result(success=False,
|
||||||
|
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||||
|
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||||
|
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
|
||||||
|
|
||||||
|
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def serve(address):
|
||||||
|
server = grpc.server(
|
||||||
|
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||||
|
options=[
|
||||||
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||||
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
|
],
|
||||||
|
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
server.add_insecure_port(address)
|
||||||
|
server.start()
|
||||||
|
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
|
||||||
|
|
||||||
|
# Handle graceful shutdown
|
||||||
|
def stop(signum, frame):
|
||||||
|
server.stop(0)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, stop)
|
||||||
|
signal.signal(signal.SIGINT, stop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
|
||||||
|
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||||
|
args = parser.parse_args()
|
||||||
|
serve(args.addr)
|
||||||
37
backend/python/trl/install.sh
Normal file
37
backend/python/trl/install.sh
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||||
|
installRequirements
|
||||||
|
|
||||||
|
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
|
||||||
|
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||||
|
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||||
|
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||||
|
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
|
curl -L --fail --retry 3 \
|
||||||
|
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||||
|
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||||
|
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||||
|
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
|
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||||
|
pip install "${GGUF_PIP_SPEC}" || {
|
||||||
|
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||||
|
pip install "gguf>=0.16.0"
|
||||||
|
}
|
||||||
|
else
|
||||||
|
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||||
|
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||||
|
uv pip install "gguf>=0.16.0"
|
||||||
|
}
|
||||||
|
fi
|
||||||
9
backend/python/trl/requirements-cpu.txt
Normal file
9
backend/python/trl/requirements-cpu.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
torch==2.10.0
|
||||||
|
trl
|
||||||
|
peft
|
||||||
|
datasets>=3.0.0
|
||||||
|
transformers>=4.56.2
|
||||||
|
accelerate>=1.4.0
|
||||||
|
huggingface-hub>=1.3.0
|
||||||
|
sentencepiece
|
||||||
9
backend/python/trl/requirements-cublas12.txt
Normal file
9
backend/python/trl/requirements-cublas12.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
torch==2.10.0
|
||||||
|
trl
|
||||||
|
peft
|
||||||
|
datasets>=3.0.0
|
||||||
|
transformers>=4.56.2
|
||||||
|
accelerate>=1.4.0
|
||||||
|
huggingface-hub>=1.3.0
|
||||||
|
sentencepiece
|
||||||
|
bitsandbytes
|
||||||
9
backend/python/trl/requirements-cublas13.txt
Normal file
9
backend/python/trl/requirements-cublas13.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
torch==2.10.0
|
||||||
|
trl
|
||||||
|
peft
|
||||||
|
datasets>=3.0.0
|
||||||
|
transformers>=4.56.2
|
||||||
|
accelerate>=1.4.0
|
||||||
|
huggingface-hub>=1.3.0
|
||||||
|
sentencepiece
|
||||||
|
bitsandbytes
|
||||||
3
backend/python/trl/requirements.txt
Normal file
3
backend/python/trl/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
grpcio==1.78.1
|
||||||
|
protobuf
|
||||||
|
certifi
|
||||||
236
backend/python/trl/reward_functions.py
Normal file
236
backend/python/trl/reward_functions.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
Built-in reward functions and inline function compiler for GRPO training.
|
||||||
|
|
||||||
|
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import math
|
||||||
|
import string
|
||||||
|
import functools
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Built-in reward functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def format_reward(completions, **kwargs):
|
||||||
|
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
|
||||||
|
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
|
||||||
|
return [1.0 if pattern.search(c) else 0.0 for c in completions]
|
||||||
|
|
||||||
|
|
||||||
|
def reasoning_accuracy_reward(completions, **kwargs):
|
||||||
|
"""Extracts <answer>...</answer> content and compares to the expected answer."""
|
||||||
|
answers = kwargs.get("answer", [])
|
||||||
|
if not answers:
|
||||||
|
return [0.0] * len(completions)
|
||||||
|
|
||||||
|
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
|
||||||
|
scores = []
|
||||||
|
for i, c in enumerate(completions):
|
||||||
|
expected = answers[i] if i < len(answers) else ""
|
||||||
|
match = pattern.search(c)
|
||||||
|
if match:
|
||||||
|
extracted = match.group(1).strip()
|
||||||
|
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
|
||||||
|
else:
|
||||||
|
scores.append(0.0)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def length_reward(completions, target_length=200, **kwargs):
|
||||||
|
"""Score based on proximity to target_length. Returns [0, 1]."""
|
||||||
|
scores = []
|
||||||
|
for c in completions:
|
||||||
|
length = len(c)
|
||||||
|
if target_length <= 0:
|
||||||
|
scores.append(0.0)
|
||||||
|
else:
|
||||||
|
diff = abs(length - target_length) / target_length
|
||||||
|
scores.append(max(0.0, 1.0 - diff))
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def xml_tag_reward(completions, **kwargs):
|
||||||
|
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
|
||||||
|
tags = ["think", "answer"]
|
||||||
|
scores = []
|
||||||
|
for c in completions:
|
||||||
|
tag_score = 0.0
|
||||||
|
for tag in tags:
|
||||||
|
if f"<{tag}>" in c and f"</{tag}>" in c:
|
||||||
|
tag_score += 0.5
|
||||||
|
scores.append(min(tag_score, 1.0))
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def no_repetition_reward(completions, n=4, **kwargs):
|
||||||
|
"""Penalizes n-gram repetition. Returns [0, 1]."""
|
||||||
|
scores = []
|
||||||
|
for c in completions:
|
||||||
|
words = c.split()
|
||||||
|
if len(words) < n:
|
||||||
|
scores.append(1.0)
|
||||||
|
continue
|
||||||
|
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
|
||||||
|
unique = len(set(ngrams))
|
||||||
|
total = len(ngrams)
|
||||||
|
scores.append(unique / total if total > 0 else 1.0)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def code_execution_reward(completions, **kwargs):
|
||||||
|
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
|
||||||
|
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
|
||||||
|
scores = []
|
||||||
|
for c in completions:
|
||||||
|
match = pattern.search(c)
|
||||||
|
if not match:
|
||||||
|
scores.append(0.0)
|
||||||
|
continue
|
||||||
|
code = match.group(1)
|
||||||
|
try:
|
||||||
|
compile(code, "<inline>", "exec")
|
||||||
|
scores.append(1.0)
|
||||||
|
except SyntaxError:
|
||||||
|
scores.append(0.0)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Registry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
BUILTIN_REGISTRY = {
|
||||||
|
"format_reward": format_reward,
|
||||||
|
"reasoning_accuracy_reward": reasoning_accuracy_reward,
|
||||||
|
"length_reward": length_reward,
|
||||||
|
"xml_tag_reward": xml_tag_reward,
|
||||||
|
"no_repetition_reward": no_repetition_reward,
|
||||||
|
"code_execution_reward": code_execution_reward,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Inline function compiler
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_SAFE_BUILTINS = {
|
||||||
|
"len": len, "int": int, "float": float, "str": str, "bool": bool,
|
||||||
|
"list": list, "dict": dict, "tuple": tuple, "set": set,
|
||||||
|
"range": range, "enumerate": enumerate, "zip": zip,
|
||||||
|
"map": map, "filter": filter, "sorted": sorted,
|
||||||
|
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
|
||||||
|
"any": any, "all": all, "isinstance": isinstance, "type": type,
|
||||||
|
"print": print, "True": True, "False": False, "None": None,
|
||||||
|
"ValueError": ValueError, "TypeError": TypeError,
|
||||||
|
"KeyError": KeyError, "IndexError": IndexError,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compile_inline_reward(name, code):
|
||||||
|
"""Compile user-provided code into a reward function.
|
||||||
|
|
||||||
|
The code should be the body of a function that receives
|
||||||
|
`completions` (list[str]) and `**kwargs`, and returns list[float].
|
||||||
|
|
||||||
|
Available modules: re, math, json, string.
|
||||||
|
"""
|
||||||
|
func_source = (
|
||||||
|
f"def _user_reward_{name}(completions, **kwargs):\n"
|
||||||
|
+ "\n".join(f" {line}" for line in code.splitlines())
|
||||||
|
)
|
||||||
|
|
||||||
|
restricted_globals = {
|
||||||
|
"__builtins__": _SAFE_BUILTINS,
|
||||||
|
"re": re,
|
||||||
|
"math": math,
|
||||||
|
"json": json,
|
||||||
|
"string": string,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
|
||||||
|
except SyntaxError as e:
|
||||||
|
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
|
||||||
|
|
||||||
|
exec(compiled, restricted_globals)
|
||||||
|
func = restricted_globals[f"_user_reward_{name}"]
|
||||||
|
|
||||||
|
# Validate with a quick smoke test
|
||||||
|
try:
|
||||||
|
result = func(["test"], answer=["test"])
|
||||||
|
if not isinstance(result, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "must return a list" in str(e):
|
||||||
|
raise
|
||||||
|
# Other errors during smoke test are acceptable (e.g. missing kwargs)
|
||||||
|
pass
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dispatcher
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_reward_functions(specs_json):
|
||||||
|
"""Parse a JSON list of reward function specs and return a list of callables.
|
||||||
|
|
||||||
|
Each spec is a dict with:
|
||||||
|
- type: "builtin" or "inline"
|
||||||
|
- name: function name
|
||||||
|
- code: (inline only) Python function body
|
||||||
|
- params: (optional) dict of string params applied via functools.partial
|
||||||
|
"""
|
||||||
|
if isinstance(specs_json, str):
|
||||||
|
specs = json.loads(specs_json)
|
||||||
|
else:
|
||||||
|
specs = specs_json
|
||||||
|
|
||||||
|
if not isinstance(specs, list):
|
||||||
|
raise ValueError("reward_funcs must be a JSON array of reward function specs")
|
||||||
|
|
||||||
|
reward_funcs = []
|
||||||
|
for spec in specs:
|
||||||
|
spec_type = spec.get("type", "builtin")
|
||||||
|
name = spec.get("name", "")
|
||||||
|
params = spec.get("params", {})
|
||||||
|
|
||||||
|
if spec_type == "builtin":
|
||||||
|
if name not in BUILTIN_REGISTRY:
|
||||||
|
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown builtin reward function '{name}'. Available: {available}"
|
||||||
|
)
|
||||||
|
func = BUILTIN_REGISTRY[name]
|
||||||
|
if params:
|
||||||
|
# Convert string params to appropriate types
|
||||||
|
typed_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
try:
|
||||||
|
typed_params[k] = int(v)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
try:
|
||||||
|
typed_params[k] = float(v)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
typed_params[k] = v
|
||||||
|
func = functools.partial(func, **typed_params)
|
||||||
|
reward_funcs.append(func)
|
||||||
|
|
||||||
|
elif spec_type == "inline":
|
||||||
|
code = spec.get("code", "")
|
||||||
|
if not code.strip():
|
||||||
|
raise ValueError(f"Inline reward function '{name}' has no code")
|
||||||
|
func = compile_inline_reward(name, code)
|
||||||
|
reward_funcs.append(func)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
|
||||||
|
|
||||||
|
return reward_funcs
|
||||||
10
backend/python/trl/run.sh
Normal file
10
backend/python/trl/run.sh
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
startBackend $@
|
||||||
58
backend/python/trl/test.py
Normal file
58
backend/python/trl/test.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
Test script for the TRL fine-tuning gRPC backend.
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackendServicer(unittest.TestCase):
|
||||||
|
"""Tests for the TRL fine-tuning gRPC service."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.service = subprocess.Popen(
|
||||||
|
["python3", "backend.py", "--addr", "localhost:50051"]
|
||||||
|
)
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.service.kill()
|
||||||
|
self.service.wait()
|
||||||
|
|
||||||
|
def test_server_startup(self):
|
||||||
|
"""Test that the server starts and responds to health checks."""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.Health(backend_pb2.HealthMessage())
|
||||||
|
self.assertEqual(response.message, b'OK')
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("Server failed to start")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_list_checkpoints_empty(self):
|
||||||
|
"""Test listing checkpoints on a non-existent directory."""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.ListCheckpoints(
|
||||||
|
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
|
||||||
|
)
|
||||||
|
self.assertEqual(len(response.checkpoints), 0)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("ListCheckpoints service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
11
backend/python/trl/test.sh
Normal file
11
backend/python/trl/test.sh
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
runUnittests
|
||||||
@@ -20,6 +20,10 @@ from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalG
|
|||||||
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
"""Check if a string can be converted to float."""
|
"""Check if a string can be converted to float."""
|
||||||
@@ -724,7 +728,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
from vllm_omni.entrypoints.omni import Omni
|
from vllm_omni.entrypoints.omni import Omni
|
||||||
from vllm_omni.outputs import OmniRequestOutput
|
from vllm_omni.outputs import OmniRequestOutput
|
||||||
@@ -650,7 +654,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@@ -338,7 +342,9 @@ async def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(aio=True),
|
||||||
|
)
|
||||||
# Add the servicer to the server
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
# Bind the server to the address
|
# Bind the server to the address
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
grpcio==1.78.1
|
grpcio==1.80.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
setuptools
|
setuptools
|
||||||
@@ -18,6 +18,10 @@ import backend_pb2_grpc
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||||
|
from grpc_auth import get_auth_interceptors
|
||||||
|
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
"""Check if a string can be converted to float."""
|
"""Check if a string can be converted to float."""
|
||||||
@@ -297,7 +301,9 @@ def serve(address):
|
|||||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
])
|
],
|
||||||
|
interceptors=get_auth_interceptors(),
|
||||||
|
)
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -8,6 +8,15 @@ else
|
|||||||
source $backend_dir/../common/libbackend.sh
|
source $backend_dir/../common/libbackend.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# The PyTorch CPU/CUDA indexes mirror common packages (e.g. requests) with
|
||||||
|
# limited, often outdated version sets. uv's default "first-index" strategy
|
||||||
|
# locks to the first index that carries a package, so it can pick e.g.
|
||||||
|
# requests==2.28.1 from the PyTorch index instead of a newer version from
|
||||||
|
# PyPI. voxcpm's transitive deps (datasets>=3 → requests>=2.32.2) need the
|
||||||
|
# PyPI versions. "unsafe-best-match" is safe here because we control both
|
||||||
|
# indexes and there is no dependency confusion risk.
|
||||||
|
export UV_INDEX_STRATEGY=unsafe-best-match
|
||||||
|
|
||||||
installRequirements
|
installRequirements
|
||||||
|
|
||||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
torch
|
torch
|
||||||
|
torchaudio
|
||||||
soundfile
|
soundfile
|
||||||
numpy
|
numpy
|
||||||
voxcpm
|
voxcpm>=1.5.0
|
||||||
torchcodec
|
torchcodec
|
||||||
@@ -5,4 +5,3 @@ certifi
|
|||||||
packaging==24.1
|
packaging==24.1
|
||||||
soundfile
|
soundfile
|
||||||
numpy
|
numpy
|
||||||
voxcpm
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user