mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
Compare commits
109 Commits
feat/vllm-
...
update/TUR
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9eb21e9a20 | ||
|
|
85ff7a310f | ||
|
|
607efe5a4c | ||
|
|
7d8c1d5e45 | ||
|
|
d18d434bb2 | ||
|
|
39573ecd2a | ||
|
|
a7dbb2a83d | ||
|
|
3ad9b16c29 | ||
|
|
c806d5ab73 | ||
|
|
47efaf5b43 | ||
|
|
315b634a91 | ||
|
|
6b245299d7 | ||
|
|
677c0315c1 | ||
|
|
478522ce4d | ||
|
|
c54897ad44 | ||
|
|
8bb1e8f21f | ||
|
|
cd94a0b61a | ||
|
|
047bc48fa9 | ||
|
|
01bd8ae5d0 | ||
|
|
d9808769be | ||
|
|
5973c0a9df | ||
|
|
486b5e25a3 | ||
|
|
c66c41e8d7 | ||
|
|
02bb715c0a | ||
|
|
8ab56e2ad3 | ||
|
|
ecf85fde9e | ||
|
|
6480715a16 | ||
|
|
f683231811 | ||
|
|
960757f0e8 | ||
|
|
865fd552f5 | ||
|
|
cb77a5a4b9 | ||
|
|
60633c4dd5 | ||
|
|
9e44944cc1 | ||
|
|
372eb08dcf | ||
|
|
28091d626e | ||
|
|
cae79d9107 | ||
|
|
babbbc6ec8 | ||
|
|
3804497186 | ||
|
|
fda1c553a1 | ||
|
|
b27de08fff | ||
|
|
510f791ccc | ||
|
|
369c50a41c | ||
|
|
75a63f87d8 | ||
|
|
9cd8d7951f | ||
|
|
884bfb84c9 | ||
|
|
e94a9a8f10 | ||
|
|
054c4b4b45 | ||
|
|
6e49dba27c | ||
|
|
e463820566 | ||
|
|
8839a71c87 | ||
|
|
117f6430b8 | ||
|
|
7809c5f5d0 | ||
|
|
ad742738cb | ||
|
|
86c673fd94 | ||
|
|
c49feb546f | ||
|
|
844b0b760b | ||
|
|
55c05211d3 | ||
|
|
a90a8cf1d0 | ||
|
|
12b069f9bd | ||
|
|
48e87db400 | ||
|
|
7dbd9c056a | ||
|
|
7c5d6162f7 | ||
|
|
5837b14888 | ||
|
|
b6a68e5df4 | ||
|
|
c6dfb4acaf | ||
|
|
ec5935421c | ||
|
|
a0cbc46be9 | ||
|
|
b4e30692a2 | ||
|
|
61d34ccb11 | ||
|
|
7f88a3ba30 | ||
|
|
c4f309388e | ||
|
|
ab326a9c61 | ||
|
|
df2d25cee5 | ||
|
|
96cd561d9d | ||
|
|
08445b1b89 | ||
|
|
ad3c8c4832 | ||
|
|
6f0051301b | ||
|
|
8487058673 | ||
|
|
62862ca06b | ||
|
|
07e244d869 | ||
|
|
95efb8a562 | ||
|
|
410d100cc3 | ||
|
|
833b7e8557 | ||
|
|
87e6de1989 | ||
|
|
b361d2ddd6 | ||
|
|
1e4c4577bb | ||
|
|
98fd9d5cc6 | ||
|
|
0c725f5702 | ||
|
|
7661a4ffa5 | ||
|
|
24ad6e4be1 | ||
|
|
c0648b8836 | ||
|
|
a05c7def59 | ||
|
|
906acba8db | ||
|
|
4226ca4aee | ||
|
|
c6d5dc3374 | ||
|
|
7ce675af21 | ||
|
|
be1b8d56c9 | ||
|
|
97f087ed31 | ||
|
|
8691bbe663 | ||
|
|
7998f96f11 | ||
|
|
cada97ee46 | ||
|
|
3375ea1a2c | ||
|
|
0e7c0adee4 | ||
|
|
016da02845 | ||
|
|
daa0272f2e | ||
|
|
d67623230f | ||
|
|
0f90d17aac | ||
|
|
ea32b8953f | ||
|
|
bc7578bdb1 |
@@ -129,6 +129,30 @@ After adding a new backend, verify:
|
||||
- [ ] No Makefile syntax errors (check with linter)
|
||||
- [ ] Follows the same pattern as similar backends (e.g., if it's a transcription backend, follow `faster-whisper` pattern)
|
||||
|
||||
## Bundling runtime shared libraries (`package.sh`)
|
||||
|
||||
The final `Dockerfile.python` stage is `FROM scratch` — there is no system `libc`, no `apt`, no fallback library path. Only files explicitly copied from the builder stage end up in the backend image. That means any runtime `dlopen` your backend (or its Python deps) needs **must** be packaged into `${BACKEND}/lib/`.
|
||||
|
||||
Pattern:
|
||||
|
||||
1. Make sure the library is installed in the builder stage of `backend/Dockerfile.python` (add it to the top-level `apt-get install`).
|
||||
2. Drop a `package.sh` in your backend directory that copies the library — and its soname symlinks — into `$(dirname $0)/lib`. See `backend/python/vllm/package.sh` for a reference implementation that walks `/usr/lib/x86_64-linux-gnu`, `/usr/lib/aarch64-linux-gnu`, etc.
|
||||
3. `Dockerfile.python` already runs `package.sh` automatically if it exists, after `package-gpu-libs.sh`.
|
||||
4. `libbackend.sh` automatically prepends `${EDIR}/lib` to `LD_LIBRARY_PATH` at run time, so anything packaged this way is found by `dlopen`.
|
||||
|
||||
How to find missing libs: when a Python module silently fails to register torch ops or you see `AttributeError: '_OpNamespace' '...' object has no attribute '...'`, run the backend image's Python with `LD_DEBUG=libs` to see which `dlopen` failed. The filename in the error message (e.g. `libnuma.so.1`) is what you need to package.
|
||||
|
||||
To verify packaging works without trusting the host:
|
||||
|
||||
```bash
|
||||
make docker-build-<backend>
|
||||
CID=$(docker create --entrypoint=/run.sh local-ai-backend:<backend>)
|
||||
docker cp $CID:/lib /tmp/check && docker rm $CID
|
||||
ls /tmp/check # expect the bundled .so files + symlinks
|
||||
```
|
||||
|
||||
Then boot it inside a fresh `ubuntu:24.04` (which intentionally does *not* have the lib installed) to confirm it actually loads from the backend dir.
|
||||
|
||||
## 6. Example: Adding a Python Backend
|
||||
|
||||
For reference, when `moonshine` was added:
|
||||
|
||||
101
.agents/ai-coding-assistants.md
Normal file
101
.agents/ai-coding-assistants.md
Normal file
@@ -0,0 +1,101 @@
|
||||
# AI Coding Assistants
|
||||
|
||||
This document provides guidance for AI tools and developers using AI
|
||||
assistance when contributing to LocalAI.
|
||||
|
||||
**LocalAI follows the same guidelines as the Linux kernel project for
|
||||
AI-assisted contributions.** See the upstream policy here:
|
||||
<https://docs.kernel.org/process/coding-assistants.html>
|
||||
|
||||
The rules below mirror that policy, adapted to LocalAI's license and
|
||||
project layout. If anything is unclear, the kernel document is the
|
||||
authoritative reference for intent.
|
||||
|
||||
AI tools helping with LocalAI development should follow the standard
|
||||
project development process:
|
||||
|
||||
- [CONTRIBUTING.md](../CONTRIBUTING.md) — development workflow, commit
|
||||
conventions, and PR guidelines
|
||||
- [.agents/coding-style.md](coding-style.md) — code style, editorconfig,
|
||||
logging, and documentation conventions
|
||||
- [.agents/building-and-testing.md](building-and-testing.md) — build and
|
||||
test procedures
|
||||
|
||||
## Licensing and Legal Requirements
|
||||
|
||||
All contributions must comply with LocalAI's licensing requirements:
|
||||
|
||||
- LocalAI is licensed under the **MIT License** — see the [LICENSE](../LICENSE)
|
||||
file
|
||||
- New source files should use the SPDX license identifier `MIT` where
|
||||
applicable to the file type
|
||||
- Contributions must be compatible with the MIT License and must not
|
||||
introduce code under incompatible licenses (e.g., GPL) without an
|
||||
explicit discussion with maintainers
|
||||
|
||||
## Signed-off-by and Developer Certificate of Origin
|
||||
|
||||
**AI agents MUST NOT add `Signed-off-by` tags.** Only humans can legally
|
||||
certify the Developer Certificate of Origin (DCO). The human submitter
|
||||
is responsible for:
|
||||
|
||||
- Reviewing all AI-generated code
|
||||
- Ensuring compliance with licensing requirements
|
||||
- Adding their own `Signed-off-by` tag (when the project requires DCO)
|
||||
to certify the contribution
|
||||
- Taking full responsibility for the contribution
|
||||
|
||||
AI agents MUST NOT add `Co-Authored-By` trailers for themselves either.
|
||||
A human reviewer owns the contribution; the AI's involvement is recorded
|
||||
via `Assisted-by` (see below).
|
||||
|
||||
## Attribution
|
||||
|
||||
When AI tools contribute to LocalAI development, proper attribution helps
|
||||
track the evolving role of AI in the development process. Contributions
|
||||
should include an `Assisted-by` tag in the commit message trailer in the
|
||||
following format:
|
||||
|
||||
```
|
||||
Assisted-by: AGENT_NAME:MODEL_VERSION [TOOL1] [TOOL2]
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `AGENT_NAME` — name of the AI tool or framework (e.g., `Claude`,
|
||||
`Copilot`, `Cursor`)
|
||||
- `MODEL_VERSION` — specific model version used (e.g.,
|
||||
`claude-opus-4-7`, `gpt-5`)
|
||||
- `[TOOL1] [TOOL2]` — optional specialized analysis tools invoked by the
|
||||
agent (e.g., `golangci-lint`, `staticcheck`, `go vet`)
|
||||
|
||||
Basic development tools (git, go, make, editors) should **not** be listed.
|
||||
|
||||
### Example
|
||||
|
||||
```
|
||||
fix(llama-cpp): handle empty tool call arguments
|
||||
|
||||
Previously the parser panicked when the model returned a tool call with
|
||||
an empty arguments object. Fall back to an empty JSON object in that
|
||||
case so downstream consumers receive a valid payload.
|
||||
|
||||
Assisted-by: Claude:claude-opus-4-7 golangci-lint
|
||||
Signed-off-by: Jane Developer <jane@example.com>
|
||||
```
|
||||
|
||||
## Scope and Responsibility
|
||||
|
||||
Using an AI assistant does not reduce the contributor's responsibility.
|
||||
The human submitter must:
|
||||
|
||||
- Understand every line that lands in the PR
|
||||
- Verify that generated code compiles, passes tests, and follows the
|
||||
project style
|
||||
- Confirm that any referenced APIs, flags, or file paths actually exist
|
||||
in the current tree (AI models may hallucinate identifiers)
|
||||
- Not submit AI output verbatim without review
|
||||
|
||||
Reviewers may ask for clarification on any change regardless of how it
|
||||
was produced. "An AI wrote it" is not an acceptable answer to a design
|
||||
question.
|
||||
115
.agents/vllm-backend.md
Normal file
115
.agents/vllm-backend.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# Working on the vLLM Backend
|
||||
|
||||
The vLLM backend lives at `backend/python/vllm/backend.py` (async gRPC) and the multimodal variant at `backend/python/vllm-omni/backend.py` (sync gRPC). Both wrap vLLM's `AsyncLLMEngine` / `Omni` and translate the LocalAI gRPC `PredictOptions` into vLLM `SamplingParams` + outputs into `Reply.chat_deltas`.
|
||||
|
||||
This file captures the non-obvious bits — most of the bring-up was a single PR (`feat/vllm-parity`) and the things below are easy to get wrong.
|
||||
|
||||
## Tool calling and reasoning use vLLM's *native* parsers
|
||||
|
||||
Do not write regex-based tool-call extractors for vLLM. vLLM ships:
|
||||
|
||||
- `vllm.tool_parsers.ToolParserManager` — 50+ registered parsers (`hermes`, `llama3_json`, `llama4_pythonic`, `mistral`, `qwen3_xml`, `deepseek_v3`, `granite4`, `openai`, `kimi_k2`, `glm45`, …)
|
||||
- `vllm.reasoning.ReasoningParserManager` — 25+ registered parsers (`deepseek_r1`, `qwen3`, `mistral`, `gemma4`, …)
|
||||
|
||||
Both can be used standalone: instantiate with a tokenizer, call `extract_tool_calls(text, request=None)` / `extract_reasoning(text, request=None)`. The backend stores the parser *classes* on `self.tool_parser_cls` / `self.reasoning_parser_cls` at LoadModel time and instantiates them per request.
|
||||
|
||||
**Selection:** vLLM does *not* auto-detect parsers from model name — neither does the LocalAI backend. The user (or `core/config/hooks_vllm.go`) must pick one and pass it via `Options[]`:
|
||||
|
||||
```yaml
|
||||
options:
|
||||
- tool_parser:hermes
|
||||
- reasoning_parser:qwen3
|
||||
```
|
||||
|
||||
Auto-defaults for known model families live in `core/config/parser_defaults.json` and are applied:
|
||||
- at gallery import time by `core/gallery/importers/vllm.go`
|
||||
- at model load time by the `vllm` / `vllm-omni` backend hook in `core/config/hooks_vllm.go`
|
||||
|
||||
User-supplied `tool_parser:`/`reasoning_parser:` in the config wins over defaults — the hook checks for existing entries before appending.
|
||||
|
||||
**When to update `parser_defaults.json`:** any time vLLM ships a new tool or reasoning parser, or you onboard a new model family that LocalAI users will pull from HuggingFace. The file is keyed by *family pattern* matched against `normalizeModelID(cfg.Model)` (lowercase, org-prefix stripped, `_`→`-`). Patterns are checked **longest-first** — keep `qwen3.5` before `qwen3`, `llama-3.3` before `llama-3`, etc., or the wrong family wins. Add a covering test in `core/config/hooks_test.go`.
|
||||
|
||||
**Sister file — `core/config/inference_defaults.json`:** same pattern but for sampling parameters (temperature, top_p, top_k, min_p, repeat_penalty, presence_penalty). Loaded by `core/config/inference_defaults.go` and applied by `ApplyInferenceDefaults()`. The schema is `map[string]float64` only — *strings don't fit*, which is why parser defaults needed their own JSON file. The inference file is **auto-generated from unsloth** via `go generate ./core/config/` (see `core/config/gen_inference_defaults/`) — don't hand-edit it; instead update the upstream source or regenerate. Both files share `normalizeModelID()` and the longest-first pattern ordering.
|
||||
|
||||
**Constructor compatibility gotcha:** the abstract `ToolParser.__init__` accepts `tools=`, but several concrete parsers (Hermes2ProToolParser, etc.) override `__init__` and *only* accept `tokenizer`. Always:
|
||||
|
||||
```python
|
||||
try:
|
||||
tp = self.tool_parser_cls(self.tokenizer, tools=tools)
|
||||
except TypeError:
|
||||
tp = self.tool_parser_cls(self.tokenizer)
|
||||
```
|
||||
|
||||
## ChatDelta is the streaming contract
|
||||
|
||||
The Go side (`core/backend/llm.go`, `pkg/functions/chat_deltas.go`) consumes `Reply.chat_deltas` to assemble the OpenAI response. For tool calls to surface in `chat/completions`, the Python backend **must** populate `Reply.chat_deltas[].tool_calls` with `ToolCallDelta{index, id, name, arguments}`. Returning the raw `<tool_call>...</tool_call>` text in `Reply.message` is *not* enough — the Go regex fallback exists for llama.cpp, not for vllm.
|
||||
|
||||
Same story for `reasoning_content` — emit it on `ChatDelta.reasoning_content`, not as part of `content`.
|
||||
|
||||
## Message conversion to chat templates
|
||||
|
||||
`tokenizer.apply_chat_template()` expects a list of dicts, not proto Messages. The shared helper in `backend/python/common/vllm_utils.py` (`messages_to_dicts`) handles the mapping including:
|
||||
|
||||
- `tool_call_id` and `name` for `role="tool"` messages
|
||||
- `tool_calls` JSON-string field → parsed Python list for `role="assistant"`
|
||||
- `reasoning_content` for thinking models
|
||||
|
||||
Pass `tools=json.loads(request.Tools)` and (when `request.Metadata.get("enable_thinking") == "true"`) `enable_thinking=True` to `apply_chat_template`. Wrap in `try/except TypeError` because not every tokenizer template accepts those kwargs.
|
||||
|
||||
## CPU support and the SIMD/library minefield
|
||||
|
||||
vLLM publishes prebuilt CPU wheels at `https://github.com/vllm-project/vllm/releases/...`. The pin lives in `backend/python/vllm/requirements-cpu-after.txt`.
|
||||
|
||||
**Version compatibility — important:** newer vllm CPU wheels (≥ 0.15) declare `torch==2.10.0+cpu` as a hard dep, but `torch==2.10.0` only exists on the PyTorch test channel and pulls in an incompatible `torchvision`. Stay on **`vllm 0.14.1+cpu` + `torch 2.9.1+cpu`** until both upstream catch up. Bumping requires verifying torchvision/torchaudio match.
|
||||
|
||||
`requirements-cpu.txt` uses `--extra-index-url https://download.pytorch.org/whl/cpu`. `install.sh` adds `--index-strategy=unsafe-best-match` for the `cpu` profile so uv resolves transformers/vllm from PyPI while pulling torch from the PyTorch index.
|
||||
|
||||
**SIMD baseline:** the prebuilt CPU wheel is compiled with AVX-512 VNNI/BF16. On a CPU without those instructions, importing `vllm.model_executor.models.registry` SIGILLs at `_run_in_subprocess` time during model inspection. There is no runtime flag to disable it. Workarounds:
|
||||
|
||||
1. **Run on a host with the right SIMD baseline** (default — fast)
|
||||
2. **Build from source** with `FROM_SOURCE=true` env var. Plumbing exists end-to-end:
|
||||
- `install.sh` hides `requirements-cpu-after.txt`, runs `installRequirements` for the base deps, then clones vllm and `VLLM_TARGET_DEVICE=cpu uv pip install --no-deps .`
|
||||
- `backend/Dockerfile.python` declares `ARG FROM_SOURCE` + `ENV FROM_SOURCE`
|
||||
- `Makefile` `docker-build-backend` macro forwards `--build-arg FROM_SOURCE=$(FROM_SOURCE)` when set
|
||||
- Source build takes 30–50 minutes — too slow for per-PR CI but fine for local.
|
||||
|
||||
**Runtime shared libraries:** vLLM's `vllm._C` extension `dlopen`s `libnuma.so.1` at import time. If missing, the C extension silently fails and `torch.ops._C_utils.init_cpu_threads_env` is never registered → `EngineCore` crashes on `init_device` with:
|
||||
|
||||
```
|
||||
AttributeError: '_OpNamespace' '_C_utils' object has no attribute 'init_cpu_threads_env'
|
||||
```
|
||||
|
||||
`backend/python/vllm/package.sh` bundles `libnuma.so.1` and `libgomp.so.1` into `${BACKEND}/lib/`, which `libbackend.sh` adds to `LD_LIBRARY_PATH` at run time. The builder stage in `backend/Dockerfile.python` installs `libnuma1`/`libgomp1` so package.sh has something to copy. Do *not* assume the production host has these — backend images are `FROM scratch`.
|
||||
|
||||
## Backend hook system (`core/config/backend_hooks.go`)
|
||||
|
||||
Per-backend defaults that used to be hardcoded in `ModelConfig.Prepare()` now live in `core/config/hooks_*.go` files and self-register via `init()`:
|
||||
|
||||
- `hooks_llamacpp.go` → GGUF metadata parsing, context size, GPU layers, jinja template
|
||||
- `hooks_vllm.go` → tool/reasoning parser auto-selection from `parser_defaults.json`
|
||||
|
||||
Hook keys:
|
||||
- `"llama-cpp"`, `"vllm"`, `"vllm-omni"`, … — backend-specific
|
||||
- `""` — runs only when `cfg.Backend` is empty (auto-detect case)
|
||||
- `"*"` — global catch-all, runs for every backend before specific hooks
|
||||
|
||||
Multiple hooks per key are supported and run in registration order. Adding a new backend default:
|
||||
|
||||
```go
|
||||
// core/config/hooks_<backend>.go
|
||||
func init() {
|
||||
RegisterBackendHook("<backend>", myDefaults)
|
||||
}
|
||||
func myDefaults(cfg *ModelConfig, modelPath string) {
|
||||
// only fill in fields the user didn't set
|
||||
}
|
||||
```
|
||||
|
||||
## The `Messages.ToProto()` fields you need to set
|
||||
|
||||
`core/schema/message.go:ToProto()` must serialize:
|
||||
- `ToolCallID` → `proto.Message.ToolCallId` (for `role="tool"` messages — links result back to the call)
|
||||
- `Reasoning` → `proto.Message.ReasoningContent`
|
||||
- `ToolCalls` → `proto.Message.ToolCalls` (JSON-encoded string)
|
||||
|
||||
These were originally not serialized and tool-calling conversations broke silently — the C++ llama.cpp backend reads them but always got empty strings. Any new field added to `schema.Message` *and* `proto.Message` needs a matching line in `ToProto()`.
|
||||
446
.github/gallery-agent/agent.go
vendored
446
.github/gallery-agent/agent.go
vendored
@@ -1,446 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ghodss/yaml"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/mudler/cogito/clients"
|
||||
"github.com/mudler/cogito/structures"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
var (
|
||||
openAIModel = os.Getenv("OPENAI_MODEL")
|
||||
openAIKey = os.Getenv("OPENAI_KEY")
|
||||
openAIBaseURL = os.Getenv("OPENAI_BASE_URL")
|
||||
galleryIndexPath = os.Getenv("GALLERY_INDEX_PATH")
|
||||
//defaultclient
|
||||
llm = clients.NewOpenAILLM(openAIModel, openAIKey, openAIBaseURL)
|
||||
)
|
||||
|
||||
// cleanTextContent removes trailing spaces, tabs, and normalizes line endings
|
||||
// to prevent YAML linting issues like trailing spaces and multiple empty lines
|
||||
func cleanTextContent(text string) string {
|
||||
lines := strings.Split(text, "\n")
|
||||
var cleanedLines []string
|
||||
var prevEmpty bool
|
||||
for _, line := range lines {
|
||||
// Remove all trailing whitespace (spaces, tabs, etc.)
|
||||
trimmed := strings.TrimRight(line, " \t\r")
|
||||
// Avoid multiple consecutive empty lines
|
||||
if trimmed == "" {
|
||||
if !prevEmpty {
|
||||
cleanedLines = append(cleanedLines, "")
|
||||
}
|
||||
prevEmpty = true
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, trimmed)
|
||||
prevEmpty = false
|
||||
}
|
||||
}
|
||||
// Remove trailing empty lines from the result
|
||||
result := strings.Join(cleanedLines, "\n")
|
||||
return stripThinkingTags(strings.TrimRight(result, "\n"))
|
||||
}
|
||||
|
||||
type galleryModel struct {
|
||||
Name string `yaml:"name"`
|
||||
Urls []string `yaml:"urls"`
|
||||
}
|
||||
|
||||
// isModelExisting checks if a specific model ID exists in the gallery using text search
|
||||
func isModelExisting(modelID string) (bool, error) {
|
||||
indexPath := getGalleryIndexPath()
|
||||
content, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read %s: %w", indexPath, err)
|
||||
}
|
||||
|
||||
var galleryModels []galleryModel
|
||||
|
||||
err = yaml.Unmarshal(content, &galleryModels)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to unmarshal %s: %w", indexPath, err)
|
||||
}
|
||||
|
||||
for _, galleryModel := range galleryModels {
|
||||
if slices.Contains(galleryModel.Urls, modelID) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// filterExistingModels removes models that already exist in the gallery
|
||||
func filterExistingModels(models []ProcessedModel) ([]ProcessedModel, error) {
|
||||
var filteredModels []ProcessedModel
|
||||
for _, model := range models {
|
||||
exists, err := isModelExisting(model.ModelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Error checking if model %s exists: %v, skipping\n", model.ModelID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !exists {
|
||||
filteredModels = append(filteredModels, model)
|
||||
} else {
|
||||
fmt.Printf("Skipping existing model: %s\n", model.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Filtered out %d existing models, %d new models remaining\n",
|
||||
len(models)-len(filteredModels), len(filteredModels))
|
||||
|
||||
return filteredModels, nil
|
||||
}
|
||||
|
||||
// getGalleryIndexPath returns the gallery index file path, with a default fallback
|
||||
func getGalleryIndexPath() string {
|
||||
if galleryIndexPath != "" {
|
||||
return galleryIndexPath
|
||||
}
|
||||
return "gallery/index.yaml"
|
||||
}
|
||||
|
||||
func stripThinkingTags(content string) string {
|
||||
// Remove content between <thinking> and </thinking> (including multi-line)
|
||||
content = regexp.MustCompile(`(?s)<thinking>.*?</thinking>`).ReplaceAllString(content, "")
|
||||
// Remove content between <think> and </think> (including multi-line)
|
||||
content = regexp.MustCompile(`(?s)<think>.*?</think>`).ReplaceAllString(content, "")
|
||||
// Clean up any extra whitespace
|
||||
content = strings.TrimSpace(content)
|
||||
return content
|
||||
}
|
||||
|
||||
func getRealReadme(ctx context.Context, repository string) (string, error) {
|
||||
// Create a conversation fragment
|
||||
fragment := cogito.NewEmptyFragment().
|
||||
AddMessage("user",
|
||||
`Your task is to get a clear description of a large language model from huggingface by using the provided tool. I will share with you a repository that might be quantized, and as such probably not by the original model author. We need to get the real description of the model, and not the one that might be quantized. You will have to call the tool to get the readme more than once by figuring out from the quantized readme which is the base model readme. This is the repository: `+repository)
|
||||
|
||||
// Execute with tools
|
||||
result, err := cogito.ExecuteTools(llm, fragment,
|
||||
cogito.WithIterations(3),
|
||||
cogito.WithMaxAttempts(3),
|
||||
cogito.DisableSinkState,
|
||||
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = result.AddMessage("user", "Describe the model in a clear and concise way that can be shared in a model gallery.")
|
||||
|
||||
// Get a response
|
||||
_, err = llm.Ask(ctx, result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
content := result.LastMessage().Content
|
||||
return cleanTextContent(content), nil
|
||||
}
|
||||
|
||||
func selectMostInterestingModels(ctx context.Context, searchResult *SearchResult) ([]ProcessedModel, error) {
|
||||
|
||||
if len(searchResult.Models) == 1 {
|
||||
return searchResult.Models, nil
|
||||
}
|
||||
|
||||
// Create a conversation fragment
|
||||
fragment := cogito.NewEmptyFragment().
|
||||
AddMessage("user",
|
||||
`Your task is to analyze a list of AI models and select the most interesting ones for a model gallery. You will be given detailed information about multiple models including their metadata, file information, and README content.
|
||||
|
||||
Consider the following criteria when selecting models:
|
||||
1. Model popularity (download count)
|
||||
2. Model recency (last modified date)
|
||||
3. Model completeness (has preferred model file, README, etc.)
|
||||
4. Model uniqueness (not duplicates or very similar models)
|
||||
5. Model quality (based on README content and description)
|
||||
6. Model utility (practical applications)
|
||||
|
||||
You should select models that would be most valuable for users browsing a model gallery. Prioritize models that are:
|
||||
- Well-documented with clear READMEs
|
||||
- Recently updated
|
||||
- Popular (high download count)
|
||||
- Have the preferred quantization format available
|
||||
- Offer unique capabilities or are from reputable authors
|
||||
|
||||
Return your analysis and selection reasoning.`)
|
||||
|
||||
// Add the search results as context
|
||||
modelsInfo := fmt.Sprintf("Found %d models matching '%s' with quantization preference '%s':\n\n",
|
||||
searchResult.TotalModelsFound, searchResult.SearchTerm, searchResult.Quantization)
|
||||
|
||||
for i, model := range searchResult.Models {
|
||||
modelsInfo += fmt.Sprintf("Model %d:\n", i+1)
|
||||
modelsInfo += fmt.Sprintf(" ID: %s\n", model.ModelID)
|
||||
modelsInfo += fmt.Sprintf(" Author: %s\n", model.Author)
|
||||
modelsInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads)
|
||||
modelsInfo += fmt.Sprintf(" Last Modified: %s\n", model.LastModified)
|
||||
modelsInfo += fmt.Sprintf(" Files: %d files\n", len(model.Files))
|
||||
|
||||
if model.PreferredModelFile != nil {
|
||||
modelsInfo += fmt.Sprintf(" Preferred Model File: %s (%d bytes)\n",
|
||||
model.PreferredModelFile.Path, model.PreferredModelFile.Size)
|
||||
} else {
|
||||
modelsInfo += " No preferred model file found\n"
|
||||
}
|
||||
|
||||
if model.ReadmeContent != "" {
|
||||
modelsInfo += fmt.Sprintf(" README: %s\n", model.ReadmeContent)
|
||||
}
|
||||
|
||||
if model.ProcessingError != "" {
|
||||
modelsInfo += fmt.Sprintf(" Processing Error: %s\n", model.ProcessingError)
|
||||
}
|
||||
|
||||
modelsInfo += "\n"
|
||||
}
|
||||
|
||||
fragment = fragment.AddMessage("user", modelsInfo)
|
||||
|
||||
fragment = fragment.AddMessage("user", "Based on your analysis, select the top 5 most interesting models and provide a brief explanation for each selection. Also, create a filtered SearchResult with only the selected models. Return just a list of repositories IDs, you will later be asked to output it as a JSON array with the json tool.")
|
||||
|
||||
// Get a response
|
||||
newFragment, err := llm.Ask(ctx, fragment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println(newFragment.LastMessage().Content)
|
||||
repositories := struct {
|
||||
Repositories []string `json:"repositories"`
|
||||
}{}
|
||||
|
||||
s := structures.Structure{
|
||||
Schema: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
AdditionalProperties: false,
|
||||
Properties: map[string]jsonschema.Definition{
|
||||
"repositories": {
|
||||
Type: jsonschema.Array,
|
||||
Items: &jsonschema.Definition{Type: jsonschema.String},
|
||||
Description: "The trending repositories IDs",
|
||||
},
|
||||
},
|
||||
Required: []string{"repositories"},
|
||||
},
|
||||
Object: &repositories,
|
||||
}
|
||||
|
||||
err = newFragment.ExtractStructure(ctx, llm, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filteredModels := []ProcessedModel{}
|
||||
for _, m := range searchResult.Models {
|
||||
if slices.Contains(repositories.Repositories, m.ModelID) {
|
||||
filteredModels = append(filteredModels, m)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredModels, nil
|
||||
}
|
||||
|
||||
// ModelMetadata represents extracted metadata from a model
|
||||
type ModelMetadata struct {
|
||||
Tags []string `json:"tags"`
|
||||
License string `json:"license"`
|
||||
}
|
||||
|
||||
// extractModelMetadata extracts tags and license from model README and documentation
|
||||
func extractModelMetadata(ctx context.Context, model ProcessedModel) ([]string, string, error) {
|
||||
// Create a conversation fragment
|
||||
fragment := cogito.NewEmptyFragment().
|
||||
AddMessage("user",
|
||||
`Your task is to extract metadata from an AI model's README and documentation. You will be provided with:
|
||||
1. Model information (ID, author, description)
|
||||
2. README content
|
||||
|
||||
You need to extract:
|
||||
1. **Tags**: An array of relevant tags that describe the model. Use common tags from the gallery such as:
|
||||
- llm, gguf, gpu, cpu, multimodal, image-to-text, text-to-text, text-to-speech, tts
|
||||
- thinking, reasoning, chat, instruction-tuned, code, vision
|
||||
- Model family names (e.g., llama, qwen, mistral, gemma) if applicable
|
||||
- Any other relevant descriptive tags
|
||||
Select 3-8 most relevant tags.
|
||||
|
||||
2. **License**: The license identifier (e.g., "apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", "cc-by-4.0").
|
||||
If no license is found, return an empty string.
|
||||
|
||||
Return the extracted metadata in a structured format.`)
|
||||
|
||||
// Add model information
|
||||
modelInfo := "Model Information:\n"
|
||||
modelInfo += fmt.Sprintf(" ID: %s\n", model.ModelID)
|
||||
modelInfo += fmt.Sprintf(" Author: %s\n", model.Author)
|
||||
modelInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads)
|
||||
if model.ReadmeContent != "" {
|
||||
modelInfo += fmt.Sprintf(" README Content:\n%s\n", model.ReadmeContent)
|
||||
} else if model.ReadmeContentPreview != "" {
|
||||
modelInfo += fmt.Sprintf(" README Preview: %s\n", model.ReadmeContentPreview)
|
||||
}
|
||||
|
||||
fragment = fragment.AddMessage("user", modelInfo)
|
||||
fragment = fragment.AddMessage("user", "Extract the tags and license from the model information. Return the metadata as a JSON object with 'tags' (array of strings) and 'license' (string).")
|
||||
|
||||
// Get a response
|
||||
newFragment, err := llm.Ask(ctx, fragment)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Extract structured metadata
|
||||
metadata := ModelMetadata{}
|
||||
|
||||
s := structures.Structure{
|
||||
Schema: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
AdditionalProperties: false,
|
||||
Properties: map[string]jsonschema.Definition{
|
||||
"tags": {
|
||||
Type: jsonschema.Array,
|
||||
Items: &jsonschema.Definition{Type: jsonschema.String},
|
||||
Description: "Array of relevant tags describing the model",
|
||||
},
|
||||
"license": {
|
||||
Type: jsonschema.String,
|
||||
Description: "License identifier (e.g., apache-2.0, mit, llama2). Empty string if not found.",
|
||||
},
|
||||
},
|
||||
Required: []string{"tags", "license"},
|
||||
},
|
||||
Object: &metadata,
|
||||
}
|
||||
|
||||
err = newFragment.ExtractStructure(ctx, llm, s)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return metadata.Tags, metadata.License, nil
|
||||
}
|
||||
|
||||
// extractIconFromReadme scans the README content for image URLs and returns the first suitable icon URL found
|
||||
func extractIconFromReadme(readmeContent string) string {
|
||||
if readmeContent == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Regular expressions to match image URLs in various formats (case-insensitive)
|
||||
// Match markdown image syntax:  - case insensitive extensions
|
||||
markdownImageRegex := regexp.MustCompile(`(?i)!\[[^\]]*\]\(([^)]+\.(png|jpg|jpeg|svg|webp|gif))\)`)
|
||||
// Match HTML img tags: <img src="url">
|
||||
htmlImageRegex := regexp.MustCompile(`(?i)<img[^>]+src=["']([^"']+\.(png|jpg|jpeg|svg|webp|gif))["']`)
|
||||
// Match plain URLs ending with image extensions
|
||||
plainImageRegex := regexp.MustCompile(`(?i)https?://[^\s<>"']+\.(png|jpg|jpeg|svg|webp|gif)`)
|
||||
|
||||
// Try markdown format first
|
||||
matches := markdownImageRegex.FindStringSubmatch(readmeContent)
|
||||
if len(matches) > 1 && matches[1] != "" {
|
||||
url := strings.TrimSpace(matches[1])
|
||||
// Prefer HuggingFace CDN URLs or absolute URLs
|
||||
if strings.HasPrefix(strings.ToLower(url), "http") {
|
||||
return url
|
||||
}
|
||||
}
|
||||
|
||||
// Try HTML img tags
|
||||
matches = htmlImageRegex.FindStringSubmatch(readmeContent)
|
||||
if len(matches) > 1 && matches[1] != "" {
|
||||
url := strings.TrimSpace(matches[1])
|
||||
if strings.HasPrefix(strings.ToLower(url), "http") {
|
||||
return url
|
||||
}
|
||||
}
|
||||
|
||||
// Try plain URLs
|
||||
matches = plainImageRegex.FindStringSubmatch(readmeContent)
|
||||
if len(matches) > 0 {
|
||||
url := strings.TrimSpace(matches[0])
|
||||
if strings.HasPrefix(strings.ToLower(url), "http") {
|
||||
return url
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getHuggingFaceAvatarURL attempts to get the HuggingFace avatar URL for a user
|
||||
func getHuggingFaceAvatarURL(author string) string {
|
||||
if author == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to fetch user info from HuggingFace API
|
||||
// HuggingFace API endpoint: https://huggingface.co/api/users/{username}
|
||||
baseURL := "https://huggingface.co"
|
||||
userURL := fmt.Sprintf("%s/api/users/%s", baseURL, author)
|
||||
|
||||
req, err := http.NewRequest("GET", userURL, nil)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse the response to get avatar URL
|
||||
var userInfo map[string]any
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract avatar URL from response
|
||||
if avatar, ok := userInfo["avatarUrl"].(string); ok && avatar != "" {
|
||||
return avatar
|
||||
}
|
||||
if avatar, ok := userInfo["avatar"].(string); ok && avatar != "" {
|
||||
return avatar
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractModelIcon extracts icon URL from README or falls back to HuggingFace avatar
|
||||
func extractModelIcon(model ProcessedModel) string {
|
||||
// First, try to extract icon from README
|
||||
if icon := extractIconFromReadme(model.ReadmeContent); icon != "" {
|
||||
return icon
|
||||
}
|
||||
|
||||
// Fallback: Try to get HuggingFace user avatar
|
||||
if model.Author != "" {
|
||||
if avatar := getHuggingFaceAvatarURL(model.Author); avatar != "" {
|
||||
return avatar
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
2
.github/gallery-agent/gallery.go
vendored
2
.github/gallery-agent/gallery.go
vendored
@@ -7,8 +7,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ghodss/yaml"
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
"sigs.k8s.io/yaml"
|
||||
)
|
||||
|
||||
func formatTextContent(text string) string {
|
||||
|
||||
301
.github/gallery-agent/helpers.go
vendored
Normal file
301
.github/gallery-agent/helpers.go
vendored
Normal file
@@ -0,0 +1,301 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
"sigs.k8s.io/yaml"
|
||||
)
|
||||
|
||||
var galleryIndexPath = os.Getenv("GALLERY_INDEX_PATH")
|
||||
|
||||
// getGalleryIndexPath returns the gallery index file path, with a default fallback
|
||||
func getGalleryIndexPath() string {
|
||||
if galleryIndexPath != "" {
|
||||
return galleryIndexPath
|
||||
}
|
||||
return "gallery/index.yaml"
|
||||
}
|
||||
|
||||
type galleryModel struct {
|
||||
Name string `yaml:"name"`
|
||||
Urls []string `yaml:"urls"`
|
||||
}
|
||||
|
||||
// loadGalleryURLSet parses gallery/index.yaml once and returns the set of
|
||||
// HuggingFace model URLs already present in the gallery.
|
||||
func loadGalleryURLSet() (map[string]struct{}, error) {
|
||||
indexPath := getGalleryIndexPath()
|
||||
content, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read %s: %w", indexPath, err)
|
||||
}
|
||||
|
||||
var galleryModels []galleryModel
|
||||
if err := yaml.Unmarshal(content, &galleryModels); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal %s: %w", indexPath, err)
|
||||
}
|
||||
|
||||
set := make(map[string]struct{}, len(galleryModels))
|
||||
for _, gm := range galleryModels {
|
||||
for _, u := range gm.Urls {
|
||||
set[u] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Also skip URLs already proposed in open (unmerged) gallery-agent PRs.
|
||||
// The workflow injects these via EXTRA_SKIP_URLS so we don't keep
|
||||
// re-proposing the same model every run while a PR is waiting to merge.
|
||||
for _, line := range strings.FieldsFunc(os.Getenv("EXTRA_SKIP_URLS"), func(r rune) bool {
|
||||
return r == '\n' || r == ',' || r == ' '
|
||||
}) {
|
||||
u := strings.TrimSpace(line)
|
||||
if u != "" {
|
||||
set[u] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
// modelAlreadyInGallery checks whether a HuggingFace model repo is already
|
||||
// referenced in the gallery URL set.
|
||||
func modelAlreadyInGallery(set map[string]struct{}, modelID string) bool {
|
||||
_, ok := set["https://huggingface.co/"+modelID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// baseModelFromTags returns the first `base_model:<repo>` value found in the
|
||||
// tag list, or "" if none is present. HuggingFace surfaces the base model
|
||||
// declared in the model card's YAML frontmatter as such a tag.
|
||||
func baseModelFromTags(tags []string) string {
|
||||
for _, t := range tags {
|
||||
if strings.HasPrefix(t, "base_model:") {
|
||||
return strings.TrimPrefix(t, "base_model:")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// licenseFromTags returns the `license:<id>` value from the tag list, or "".
|
||||
func licenseFromTags(tags []string) string {
|
||||
for _, t := range tags {
|
||||
if strings.HasPrefix(t, "license:") {
|
||||
return strings.TrimPrefix(t, "license:")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// curatedTags produces the gallery tag list from HuggingFace's raw tag set.
|
||||
// Always includes llm + gguf, then adds whitelisted family / capability
|
||||
// markers when they appear in the HF tag list.
|
||||
func curatedTags(hfTags []string) []string {
|
||||
whitelist := []string{
|
||||
"gpu", "cpu",
|
||||
"llama", "mistral", "mixtral", "qwen", "qwen2", "qwen3",
|
||||
"gemma", "gemma2", "gemma3", "phi", "phi3", "phi4",
|
||||
"deepseek", "yi", "falcon", "command-r",
|
||||
"vision", "multimodal", "code", "chat",
|
||||
"instruction-tuned", "reasoning", "thinking",
|
||||
}
|
||||
seen := map[string]struct{}{}
|
||||
out := []string{"llm", "gguf"}
|
||||
seen["llm"] = struct{}{}
|
||||
seen["gguf"] = struct{}{}
|
||||
|
||||
hfSet := map[string]struct{}{}
|
||||
for _, t := range hfTags {
|
||||
hfSet[strings.ToLower(t)] = struct{}{}
|
||||
}
|
||||
for _, w := range whitelist {
|
||||
if _, ok := hfSet[w]; ok {
|
||||
if _, dup := seen[w]; !dup {
|
||||
out = append(out, w)
|
||||
seen[w] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// resolveReadme fetches a description-quality README for a (possibly
|
||||
// quantized) repo: if a `base_model:` tag is present, fetch the base repo's
|
||||
// README; otherwise fall back to the repo's own README.
|
||||
func resolveReadme(client *hfapi.Client, modelID string, hfTags []string) (string, error) {
|
||||
if base := baseModelFromTags(hfTags); base != "" && base != modelID {
|
||||
if content, err := client.GetReadmeContent(base, "README.md"); err == nil && strings.TrimSpace(content) != "" {
|
||||
return cleanTextContent(content), nil
|
||||
}
|
||||
}
|
||||
content, err := client.GetReadmeContent(modelID, "README.md")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cleanTextContent(content), nil
|
||||
}
|
||||
|
||||
// extractDescription turns a raw HuggingFace README into a concise plain-text
|
||||
// description suitable for embedding in gallery/index.yaml: strips YAML
|
||||
// frontmatter, HTML tags/comments, markdown images, link URLs (keeping the
|
||||
// link text), markdown tables, and then truncates at a paragraph boundary
|
||||
// around ~1200 characters. Raw README should still be used for icon
|
||||
// extraction — call this only for the `description:` field.
|
||||
func extractDescription(readme string) string {
|
||||
s := readme
|
||||
|
||||
// Strip leading YAML frontmatter: `---\n...\n---\n` at start of file.
|
||||
if strings.HasPrefix(strings.TrimLeft(s, " \t\n"), "---") {
|
||||
trimmed := strings.TrimLeft(s, " \t\n")
|
||||
rest := strings.TrimPrefix(trimmed, "---")
|
||||
if idx := strings.Index(rest, "\n---"); idx >= 0 {
|
||||
after := rest[idx+len("\n---"):]
|
||||
after = strings.TrimPrefix(after, "\n")
|
||||
s = after
|
||||
}
|
||||
}
|
||||
|
||||
// Strip HTML comments and tags.
|
||||
s = regexp.MustCompile(`(?s)<!--.*?-->`).ReplaceAllString(s, "")
|
||||
s = regexp.MustCompile(`(?is)<[^>]+>`).ReplaceAllString(s, "")
|
||||
|
||||
// Strip markdown images entirely.
|
||||
s = regexp.MustCompile(`!\[[^\]]*\]\([^)]*\)`).ReplaceAllString(s, "")
|
||||
// Replace markdown links `[text](url)` with just `text`.
|
||||
s = regexp.MustCompile(`\[([^\]]+)\]\([^)]+\)`).ReplaceAllString(s, "$1")
|
||||
|
||||
// Drop table lines and horizontal rules, and flatten all leading
|
||||
// whitespace: generateYAMLEntry embeds this under a `description: |`
|
||||
// literal block whose indentation is set by the first non-empty line.
|
||||
// If any line has extra leading whitespace (e.g. from an indented
|
||||
// `<p align="center">` block in the original README), YAML will pick
|
||||
// that up as the block's indent and every later line at a smaller
|
||||
// indent blows the block scalar. Stripping leading whitespace here
|
||||
// guarantees uniform 4-space indentation after formatTextContent runs.
|
||||
var kept []string
|
||||
for _, line := range strings.Split(s, "\n") {
|
||||
t := strings.TrimLeft(line, " \t")
|
||||
ts := strings.TrimSpace(t)
|
||||
if strings.HasPrefix(ts, "|") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(ts, ":--") || strings.HasPrefix(ts, "---") || strings.HasPrefix(ts, "===") {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, t)
|
||||
}
|
||||
s = strings.Join(kept, "\n")
|
||||
|
||||
// Normalise whitespace and drop any leading blank lines so the literal
|
||||
// block in YAML doesn't start with a blank first line (which would
|
||||
// break the indentation detector the same way).
|
||||
s = cleanTextContent(s)
|
||||
s = strings.TrimLeft(s, " \t\n")
|
||||
|
||||
// Truncate at a paragraph boundary around maxLen chars.
|
||||
const maxLen = 1200
|
||||
if len(s) > maxLen {
|
||||
cut := strings.LastIndex(s[:maxLen], "\n\n")
|
||||
if cut < maxLen/3 {
|
||||
cut = maxLen
|
||||
}
|
||||
s = strings.TrimRight(s[:cut], " \t\n") + "\n\n..."
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// cleanTextContent removes trailing spaces/tabs and collapses multiple empty
|
||||
// lines so README content embeds cleanly into YAML without lint noise.
|
||||
func cleanTextContent(text string) string {
|
||||
lines := strings.Split(text, "\n")
|
||||
var cleaned []string
|
||||
var prevEmpty bool
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimRight(line, " \t\r")
|
||||
if trimmed == "" {
|
||||
if !prevEmpty {
|
||||
cleaned = append(cleaned, "")
|
||||
}
|
||||
prevEmpty = true
|
||||
} else {
|
||||
cleaned = append(cleaned, trimmed)
|
||||
prevEmpty = false
|
||||
}
|
||||
}
|
||||
return strings.TrimRight(strings.Join(cleaned, "\n"), "\n")
|
||||
}
|
||||
|
||||
// extractIconFromReadme scans README content for an image URL usable as a
|
||||
// gallery entry icon.
|
||||
func extractIconFromReadme(readmeContent string) string {
|
||||
if readmeContent == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
markdownImageRegex := regexp.MustCompile(`(?i)!\[[^\]]*\]\(([^)]+\.(png|jpg|jpeg|svg|webp|gif))\)`)
|
||||
htmlImageRegex := regexp.MustCompile(`(?i)<img[^>]+src=["']([^"']+\.(png|jpg|jpeg|svg|webp|gif))["']`)
|
||||
plainImageRegex := regexp.MustCompile(`(?i)https?://[^\s<>"']+\.(png|jpg|jpeg|svg|webp|gif)`)
|
||||
|
||||
if m := markdownImageRegex.FindStringSubmatch(readmeContent); len(m) > 1 && strings.HasPrefix(strings.ToLower(m[1]), "http") {
|
||||
return strings.TrimSpace(m[1])
|
||||
}
|
||||
if m := htmlImageRegex.FindStringSubmatch(readmeContent); len(m) > 1 && strings.HasPrefix(strings.ToLower(m[1]), "http") {
|
||||
return strings.TrimSpace(m[1])
|
||||
}
|
||||
if m := plainImageRegex.FindStringSubmatch(readmeContent); len(m) > 0 && strings.HasPrefix(strings.ToLower(m[0]), "http") {
|
||||
return strings.TrimSpace(m[0])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getHuggingFaceAvatarURL returns the HF avatar URL for a user, or "".
|
||||
func getHuggingFaceAvatarURL(author string) string {
|
||||
if author == "" {
|
||||
return ""
|
||||
}
|
||||
userURL := fmt.Sprintf("https://huggingface.co/api/users/%s/overview", author)
|
||||
resp, err := http.Get(userURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return ""
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var info map[string]any
|
||||
if err := json.Unmarshal(body, &info); err != nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := info["avatarUrl"].(string); ok && v != "" {
|
||||
return v
|
||||
}
|
||||
if v, ok := info["avatar"].(string); ok && v != "" {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractModelIcon extracts an icon URL from the README, falling back to the
|
||||
// HuggingFace user avatar.
|
||||
func extractModelIcon(model ProcessedModel) string {
|
||||
if icon := extractIconFromReadme(model.ReadmeContent); icon != "" {
|
||||
return icon
|
||||
}
|
||||
if model.Author != "" {
|
||||
if avatar := getHuggingFaceAvatarURL(model.Author); avatar != "" {
|
||||
return avatar
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
409
.github/gallery-agent/main.go
vendored
409
.github/gallery-agent/main.go
vendored
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
@@ -39,16 +38,6 @@ type ProcessedModel struct {
|
||||
Icon string `json:"icon,omitempty"`
|
||||
}
|
||||
|
||||
// SearchResult represents the complete result of searching and processing models
|
||||
type SearchResult struct {
|
||||
SearchTerm string `json:"search_term"`
|
||||
Limit int `json:"limit"`
|
||||
Quantization string `json:"quantization"`
|
||||
TotalModelsFound int `json:"total_models_found"`
|
||||
Models []ProcessedModel `json:"models"`
|
||||
FormattedOutput string `json:"formatted_output"`
|
||||
}
|
||||
|
||||
// AddedModelSummary represents a summary of models added to the gallery
|
||||
type AddedModelSummary struct {
|
||||
SearchTerm string `json:"search_term"`
|
||||
@@ -63,19 +52,16 @@ type AddedModelSummary struct {
|
||||
func main() {
|
||||
startTime := time.Now()
|
||||
|
||||
// Check for synthetic mode
|
||||
syntheticMode := os.Getenv("SYNTHETIC_MODE")
|
||||
if syntheticMode == "true" || syntheticMode == "1" {
|
||||
// Synthetic mode for local testing
|
||||
if sm := os.Getenv("SYNTHETIC_MODE"); sm == "true" || sm == "1" {
|
||||
fmt.Println("Running in SYNTHETIC MODE - generating random test data")
|
||||
err := runSyntheticMode()
|
||||
if err != nil {
|
||||
if err := runSyntheticMode(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error in synthetic mode: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Get configuration from environment variables
|
||||
searchTerm := os.Getenv("SEARCH_TERM")
|
||||
if searchTerm == "" {
|
||||
searchTerm = "GGUF"
|
||||
@@ -83,7 +69,7 @@ func main() {
|
||||
|
||||
limitStr := os.Getenv("LIMIT")
|
||||
if limitStr == "" {
|
||||
limitStr = "5"
|
||||
limitStr = "15"
|
||||
}
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil {
|
||||
@@ -92,287 +78,197 @@ func main() {
|
||||
}
|
||||
|
||||
quantization := os.Getenv("QUANTIZATION")
|
||||
|
||||
maxModels := os.Getenv("MAX_MODELS")
|
||||
if maxModels == "" {
|
||||
maxModels = "1"
|
||||
if quantization == "" {
|
||||
quantization = "Q4_K_M"
|
||||
}
|
||||
maxModelsInt, err := strconv.Atoi(maxModels)
|
||||
|
||||
maxModelsStr := os.Getenv("MAX_MODELS")
|
||||
if maxModelsStr == "" {
|
||||
maxModelsStr = "1"
|
||||
}
|
||||
maxModels, err := strconv.Atoi(maxModelsStr)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing MAX_MODELS: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Print configuration
|
||||
fmt.Printf("Gallery Agent Configuration:\n")
|
||||
fmt.Printf(" Search Term: %s\n", searchTerm)
|
||||
fmt.Printf(" Limit: %d\n", limit)
|
||||
fmt.Printf(" Quantization: %s\n", quantization)
|
||||
fmt.Printf(" Max Models to Add: %d\n", maxModelsInt)
|
||||
fmt.Printf(" Gallery Index Path: %s\n", os.Getenv("GALLERY_INDEX_PATH"))
|
||||
fmt.Printf(" Max Models to Add: %d\n", maxModels)
|
||||
fmt.Printf(" Gallery Index Path: %s\n", getGalleryIndexPath())
|
||||
fmt.Println()
|
||||
|
||||
result, err := searchAndProcessModels(searchTerm, limit, quantization)
|
||||
// Phase 1: load current gallery and query HuggingFace.
|
||||
gallerySet, err := loadGalleryURLSet()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Error loading gallery index: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("Loaded %d existing gallery entries\n", len(gallerySet))
|
||||
|
||||
fmt.Println(result.FormattedOutput)
|
||||
var models []ProcessedModel
|
||||
|
||||
if len(result.Models) > 1 {
|
||||
fmt.Println("More than one model found (", len(result.Models), "), using AI agent to select the most interesting models")
|
||||
for _, model := range result.Models {
|
||||
fmt.Println("Model: ", model.ModelID)
|
||||
}
|
||||
// Use AI agent to select the most interesting models
|
||||
fmt.Println("Using AI agent to select the most interesting models...")
|
||||
models, err = selectMostInterestingModels(context.Background(), result)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error in model selection: %v\n", err)
|
||||
// Continue with original result if selection fails
|
||||
models = result.Models
|
||||
}
|
||||
} else if len(result.Models) == 1 {
|
||||
models = result.Models
|
||||
fmt.Println("Only one model found, using it directly")
|
||||
}
|
||||
|
||||
fmt.Print(models)
|
||||
|
||||
// Filter out models that already exist in the gallery
|
||||
fmt.Println("Filtering out existing models...")
|
||||
models, err = filterExistingModels(models)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error filtering existing models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Limit to maxModelsInt after filtering
|
||||
if len(models) > maxModelsInt {
|
||||
models = models[:maxModelsInt]
|
||||
}
|
||||
|
||||
// Track added models for summary
|
||||
var addedModelIDs []string
|
||||
var addedModelURLs []string
|
||||
|
||||
// Generate YAML entries and append to gallery/index.yaml
|
||||
if len(models) > 0 {
|
||||
for _, model := range models {
|
||||
addedModelIDs = append(addedModelIDs, model.ModelID)
|
||||
// Generate Hugging Face URL for the model
|
||||
modelURL := fmt.Sprintf("https://huggingface.co/%s", model.ModelID)
|
||||
addedModelURLs = append(addedModelURLs, modelURL)
|
||||
}
|
||||
fmt.Println("Generating YAML entries for selected models...")
|
||||
err = generateYAMLForModels(context.Background(), models, quantization)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error generating YAML entries: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("No new models to add to the gallery.")
|
||||
}
|
||||
|
||||
// Create and write summary
|
||||
processingTime := time.Since(startTime).String()
|
||||
summary := AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: result.TotalModelsFound,
|
||||
ModelsAdded: len(addedModelIDs),
|
||||
AddedModelIDs: addedModelIDs,
|
||||
AddedModelURLs: addedModelURLs,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: processingTime,
|
||||
}
|
||||
|
||||
// Write summary to file
|
||||
summaryData, err := json.MarshalIndent(summary, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marshaling summary: %v\n", err)
|
||||
} else {
|
||||
err = os.WriteFile("gallery-agent-summary.json", summaryData, 0644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing summary file: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("Summary written to gallery-agent-summary.json\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func searchAndProcessModels(searchTerm string, limit int, quantization string) (*SearchResult, error) {
|
||||
client := hfapi.NewClient()
|
||||
var outputBuilder strings.Builder
|
||||
|
||||
fmt.Println("Searching for models...")
|
||||
// Initialize the result struct
|
||||
result := &SearchResult{
|
||||
SearchTerm: searchTerm,
|
||||
Limit: limit,
|
||||
Quantization: quantization,
|
||||
Models: []ProcessedModel{},
|
||||
}
|
||||
|
||||
models, err := client.GetLatest(searchTerm, limit)
|
||||
fmt.Println("Searching for trending models on HuggingFace...")
|
||||
rawModels, err := client.GetTrending(searchTerm, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch models: %w", err)
|
||||
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("Found %d trending models matching %q\n", len(rawModels), searchTerm)
|
||||
totalFound := len(rawModels)
|
||||
|
||||
// Phase 2: drop anything already in the gallery *before* any expensive
|
||||
// per-model work (GetModelDetails, README fetches, icon lookups).
|
||||
fresh := rawModels[:0]
|
||||
for _, m := range rawModels {
|
||||
if modelAlreadyInGallery(gallerySet, m.ModelID) {
|
||||
fmt.Printf("Skipping existing model: %s\n", m.ModelID)
|
||||
continue
|
||||
}
|
||||
fresh = append(fresh, m)
|
||||
}
|
||||
fmt.Printf("%d candidates after gallery dedup\n", len(fresh))
|
||||
|
||||
// Phase 3: HuggingFace already returned these in trendingScore order —
|
||||
// just cap to MAX_MODELS.
|
||||
if len(fresh) > maxModels {
|
||||
fresh = fresh[:maxModels]
|
||||
}
|
||||
if len(fresh) == 0 {
|
||||
fmt.Println("No new models to add to the gallery.")
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: totalFound,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Models found:", len(models))
|
||||
result.TotalModelsFound = len(models)
|
||||
// Phase 4: fetch details and build ProcessedModel entries for survivors.
|
||||
var processed []ProcessedModel
|
||||
quantPrefs := []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K", "Q8_0"}
|
||||
for _, m := range fresh {
|
||||
fmt.Printf("Processing model: %s (downloads=%d)\n", m.ModelID, m.Downloads)
|
||||
|
||||
if len(models) == 0 {
|
||||
outputBuilder.WriteString("No models found.\n")
|
||||
result.FormattedOutput = outputBuilder.String()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
outputBuilder.WriteString(fmt.Sprintf("Found %d models matching '%s':\n\n", len(models), searchTerm))
|
||||
|
||||
// Process each model
|
||||
for i, model := range models {
|
||||
outputBuilder.WriteString(fmt.Sprintf("%d. Processing Model: %s\n", i+1, model.ModelID))
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Author: %s\n", model.Author))
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Downloads: %d\n", model.Downloads))
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Last Modified: %s\n", model.LastModified))
|
||||
|
||||
// Initialize processed model struct
|
||||
processedModel := ProcessedModel{
|
||||
ModelID: model.ModelID,
|
||||
Author: model.Author,
|
||||
Downloads: model.Downloads,
|
||||
LastModified: model.LastModified,
|
||||
QuantizationPreferences: []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"},
|
||||
pm := ProcessedModel{
|
||||
ModelID: m.ModelID,
|
||||
Author: m.Author,
|
||||
Downloads: m.Downloads,
|
||||
LastModified: m.LastModified,
|
||||
QuantizationPreferences: quantPrefs,
|
||||
}
|
||||
|
||||
// Get detailed model information
|
||||
details, err := client.GetModelDetails(model.ModelID)
|
||||
details, err := client.GetModelDetails(m.ModelID)
|
||||
if err != nil {
|
||||
errorMsg := fmt.Sprintf(" Error getting model details: %v\n", err)
|
||||
outputBuilder.WriteString(errorMsg)
|
||||
processedModel.ProcessingError = err.Error()
|
||||
result.Models = append(result.Models, processedModel)
|
||||
fmt.Printf(" Error getting model details: %v (skipping)\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Define quantization preferences (in order of preference)
|
||||
quantizationPreferences := []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"}
|
||||
preferred := hfapi.FindPreferredModelFile(details.Files, quantPrefs)
|
||||
if preferred == nil {
|
||||
fmt.Printf(" No GGUF file matching %v — skipping\n", quantPrefs)
|
||||
continue
|
||||
}
|
||||
|
||||
// Find preferred model file
|
||||
preferredModelFile := hfapi.FindPreferredModelFile(details.Files, quantizationPreferences)
|
||||
|
||||
// Process files
|
||||
processedFiles := make([]ProcessedModelFile, len(details.Files))
|
||||
for j, file := range details.Files {
|
||||
pm.Files = make([]ProcessedModelFile, len(details.Files))
|
||||
for j, f := range details.Files {
|
||||
fileType := "other"
|
||||
if file.IsReadme {
|
||||
if f.IsReadme {
|
||||
fileType = "readme"
|
||||
} else if preferredModelFile != nil && file.Path == preferredModelFile.Path {
|
||||
} else if f.Path == preferred.Path {
|
||||
fileType = "model"
|
||||
}
|
||||
|
||||
processedFiles[j] = ProcessedModelFile{
|
||||
Path: file.Path,
|
||||
Size: file.Size,
|
||||
SHA256: file.SHA256,
|
||||
IsReadme: file.IsReadme,
|
||||
pm.Files[j] = ProcessedModelFile{
|
||||
Path: f.Path,
|
||||
Size: f.Size,
|
||||
SHA256: f.SHA256,
|
||||
IsReadme: f.IsReadme,
|
||||
FileType: fileType,
|
||||
}
|
||||
}
|
||||
|
||||
processedModel.Files = processedFiles
|
||||
|
||||
// Set preferred model file
|
||||
if preferredModelFile != nil {
|
||||
for _, file := range processedFiles {
|
||||
if file.Path == preferredModelFile.Path {
|
||||
processedModel.PreferredModelFile = &file
|
||||
break
|
||||
}
|
||||
if f.Path == preferred.Path {
|
||||
copyFile := pm.Files[j]
|
||||
pm.PreferredModelFile = ©File
|
||||
}
|
||||
if f.IsReadme {
|
||||
copyFile := pm.Files[j]
|
||||
pm.ReadmeFile = ©File
|
||||
}
|
||||
}
|
||||
|
||||
// Print file information
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Files found: %d\n", len(details.Files)))
|
||||
// Deterministic README resolution: follow base_model tag if set.
|
||||
// Keep the raw (HTML-bearing) README around while we extract the
|
||||
// icon, then strip it down to a plain-text description for the
|
||||
// `description:` YAML field.
|
||||
readme, err := resolveReadme(client, m.ModelID, m.Tags)
|
||||
if err != nil {
|
||||
fmt.Printf(" Warning: failed to fetch README: %v\n", err)
|
||||
}
|
||||
pm.ReadmeContent = readme
|
||||
|
||||
if preferredModelFile != nil {
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Preferred Model File: %s (SHA256: %s)\n",
|
||||
preferredModelFile.Path,
|
||||
preferredModelFile.SHA256))
|
||||
} else {
|
||||
outputBuilder.WriteString(fmt.Sprintf(" No model file found with quantization preferences: %v\n", quantizationPreferences))
|
||||
pm.License = licenseFromTags(m.Tags)
|
||||
pm.Tags = curatedTags(m.Tags)
|
||||
pm.Icon = extractModelIcon(pm)
|
||||
|
||||
if pm.ReadmeContent != "" {
|
||||
pm.ReadmeContent = extractDescription(pm.ReadmeContent)
|
||||
pm.ReadmeContentPreview = truncateString(pm.ReadmeContent, 200)
|
||||
}
|
||||
|
||||
if details.ReadmeFile != nil {
|
||||
outputBuilder.WriteString(fmt.Sprintf(" README File: %s\n", details.ReadmeFile.Path))
|
||||
|
||||
// Find and set readme file
|
||||
for _, file := range processedFiles {
|
||||
if file.IsReadme {
|
||||
processedModel.ReadmeFile = &file
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Getting real readme for", model.ModelID, "waiting...")
|
||||
// Use agent to get the real readme and prepare the model description
|
||||
readmeContent, err := getRealReadme(context.Background(), model.ModelID)
|
||||
if err == nil {
|
||||
processedModel.ReadmeContent = readmeContent
|
||||
processedModel.ReadmeContentPreview = truncateString(readmeContent, 200)
|
||||
outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n",
|
||||
processedModel.ReadmeContentPreview))
|
||||
} else {
|
||||
fmt.Printf(" Warning: Failed to get real readme: %v\n", err)
|
||||
}
|
||||
fmt.Println("Real readme got", readmeContent)
|
||||
|
||||
// Extract metadata (tags, license) from README using LLM
|
||||
fmt.Println("Extracting metadata for", model.ModelID, "waiting...")
|
||||
tags, license, err := extractModelMetadata(context.Background(), processedModel)
|
||||
if err == nil {
|
||||
processedModel.Tags = tags
|
||||
processedModel.License = license
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Tags: %v\n", tags))
|
||||
outputBuilder.WriteString(fmt.Sprintf(" License: %s\n", license))
|
||||
} else {
|
||||
fmt.Printf(" Warning: Failed to extract metadata: %v\n", err)
|
||||
}
|
||||
|
||||
// Extract icon from README or use HuggingFace avatar
|
||||
icon := extractModelIcon(processedModel)
|
||||
if icon != "" {
|
||||
processedModel.Icon = icon
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Icon: %s\n", icon))
|
||||
}
|
||||
// Get README content
|
||||
// readmeContent, err := client.GetReadmeContent(model.ModelID, details.ReadmeFile.Path)
|
||||
// if err == nil {
|
||||
// processedModel.ReadmeContent = readmeContent
|
||||
// processedModel.ReadmeContentPreview = truncateString(readmeContent, 200)
|
||||
// outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n",
|
||||
// processedModel.ReadmeContentPreview))
|
||||
// }
|
||||
}
|
||||
|
||||
// Print all files with their checksums
|
||||
outputBuilder.WriteString(" All Files:\n")
|
||||
for _, file := range processedFiles {
|
||||
outputBuilder.WriteString(fmt.Sprintf(" - %s (%s, %d bytes", file.Path, file.FileType, file.Size))
|
||||
if file.SHA256 != "" {
|
||||
outputBuilder.WriteString(fmt.Sprintf(", SHA256: %s", file.SHA256))
|
||||
}
|
||||
outputBuilder.WriteString(")\n")
|
||||
}
|
||||
|
||||
outputBuilder.WriteString("\n")
|
||||
result.Models = append(result.Models, processedModel)
|
||||
fmt.Printf(" License: %s, Tags: %v, Icon: %s\n", pm.License, pm.Tags, pm.Icon)
|
||||
processed = append(processed, pm)
|
||||
}
|
||||
|
||||
result.FormattedOutput = outputBuilder.String()
|
||||
return result, nil
|
||||
if len(processed) == 0 {
|
||||
fmt.Println("No processable models after detail fetch.")
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: totalFound,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Phase 5: write YAML entries.
|
||||
var addedIDs, addedURLs []string
|
||||
for _, pm := range processed {
|
||||
addedIDs = append(addedIDs, pm.ModelID)
|
||||
addedURLs = append(addedURLs, "https://huggingface.co/"+pm.ModelID)
|
||||
}
|
||||
|
||||
fmt.Println("Generating YAML entries for selected models...")
|
||||
if err := generateYAMLForModels(context.Background(), processed, quantization); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error generating YAML entries: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: totalFound,
|
||||
ModelsAdded: len(addedIDs),
|
||||
AddedModelIDs: addedIDs,
|
||||
AddedModelURLs: addedURLs,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
}
|
||||
|
||||
func writeSummary(summary AddedModelSummary) {
|
||||
data, err := json.MarshalIndent(summary, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marshaling summary: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile("gallery-agent-summary.json", data, 0644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing summary file: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Println("Summary written to gallery-agent-summary.json")
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
@@ -381,3 +277,4 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
|
||||
46
.github/gallery-agent/tools.go
vendored
46
.github/gallery-agent/tools.go
vendored
@@ -1,46 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
jsonschema "github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
// Get repository README from HF
|
||||
type HFReadmeTool struct {
|
||||
client *hfapi.Client
|
||||
}
|
||||
|
||||
func (s *HFReadmeTool) Execute(args map[string]any) (string, any, error) {
|
||||
q, ok := args["repository"].(string)
|
||||
if !ok {
|
||||
return "", nil, fmt.Errorf("no query")
|
||||
}
|
||||
readme, err := s.client.GetReadmeContent(q, "README.md")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return readme, nil, nil
|
||||
}
|
||||
|
||||
func (s *HFReadmeTool) Tool() openai.Tool {
|
||||
return openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: "hf_readme",
|
||||
Description: "A tool to get the README content of a huggingface repository",
|
||||
Parameters: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
Properties: map[string]jsonschema.Definition{
|
||||
"repository": {
|
||||
Type: jsonschema.String,
|
||||
Description: "The huggingface repository to get the README content of",
|
||||
},
|
||||
},
|
||||
Required: []string{"repository"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
215
.github/workflows/backend.yml
vendored
215
.github/workflows/backend.yml
vendored
@@ -30,6 +30,7 @@ jobs:
|
||||
skip-drivers: ${{ matrix.skip-drivers }}
|
||||
context: ${{ matrix.context }}
|
||||
ubuntu-version: ${{ matrix.ubuntu-version }}
|
||||
amdgpu-targets: ${{ matrix.amdgpu-targets }}
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
@@ -53,6 +54,32 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-vllm'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "vllm"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-sglang'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "sglang"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -92,6 +119,25 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# tinygrad ships a single image — its CPU device uses bundled
|
||||
# libLLVM, and its CUDA / HIP / Metal devices dlopen the host
|
||||
# driver libraries at runtime via tinygrad's ctypes autogen
|
||||
# wrappers. There is no toolkit-version split because tinygrad
|
||||
# generates kernels itself (PTX renderer for CUDA) and never
|
||||
# links against cuDNN/cuBLAS/torch.
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-tinygrad'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "tinygrad"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -340,6 +386,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-turboquant'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -366,6 +425,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-sglang'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sglang"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -783,6 +855,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-turboquant'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -796,6 +881,19 @@ jobs:
|
||||
backend: "llama-cpp"
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-turboquant'
|
||||
base-image: "ubuntu:24.04"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
ubuntu-version: '2404'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1317,6 +1415,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-turboquant'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1343,6 +1454,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-sglang'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "sglang"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1500,19 +1624,6 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-whisperx'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "rocm/dev-ubuntu-24.04:7.2.1"
|
||||
skip-drivers: 'false'
|
||||
backend: "whisperx"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1553,6 +1664,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f32'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f32-turboquant'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1566,6 +1690,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'sycl_f16'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sycl-f16-turboquant'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1579,6 +1716,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-intel-sglang'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "sglang"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1945,6 +2095,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-turboquant'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1971,6 +2134,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'false'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-turboquant'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -1984,6 +2160,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.llama-cpp"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'vulkan'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-vulkan-turboquant'
|
||||
runs-on: 'bigger-runner'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "turboquant"
|
||||
dockerfile: "./backend/Dockerfile.turboquant"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# Stablediffusion-ggml
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
|
||||
7
.github/workflows/backend_build.yml
vendored
7
.github/workflows/backend_build.yml
vendored
@@ -58,6 +58,11 @@ on:
|
||||
required: false
|
||||
default: '2204'
|
||||
type: string
|
||||
amdgpu-targets:
|
||||
description: 'AMD GPU targets for ROCm/HIP builds'
|
||||
required: false
|
||||
default: 'gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1151,gfx1200,gfx1201'
|
||||
type: string
|
||||
secrets:
|
||||
dockerUsername:
|
||||
required: false
|
||||
@@ -214,6 +219,7 @@ jobs:
|
||||
BASE_IMAGE=${{ inputs.base-image }}
|
||||
BACKEND=${{ inputs.backend }}
|
||||
UBUNTU_VERSION=${{ inputs.ubuntu-version }}
|
||||
AMDGPU_TARGETS=${{ inputs.amdgpu-targets }}
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.dockerfile }}
|
||||
cache-from: type=gha
|
||||
@@ -235,6 +241,7 @@ jobs:
|
||||
BASE_IMAGE=${{ inputs.base-image }}
|
||||
BACKEND=${{ inputs.backend }}
|
||||
UBUNTU_VERSION=${{ inputs.ubuntu-version }}
|
||||
AMDGPU_TARGETS=${{ inputs.amdgpu-targets }}
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.dockerfile }}
|
||||
cache-from: type=gha
|
||||
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -18,6 +18,10 @@ jobs:
|
||||
variable: "IK_LLAMA_VERSION"
|
||||
branch: "main"
|
||||
file: "backend/cpp/ik-llama-cpp/Makefile"
|
||||
- repository: "TheTom/llama-cpp-turboquant"
|
||||
variable: "TURBOQUANT_VERSION"
|
||||
branch: "feature/turboquant-kv-cache"
|
||||
file: "backend/cpp/turboquant/Makefile"
|
||||
- repository: "ggml-org/whisper.cpp"
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
|
||||
99
.github/workflows/gallery-agent.yaml
vendored
99
.github/workflows/gallery-agent.yaml
vendored
@@ -48,21 +48,88 @@ jobs:
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||
PATH="$PATH:$HOME/go/bin" make protogen-go
|
||||
- uses: mudler/localai-github-action@v1.1
|
||||
with:
|
||||
model: 'https://huggingface.co/unsloth/Qwen3.5-2B-GGUF'
|
||||
- name: Process gallery-agent PR commands
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.UPDATE_BOT_TOKEN }}
|
||||
REPO: ${{ github.repository }}
|
||||
SEARCH: 'gallery agent in:title'
|
||||
run: |
|
||||
# Walk gallery-agent PRs and act on maintainer comments:
|
||||
# /gallery-agent blacklist → label `gallery-agent/blacklisted` + close (never repropose)
|
||||
# /gallery-agent recreate → close without label (next run may repropose)
|
||||
# Only comments from OWNER / MEMBER / COLLABORATOR are honored so
|
||||
# random users can't drive the bot.
|
||||
#
|
||||
# We scan both open PRs AND recently-closed PRs that don't already
|
||||
# carry the blacklist label. This covers the common flow where a
|
||||
# maintainer writes /gallery-agent blacklist and immediately clicks
|
||||
# Close — without this, the next scheduled run wouldn't see the
|
||||
# command (PR is already closed) and would repropose the model.
|
||||
gh label create gallery-agent/blacklisted \
|
||||
--repo "$REPO" --color ededed \
|
||||
--description "gallery-agent must not repropose this model" 2>/dev/null || true
|
||||
|
||||
prs_open=$(gh pr list --repo "$REPO" --state open --search "$SEARCH" \
|
||||
--json number --jq '.[].number')
|
||||
# Closed PRs from the last 14 days that don't yet have the blacklist label.
|
||||
# Bounded window keeps the scan cheap while covering late-applied commands.
|
||||
since=$(date -u -d '14 days ago' +%Y-%m-%d)
|
||||
prs_closed=$(gh pr list --repo "$REPO" --state closed \
|
||||
--search "$SEARCH closed:>=$since -label:gallery-agent/blacklisted" \
|
||||
--json number --jq '.[].number')
|
||||
prs=$(printf '%s\n%s\n' "$prs_open" "$prs_closed" | sort -u | sed '/^$/d')
|
||||
for pr in $prs; do
|
||||
state=$(gh pr view "$pr" --repo "$REPO" --json state --jq '.state')
|
||||
cmds=$(gh pr view "$pr" --repo "$REPO" --json comments \
|
||||
--jq '.comments[] | select(.authorAssociation=="OWNER" or .authorAssociation=="MEMBER" or .authorAssociation=="COLLABORATOR") | .body')
|
||||
if echo "$cmds" | grep -qE '(^|[[:space:]])/gallery-agent[[:space:]]+blacklist([[:space:]]|$)'; then
|
||||
echo "PR #$pr: blacklist command found (state=$state)"
|
||||
gh pr edit "$pr" --repo "$REPO" --add-label gallery-agent/blacklisted || true
|
||||
if [ "$state" = "OPEN" ]; then
|
||||
gh pr close "$pr" --repo "$REPO" --comment "Blacklisted via \`/gallery-agent blacklist\`. This model will not be reproposed." || true
|
||||
fi
|
||||
elif [ "$state" = "OPEN" ] && echo "$cmds" | grep -qE '(^|[[:space:]])/gallery-agent[[:space:]]+recreate([[:space:]]|$)'; then
|
||||
echo "PR #$pr: recreate command found"
|
||||
gh pr close "$pr" --repo "$REPO" --comment "Closed via \`/gallery-agent recreate\`. The next scheduled run will propose this model again." || true
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Collect skip URLs for the gallery agent
|
||||
id: open_prs
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
REPO: ${{ github.repository }}
|
||||
SEARCH: 'gallery agent in:title'
|
||||
run: |
|
||||
# Skip set =
|
||||
# URLs from any open gallery-agent PR (avoid duplicate PRs for the same model while one is pending)
|
||||
# + URLs from closed PRs carrying the `gallery-agent/blacklisted` label (hard blacklist)
|
||||
# Plain-closed PRs without the label are ignored — closing a PR is
|
||||
# not by itself a "never propose again" signal; maintainers must
|
||||
# opt in via the /gallery-agent blacklist comment command.
|
||||
urls_open=$(gh pr list --repo "$REPO" --state open --search "$SEARCH" \
|
||||
--json body --jq '[.[].body] | join("\n")' \
|
||||
| grep -oE 'https://huggingface\.co/[^ )]+' || true)
|
||||
urls_blacklist=$(gh pr list --repo "$REPO" --state closed --search "$SEARCH" \
|
||||
--label gallery-agent/blacklisted \
|
||||
--json body --jq '[.[].body] | join("\n")' \
|
||||
| grep -oE 'https://huggingface\.co/[^ )]+' || true)
|
||||
urls=$(printf '%s\n%s\n' "$urls_open" "$urls_blacklist" | sort -u | sed '/^$/d')
|
||||
echo "Skip URLs:"
|
||||
echo "$urls"
|
||||
{
|
||||
echo "urls<<EOF"
|
||||
echo "$urls"
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run gallery agent
|
||||
env:
|
||||
#OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
OPENAI_MODEL: Qwen3.5-2B-GGUF
|
||||
OPENAI_BASE_URL: "http://localhost:8080"
|
||||
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
|
||||
#OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
||||
SEARCH_TERM: ${{ github.event.inputs.search_term || 'GGUF' }}
|
||||
LIMIT: ${{ github.event.inputs.limit || '15' }}
|
||||
QUANTIZATION: ${{ github.event.inputs.quantization || 'Q4_K_M' }}
|
||||
MAX_MODELS: ${{ github.event.inputs.max_models || '1' }}
|
||||
EXTRA_SKIP_URLS: ${{ steps.open_prs.outputs.urls }}
|
||||
run: |
|
||||
export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml
|
||||
go run ./.github/gallery-agent
|
||||
@@ -124,7 +191,21 @@ jobs:
|
||||
|
||||
**Added Models:**
|
||||
${{ steps.read_summary.outputs.added_models || '- No models added' }}
|
||||
|
||||
|
||||
### Bot commands
|
||||
|
||||
Maintainers (owner / member / collaborator) can control this PR
|
||||
by leaving a comment with one of:
|
||||
|
||||
- `/gallery-agent recreate` — close this PR; the next scheduled
|
||||
run will propose this model again (useful if the entry needs
|
||||
to be regenerated with fresh metadata).
|
||||
- `/gallery-agent blacklist` — close this PR and permanently
|
||||
prevent the gallery agent from ever reproposing this model.
|
||||
|
||||
Plain "Close" (without a command) is treated as a no-op: the
|
||||
model may be reproposed by a future run.
|
||||
|
||||
**Workflow Details:**
|
||||
- Triggered by: `${{ github.event_name }}`
|
||||
- Run ID: `${{ github.run_id }}`
|
||||
|
||||
2
.github/workflows/gh-pages.yml
vendored
2
.github/workflows/gh-pages.yml
vendored
@@ -59,7 +59,7 @@ jobs:
|
||||
hugo --minify --baseURL "${{ steps.pages.outputs.base_url }}/"
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-pages-artifact@v4
|
||||
uses: actions/upload-pages-artifact@v5
|
||||
with:
|
||||
path: docs/public
|
||||
|
||||
|
||||
4
.github/workflows/release.yaml
vendored
4
.github/workflows/release.yaml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
run: |
|
||||
make build-launcher-darwin
|
||||
- name: Upload DMG to Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@v3
|
||||
with:
|
||||
files: ./dist/LocalAI.dmg
|
||||
launcher-build-linux:
|
||||
@@ -59,6 +59,6 @@ jobs:
|
||||
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
|
||||
make build-launcher-linux
|
||||
- name: Upload Linux launcher artifacts
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@v3
|
||||
with:
|
||||
files: ./local-ai-launcher-linux.tar.xz
|
||||
|
||||
131
.github/workflows/test-extra.yml
vendored
131
.github/workflows/test-extra.yml
vendored
@@ -31,6 +31,9 @@ jobs:
|
||||
llama-cpp-quantization: ${{ steps.detect.outputs.llama-cpp-quantization }}
|
||||
llama-cpp: ${{ steps.detect.outputs.llama-cpp }}
|
||||
ik-llama-cpp: ${{ steps.detect.outputs.ik-llama-cpp }}
|
||||
turboquant: ${{ steps.detect.outputs.turboquant }}
|
||||
vllm: ${{ steps.detect.outputs.vllm }}
|
||||
sglang: ${{ steps.detect.outputs.sglang }}
|
||||
acestep-cpp: ${{ steps.detect.outputs.acestep-cpp }}
|
||||
qwen3-tts-cpp: ${{ steps.detect.outputs.qwen3-tts-cpp }}
|
||||
voxtral: ${{ steps.detect.outputs.voxtral }}
|
||||
@@ -484,6 +487,23 @@ jobs:
|
||||
- name: Build llama-cpp backend image and run gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-llama-cpp
|
||||
tests-llama-cpp-grpc-transcription:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
- name: Build llama-cpp backend image and run audio transcription gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-llama-cpp-transcription
|
||||
tests-ik-llama-cpp-grpc:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.ik-llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
@@ -501,6 +521,117 @@ jobs:
|
||||
- name: Build ik-llama-cpp backend image and run gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-ik-llama-cpp
|
||||
tests-turboquant-grpc:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.turboquant == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25.4'
|
||||
# Exercises the turboquant (llama.cpp fork) backend with KV-cache
|
||||
# quantization enabled. The convenience target sets
|
||||
# BACKEND_TEST_CACHE_TYPE_K / _V=q8_0, which are plumbed into the
|
||||
# ModelOptions.CacheTypeKey/Value gRPC fields. LoadModel-success +
|
||||
# backend stdout/stderr (captured by the Ginkgo suite) prove the
|
||||
# cache-type config path reaches the fork's KV-cache init.
|
||||
- name: Build turboquant backend image and run gRPC e2e tests
|
||||
run: |
|
||||
make test-extra-backend-turboquant
|
||||
# tests-vllm-grpc is currently disabled in CI.
|
||||
#
|
||||
# The prebuilt vllm CPU wheel is compiled with AVX-512 VNNI/BF16
|
||||
# instructions, and neither ubuntu-latest nor the bigger-runner pool
|
||||
# offers a stable CPU baseline that supports them — runners come
|
||||
# back with different hardware between runs and SIGILL on import of
|
||||
# vllm.model_executor.models.registry. Compiling vllm from source
|
||||
# via FROM_SOURCE=true works on any CPU but takes 30-50 minutes per
|
||||
# run, which is too slow for a smoke test.
|
||||
#
|
||||
# The test itself (tests/e2e-backends + make test-extra-backend-vllm)
|
||||
# is fully working and validated locally on a host with the right
|
||||
# SIMD baseline. Run it manually with:
|
||||
#
|
||||
# make test-extra-backend-vllm
|
||||
#
|
||||
# Re-enable this job once we have a self-hosted runner label with
|
||||
# guaranteed AVX-512 VNNI/BF16 support, or once the vllm project
|
||||
# publishes a CPU wheel with a wider baseline.
|
||||
#
|
||||
# tests-vllm-grpc:
|
||||
# needs: detect-changes
|
||||
# if: needs.detect-changes.outputs.vllm == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
# runs-on: bigger-runner
|
||||
# timeout-minutes: 90
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# submodules: true
|
||||
# - name: Dependencies
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# make build-essential curl unzip ca-certificates git tar
|
||||
# - name: Setup Go
|
||||
# uses: actions/setup-go@v5
|
||||
# with:
|
||||
# go-version: '1.25.4'
|
||||
# - name: Free disk space
|
||||
# run: |
|
||||
# sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/lib/android /opt/hostedtoolcache/CodeQL || true
|
||||
# df -h
|
||||
# - name: Build vllm (cpu) backend image and run gRPC e2e tests
|
||||
# run: |
|
||||
# make test-extra-backend-vllm
|
||||
# tests-sglang-grpc is currently disabled in CI for the same reason as
|
||||
# tests-vllm-grpc: sglang's CPU kernel (sgl-kernel) uses __m512 AVX-512
|
||||
# intrinsics unconditionally in shm.cpp, so the from-source build
|
||||
# requires `-march=sapphirerapids` (already set in install.sh) and the
|
||||
# resulting binary SIGILLs at import on CPUs without AVX-512 VNNI/BF16.
|
||||
# The ubuntu-latest runner pool does not guarantee that ISA baseline.
|
||||
#
|
||||
# The test itself (tests/e2e-backends + make test-extra-backend-sglang)
|
||||
# is fully working and validated locally on a host with the right
|
||||
# SIMD baseline. Run it manually with:
|
||||
#
|
||||
# make test-extra-backend-sglang
|
||||
#
|
||||
# Re-enable this job once we have a self-hosted runner label with
|
||||
# guaranteed AVX-512 VNNI/BF16 support.
|
||||
#
|
||||
# tests-sglang-grpc:
|
||||
# needs: detect-changes
|
||||
# if: needs.detect-changes.outputs.sglang == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
# runs-on: bigger-runner
|
||||
# timeout-minutes: 90
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# submodules: true
|
||||
# - name: Dependencies
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# make build-essential curl unzip ca-certificates git tar
|
||||
# - name: Setup Go
|
||||
# uses: actions/setup-go@v5
|
||||
# with:
|
||||
# go-version: '1.25.4'
|
||||
# - name: Free disk space
|
||||
# run: |
|
||||
# sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/lib/android /opt/hostedtoolcache/CodeQL || true
|
||||
# df -h
|
||||
# - name: Build sglang (cpu) backend image and run gRPC e2e tests
|
||||
# run: |
|
||||
# make test-extra-backend-sglang
|
||||
tests-acestep-cpp:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.acestep-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true'
|
||||
|
||||
15
AGENTS.md
15
AGENTS.md
@@ -1,15 +1,28 @@
|
||||
# LocalAI Agent Instructions
|
||||
|
||||
This file is an index to detailed topic guides in the `.agents/` directory. Read the relevant file(s) for the task at hand — you don't need to load all of them.
|
||||
This file is the entry point for AI coding assistants (Claude Code, Cursor, Copilot, Codex, Aider, etc.) working on LocalAI. It is an index to detailed topic guides in the `.agents/` directory. Read the relevant file(s) for the task at hand — you don't need to load all of them.
|
||||
|
||||
Human contributors: see [CONTRIBUTING.md](CONTRIBUTING.md) for the development workflow.
|
||||
|
||||
## Policy for AI-Assisted Contributions
|
||||
|
||||
LocalAI follows the Linux kernel project's [guidelines for AI coding assistants](https://docs.kernel.org/process/coding-assistants.html). Before submitting AI-assisted code, read [.agents/ai-coding-assistants.md](.agents/ai-coding-assistants.md). Key rules:
|
||||
|
||||
- **No `Signed-off-by` from AI.** Only the human submitter may sign off on the Developer Certificate of Origin.
|
||||
- **No `Co-Authored-By: <AI>` trailers.** The human contributor owns the change.
|
||||
- **Use an `Assisted-by:` trailer** to attribute AI involvement. Format: `Assisted-by: AGENT_NAME:MODEL_VERSION [TOOL1] [TOOL2]`.
|
||||
- **The human submitter is responsible** for reviewing, testing, and understanding every line of generated code.
|
||||
|
||||
## Topics
|
||||
|
||||
| File | When to read |
|
||||
|------|-------------|
|
||||
| [.agents/ai-coding-assistants.md](.agents/ai-coding-assistants.md) | Policy for AI-assisted contributions — licensing, DCO, attribution |
|
||||
| [.agents/building-and-testing.md](.agents/building-and-testing.md) | Building the project, running tests, Docker builds for specific platforms |
|
||||
| [.agents/adding-backends.md](.agents/adding-backends.md) | Adding a new backend (Python, Go, or C++) — full step-by-step checklist |
|
||||
| [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions |
|
||||
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
||||
| [.agents/vllm-backend.md](.agents/vllm-backend.md) | Working on the vLLM / vLLM-omni backends — native parsers, ChatDelta, CPU build, libnuma packaging, backend hooks |
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
|
||||
@@ -13,6 +13,7 @@ Thank you for your interest in contributing to LocalAI! We appreciate your time
|
||||
- [Development Workflow](#development-workflow)
|
||||
- [Creating a Pull Request (PR)](#creating-a-pull-request-pr)
|
||||
- [Coding Guidelines](#coding-guidelines)
|
||||
- [AI Coding Assistants](#ai-coding-assistants)
|
||||
- [Testing](#testing)
|
||||
- [Documentation](#documentation)
|
||||
- [Community and Communication](#community-and-communication)
|
||||
@@ -185,7 +186,7 @@ Before jumping into a PR for a massive feature or big change, it is preferred to
|
||||
|
||||
This project uses an [`.editorconfig`](.editorconfig) file to define formatting standards (indentation, line endings, charset, etc.). Please configure your editor to respect it.
|
||||
|
||||
For AI-assisted development, see [`CLAUDE.md`](CLAUDE.md) for agent-specific guidelines including build instructions and backend architecture details.
|
||||
For AI-assisted development, see [`AGENTS.md`](AGENTS.md) (or the equivalent [`CLAUDE.md`](CLAUDE.md) symlink) for agent-specific guidelines including build instructions and backend architecture details. Contributions produced with AI assistance must follow the rules in the [AI Coding Assistants](#ai-coding-assistants) section below.
|
||||
|
||||
### General Principles
|
||||
|
||||
@@ -211,6 +212,26 @@ For AI-assisted development, see [`CLAUDE.md`](CLAUDE.md) for agent-specific gui
|
||||
- Reviewers will check for correctness, test coverage, adherence to these guidelines, and clarity of intent.
|
||||
- Be responsive to review feedback and keep discussions constructive.
|
||||
|
||||
## AI Coding Assistants
|
||||
|
||||
LocalAI follows the **same guidelines as the Linux kernel project** for AI-assisted contributions: <https://docs.kernel.org/process/coding-assistants.html>.
|
||||
|
||||
The full policy for this repository lives in [`.agents/ai-coding-assistants.md`](.agents/ai-coding-assistants.md). Summary:
|
||||
|
||||
- **AI agents MUST NOT add `Signed-off-by` tags.** Only humans can certify the Developer Certificate of Origin.
|
||||
- **AI agents MUST NOT add `Co-Authored-By` trailers** attributing themselves as co-authors.
|
||||
- **Attribute AI involvement with an `Assisted-by` trailer** in the commit message:
|
||||
|
||||
```
|
||||
Assisted-by: AGENT_NAME:MODEL_VERSION [TOOL1] [TOOL2]
|
||||
```
|
||||
|
||||
Example: `Assisted-by: Claude:claude-opus-4-7 golangci-lint`
|
||||
|
||||
Basic development tools (git, go, make, editors) should not be listed.
|
||||
- **The human submitter is responsible** for reviewing, testing, and fully understanding every line of AI-generated code — including verifying that any referenced APIs, flags, or file paths actually exist in the tree.
|
||||
- Contributions must remain compatible with LocalAI's **MIT License**.
|
||||
|
||||
## Testing
|
||||
|
||||
All new features and bug fixes should include test coverage. The project uses [Ginkgo](https://onsi.github.io/ginkgo/) as its test framework.
|
||||
|
||||
152
Makefile
152
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -419,6 +419,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/chatterbox
|
||||
$(MAKE) -C backend/python/vllm
|
||||
$(MAKE) -C backend/python/vllm-omni
|
||||
$(MAKE) -C backend/python/sglang
|
||||
$(MAKE) -C backend/python/vibevoice
|
||||
$(MAKE) -C backend/python/moonshine
|
||||
$(MAKE) -C backend/python/pocket-tts
|
||||
@@ -432,6 +433,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/whisperx
|
||||
$(MAKE) -C backend/python/ace-step
|
||||
$(MAKE) -C backend/python/trl
|
||||
$(MAKE) -C backend/python/tinygrad
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
@@ -454,6 +456,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/whisperx test
|
||||
$(MAKE) -C backend/python/ace-step test
|
||||
$(MAKE) -C backend/python/trl test
|
||||
$(MAKE) -C backend/python/tinygrad test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
|
||||
##
|
||||
@@ -466,8 +469,14 @@ test-extra: prepare-test-extra
|
||||
## BACKEND_IMAGE Required. Docker image to test, e.g. local-ai-backend:llama-cpp.
|
||||
## BACKEND_TEST_MODEL_URL URL of a model file to download and load.
|
||||
## BACKEND_TEST_MODEL_FILE Path to an already-downloaded model (skips download).
|
||||
## BACKEND_TEST_MODEL_NAME HuggingFace repo id (e.g. Qwen/Qwen2.5-0.5B-Instruct).
|
||||
## Use this instead of MODEL_URL for backends that
|
||||
## resolve HF model ids natively (vllm, vllm-omni).
|
||||
## BACKEND_TEST_CAPS Comma-separated capabilities, default "health,load,predict,stream".
|
||||
## Adds "tools" to exercise ChatDelta tool call extraction.
|
||||
## BACKEND_TEST_PROMPT Override the prompt used in predict/stream specs.
|
||||
## BACKEND_TEST_OPTIONS Comma-separated Options[] entries forwarded to LoadModel,
|
||||
## e.g. "tool_parser:hermes,reasoning_parser:qwen3".
|
||||
##
|
||||
## Direct usage (image already built, no docker-build-* dependency):
|
||||
##
|
||||
@@ -486,9 +495,19 @@ test-extra-backend: protogen-go
|
||||
BACKEND_IMAGE="$$BACKEND_IMAGE" \
|
||||
BACKEND_TEST_MODEL_URL="$${BACKEND_TEST_MODEL_URL:-$(BACKEND_TEST_MODEL_URL)}" \
|
||||
BACKEND_TEST_MODEL_FILE="$$BACKEND_TEST_MODEL_FILE" \
|
||||
BACKEND_TEST_MODEL_NAME="$$BACKEND_TEST_MODEL_NAME" \
|
||||
BACKEND_TEST_MMPROJ_URL="$$BACKEND_TEST_MMPROJ_URL" \
|
||||
BACKEND_TEST_MMPROJ_FILE="$$BACKEND_TEST_MMPROJ_FILE" \
|
||||
BACKEND_TEST_AUDIO_URL="$$BACKEND_TEST_AUDIO_URL" \
|
||||
BACKEND_TEST_AUDIO_FILE="$$BACKEND_TEST_AUDIO_FILE" \
|
||||
BACKEND_TEST_CAPS="$$BACKEND_TEST_CAPS" \
|
||||
BACKEND_TEST_PROMPT="$$BACKEND_TEST_PROMPT" \
|
||||
go test -v -timeout 15m ./tests/e2e-backends/...
|
||||
BACKEND_TEST_OPTIONS="$$BACKEND_TEST_OPTIONS" \
|
||||
BACKEND_TEST_TOOL_PROMPT="$$BACKEND_TEST_TOOL_PROMPT" \
|
||||
BACKEND_TEST_TOOL_NAME="$$BACKEND_TEST_TOOL_NAME" \
|
||||
BACKEND_TEST_CACHE_TYPE_K="$$BACKEND_TEST_CACHE_TYPE_K" \
|
||||
BACKEND_TEST_CACHE_TYPE_V="$$BACKEND_TEST_CACHE_TYPE_V" \
|
||||
go test -v -timeout 30m ./tests/e2e-backends/...
|
||||
|
||||
## Convenience wrappers: build the image, then exercise it.
|
||||
test-extra-backend-llama-cpp: docker-build-llama-cpp
|
||||
@@ -497,6 +516,120 @@ test-extra-backend-llama-cpp: docker-build-llama-cpp
|
||||
test-extra-backend-ik-llama-cpp: docker-build-ik-llama-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:ik-llama-cpp $(MAKE) test-extra-backend
|
||||
|
||||
## turboquant: exercises the llama.cpp-fork backend with the fork's
|
||||
## *TurboQuant-specific* KV-cache types (turbo3 for both K and V). turbo3
|
||||
## is what makes this backend distinct from stock llama-cpp — picking q8_0
|
||||
## here would only test the standard llama.cpp code path that the upstream
|
||||
## llama-cpp backend already covers. The fork auto-enables flash_attention
|
||||
## when turbo3/turbo4 are active, so we don't need to set it explicitly.
|
||||
test-extra-backend-turboquant: docker-build-turboquant
|
||||
BACKEND_IMAGE=local-ai-backend:turboquant \
|
||||
BACKEND_TEST_CACHE_TYPE_K=q8_0 \
|
||||
BACKEND_TEST_CACHE_TYPE_V=turbo3 \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## Audio transcription wrapper for the llama-cpp backend.
|
||||
## Drives the new AudioTranscription / AudioTranscriptionStream RPCs against
|
||||
## ggml-org/Qwen3-ASR-0.6B-GGUF (a small ASR model that requires its mmproj
|
||||
## audio encoder companion). The audio fixture is a short public-domain
|
||||
## "jfk.wav" clip ggml-org bundles with whisper.cpp's CI assets.
|
||||
test-extra-backend-llama-cpp-transcription: docker-build-llama-cpp
|
||||
BACKEND_IMAGE=local-ai-backend:llama-cpp \
|
||||
BACKEND_TEST_MODEL_URL=https://huggingface.co/ggml-org/Qwen3-ASR-0.6B-GGUF/resolve/main/Qwen3-ASR-0.6B-Q8_0.gguf \
|
||||
BACKEND_TEST_MMPROJ_URL=https://huggingface.co/ggml-org/Qwen3-ASR-0.6B-GGUF/resolve/main/mmproj-Qwen3-ASR-0.6B-Q8_0.gguf \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## vllm is resolved from a HuggingFace model id (no file download) and
|
||||
## exercises Predict + streaming + tool-call extraction via the hermes parser.
|
||||
## Requires a host CPU with the SIMD instructions the prebuilt vllm CPU
|
||||
## wheel was compiled against (AVX-512 VNNI/BF16); older CPUs will SIGILL
|
||||
## on import — on CI this means using the bigger-runner label.
|
||||
test-extra-backend-vllm: docker-build-vllm
|
||||
BACKEND_IMAGE=local-ai-backend:vllm \
|
||||
BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \
|
||||
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
|
||||
BACKEND_TEST_OPTIONS=tool_parser:hermes \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## tinygrad mirrors the vllm target (same model, same caps, same parser) so
|
||||
## the two backends are directly comparable. The LLM path covers Predict,
|
||||
## streaming and native tool-call extraction. Companion targets below cover
|
||||
## embeddings, Stable Diffusion and Whisper — run them individually or via
|
||||
## the `test-extra-backend-tinygrad-all` aggregate.
|
||||
test-extra-backend-tinygrad: docker-build-tinygrad
|
||||
BACKEND_IMAGE=local-ai-backend:tinygrad \
|
||||
BACKEND_TEST_MODEL_NAME=Qwen/Qwen3-0.6B \
|
||||
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
|
||||
BACKEND_TEST_OPTIONS=tool_parser:hermes \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## tinygrad — embeddings via LLM last-hidden-state pooling. Reuses the same
|
||||
## Qwen3-0.6B as the chat target so we don't need a separate BERT vendor;
|
||||
## the Embedding RPC mean-pools and L2-normalizes the last-layer hidden
|
||||
## state.
|
||||
test-extra-backend-tinygrad-embeddings: docker-build-tinygrad
|
||||
BACKEND_IMAGE=local-ai-backend:tinygrad \
|
||||
BACKEND_TEST_MODEL_NAME=Qwen/Qwen3-0.6B \
|
||||
BACKEND_TEST_CAPS=health,load,embeddings \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## tinygrad — Stable Diffusion 1.5. The original CompVis/runwayml repos have
|
||||
## been gated, so we use the community-maintained mirror at
|
||||
## stable-diffusion-v1-5/stable-diffusion-v1-5 with the EMA-only pruned
|
||||
## checkpoint (~4.3GB). Step count is kept low (4) so a CPU-only run finishes
|
||||
## in a few minutes; bump BACKEND_TEST_IMAGE_STEPS for higher quality.
|
||||
test-extra-backend-tinygrad-sd: docker-build-tinygrad
|
||||
BACKEND_IMAGE=local-ai-backend:tinygrad \
|
||||
BACKEND_TEST_MODEL_NAME=stable-diffusion-v1-5/stable-diffusion-v1-5 \
|
||||
BACKEND_TEST_CAPS=health,load,image \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## tinygrad — Whisper. Loads OpenAI's tiny.en checkpoint (smallest at ~75MB)
|
||||
## from the original azure CDN through tinygrad's `fetch` helper, and
|
||||
## transcribes the canonical jfk.wav fixture from whisper.cpp's CI samples.
|
||||
## Exercises both AudioTranscription and AudioTranscriptionStream.
|
||||
test-extra-backend-tinygrad-whisper: docker-build-tinygrad
|
||||
BACKEND_IMAGE=local-ai-backend:tinygrad \
|
||||
BACKEND_TEST_MODEL_NAME=openai/whisper-tiny.en \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
test-extra-backend-tinygrad-all: \
|
||||
test-extra-backend-tinygrad \
|
||||
test-extra-backend-tinygrad-embeddings \
|
||||
test-extra-backend-tinygrad-sd \
|
||||
test-extra-backend-tinygrad-whisper
|
||||
|
||||
## sglang mirrors the vllm setup: HuggingFace model id, same tiny Qwen,
|
||||
## tool-call extraction via sglang's native qwen parser. CPU builds use
|
||||
## sglang's upstream pyproject_cpu.toml recipe (see backend/python/sglang/install.sh).
|
||||
test-extra-backend-sglang: docker-build-sglang
|
||||
BACKEND_IMAGE=local-ai-backend:sglang \
|
||||
BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \
|
||||
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
|
||||
BACKEND_TEST_OPTIONS=tool_parser:qwen \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
|
||||
## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool
|
||||
## parser from the chat template, so no tool_parser: option is needed (it
|
||||
## would be ignored at runtime). Run this on macOS / arm64 with Metal; the
|
||||
## Linux/CPU mlx variant is untested in CI.
|
||||
test-extra-backend-mlx: docker-build-mlx
|
||||
BACKEND_IMAGE=local-ai-backend:mlx \
|
||||
BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \
|
||||
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
test-extra-backend-mlx-vlm: docker-build-mlx-vlm
|
||||
BACKEND_IMAGE=local-ai-backend:mlx-vlm \
|
||||
BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \
|
||||
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
IMAGE_TYPE?=core
|
||||
BASE_IMAGE?=ubuntu:24.04
|
||||
@@ -592,6 +725,9 @@ backend-images:
|
||||
BACKEND_LLAMA_CPP = llama-cpp|llama-cpp|.|false|false
|
||||
# ik-llama-cpp is a fork of llama.cpp with superior CPU performance
|
||||
BACKEND_IK_LLAMA_CPP = ik-llama-cpp|ik-llama-cpp|.|false|false
|
||||
# turboquant is a llama.cpp fork with TurboQuant KV-cache quantization.
|
||||
# Reuses backend/cpp/llama-cpp grpc-server sources via a thin wrapper Makefile.
|
||||
BACKEND_TURBOQUANT = turboquant|turboquant|.|false|false
|
||||
|
||||
# Golang backends
|
||||
BACKEND_PIPER = piper|golang|.|false|true
|
||||
@@ -617,6 +753,7 @@ BACKEND_NEUTTS = neutts|python|.|false|true
|
||||
BACKEND_KOKORO = kokoro|python|.|false|true
|
||||
BACKEND_VLLM = vllm|python|.|false|true
|
||||
BACKEND_VLLM_OMNI = vllm-omni|python|.|false|true
|
||||
BACKEND_SGLANG = sglang|python|.|false|true
|
||||
BACKEND_DIFFUSERS = diffusers|python|.|--progress=plain|true
|
||||
BACKEND_CHATTERBOX = chatterbox|python|.|false|true
|
||||
BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true
|
||||
@@ -630,9 +767,12 @@ BACKEND_NEMO = nemo|python|.|false|true
|
||||
BACKEND_VOXCPM = voxcpm|python|.|false|true
|
||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
BACKEND_MLX = mlx|python|.|false|true
|
||||
BACKEND_MLX_VLM = mlx-vlm|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||
BACKEND_TINYGRAD = tinygrad|python|.|false|true
|
||||
|
||||
# Rust backends
|
||||
BACKEND_KOKOROS = kokoros|rust|.|false|true
|
||||
@@ -650,6 +790,7 @@ define docker-build-backend
|
||||
--build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \
|
||||
--build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \
|
||||
--build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \
|
||||
$(if $(FROM_SOURCE),--build-arg FROM_SOURCE=$(FROM_SOURCE)) \
|
||||
$(if $(filter true,$(5)),--build-arg BACKEND=$(1)) \
|
||||
-t local-ai-backend:$(1) -f backend/Dockerfile.$(2) $(3)
|
||||
endef
|
||||
@@ -663,6 +804,7 @@ endef
|
||||
# Generate all docker-build targets
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_IK_LLAMA_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TURBOQUANT)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_PIPER)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE)))
|
||||
@@ -682,6 +824,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_NEUTTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKORO)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VLLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VLLM_OMNI)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SGLANG)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DIFFUSERS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CHATTERBOX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE)))
|
||||
@@ -697,9 +840,12 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
|
||||
@@ -707,7 +853,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -58,6 +58,8 @@ ARG CUDA_DOCKER_ARCH
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
ARG CMAKE_ARGS
|
||||
ENV CMAKE_ARGS=${CMAKE_ARGS}
|
||||
ARG AMDGPU_TARGETS
|
||||
ENV AMDGPU_TARGETS=${AMDGPU_TARGETS}
|
||||
ARG BACKEND=rerankers
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
|
||||
@@ -29,6 +29,7 @@ RUN apt-get update && \
|
||||
curl python3-pip \
|
||||
python-is-python3 \
|
||||
python3-dev llvm \
|
||||
libnuma1 libgomp1 \
|
||||
python3-venv make cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
@@ -195,6 +196,12 @@ COPY backend/backend.proto /${BACKEND}/backend.proto
|
||||
COPY backend/python/common/ /${BACKEND}/common
|
||||
COPY scripts/build/package-gpu-libs.sh /package-gpu-libs.sh
|
||||
|
||||
# Optional per-backend source build toggle (e.g. vllm on CPU can set
|
||||
# FROM_SOURCE=true to compile against the build host SIMD instead of
|
||||
# pulling a prebuilt wheel). Default empty — most backends ignore it.
|
||||
ARG FROM_SOURCE=""
|
||||
ENV FROM_SOURCE=${FROM_SOURCE}
|
||||
|
||||
RUN cd /${BACKEND} && PORTABLE_PYTHON=true make
|
||||
|
||||
# Package GPU libraries into the backend's lib directory
|
||||
|
||||
290
backend/Dockerfile.turboquant
Normal file
290
backend/Dockerfile.turboquant
Normal file
@@ -0,0 +1,290 @@
|
||||
ARG BASE_IMAGE=ubuntu:24.04
|
||||
ARG GRPC_BASE_IMAGE=${BASE_IMAGE}
|
||||
|
||||
|
||||
# The grpc target does one thing, it builds and installs GRPC. This is in it's own layer so that it can be effectively cached by CI.
|
||||
# You probably don't need to change anything here, and if you do, make sure that CI is adjusted so that the cache continues to work.
|
||||
FROM ${GRPC_BASE_IMAGE} AS grpc
|
||||
|
||||
# This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI
|
||||
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
|
||||
ARG GRPC_VERSION=v1.65.0
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
# CUDA Toolkit 13.x compatibility: CMake 3.31.9+ fixes toolchain detection/arch table issues
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
|
||||
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
build-essential curl libssl-dev \
|
||||
git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# We install GRPC to a different prefix here so that we can copy in only the build artifacts later
|
||||
# saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree
|
||||
# and running make install in the target container
|
||||
RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||
mkdir -p /build/grpc/cmake/build && \
|
||||
cd /build/grpc/cmake/build && \
|
||||
sed -i "216i\ TESTONLY" "../../third_party/abseil-cpp/absl/container/CMakeLists.txt" && \
|
||||
cmake -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX:PATH=/opt/grpc ../.. && \
|
||||
make && \
|
||||
make install && \
|
||||
rm -rf /build
|
||||
|
||||
FROM ${BASE_IMAGE} AS builder
|
||||
ARG CMAKE_FROM_SOURCE=false
|
||||
ARG CMAKE_VERSION=3.31.10
|
||||
# We can target specific CUDA ARCHITECTURES like --build-arg CUDA_DOCKER_ARCH='75;86;89;120'
|
||||
ARG CUDA_DOCKER_ARCH
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
ARG CMAKE_ARGS
|
||||
ENV CMAKE_ARGS=${CMAKE_ARGS}
|
||||
ARG BACKEND=rerankers
|
||||
ARG BUILD_TYPE
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ARG CUDA_MAJOR_VERSION
|
||||
ARG CUDA_MINOR_VERSION
|
||||
ARG SKIP_DRIVERS=false
|
||||
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
|
||||
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
ARG GO_VERSION=1.25.4
|
||||
ARG UBUNTU_VERSION=2404
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ccache git \
|
||||
ca-certificates \
|
||||
make \
|
||||
pkg-config libcurl4-openssl-dev \
|
||||
curl unzip \
|
||||
libssl-dev wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cuda
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH}
|
||||
|
||||
# HipBLAS requirements
|
||||
ENV PATH=/opt/rocm/bin:${PATH}
|
||||
|
||||
|
||||
# Vulkan requirements
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils wget gpg-agent && \
|
||||
apt-get install -y libglm-dev cmake libxcb-dri3-0 libxcb-present0 libpciaccess0 \
|
||||
libpng-dev libxcb-keysyms1-dev libxcb-dri3-dev libx11-dev g++ gcc \
|
||||
libwayland-dev libxrandr-dev libxcb-randr0-dev libxcb-ewmh-dev \
|
||||
git python-is-python3 bison libx11-xcb-dev liblz4-dev libzstd-dev \
|
||||
ocaml-core ninja-build pkg-config libxml2-dev wayland-protocols python3-jsonschema \
|
||||
clang-format qtbase5-dev qt6-base-dev libxcb-glx0-dev sudo xz-utils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
wget "https://sdk.lunarg.com/sdk/download/1.4.335.0/linux/vulkansdk-linux-x86_64-1.4.335.0.tar.xz" && \
|
||||
tar -xf vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
rm vulkansdk-linux-x86_64-1.4.335.0.tar.xz && \
|
||||
mkdir -p /opt/vulkan-sdk && \
|
||||
mv 1.4.335.0 /opt/vulkan-sdk/ && \
|
||||
cd /opt/vulkan-sdk/1.4.335.0 && \
|
||||
./vulkansdk --no-deps --maxjobs \
|
||||
vulkan-loader \
|
||||
vulkan-validationlayers \
|
||||
vulkan-extensionlayer \
|
||||
vulkan-tools \
|
||||
shaderc && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/bin/* /usr/bin/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/lib/* /usr/lib/x86_64-linux-gnu/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/include/* /usr/include/ && \
|
||||
cp -rfv /opt/vulkan-sdk/1.4.335.0/x86_64/share/* /usr/share/ && \
|
||||
rm -rf /opt/vulkan-sdk
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
mkdir vulkan && cd vulkan && \
|
||||
curl -L -o vulkan-sdk.tar.xz https://github.com/mudler/vulkan-sdk-arm/releases/download/1.4.335.0/vulkansdk-ubuntu-24.04-arm-1.4.335.0.tar.xz && \
|
||||
tar -xvf vulkan-sdk.tar.xz && \
|
||||
rm vulkan-sdk.tar.xz && \
|
||||
cd 1.4.335.0 && \
|
||||
cp -rfv aarch64/bin/* /usr/bin/ && \
|
||||
cp -rfv aarch64/lib/* /usr/lib/aarch64-linux-gnu/ && \
|
||||
cp -rfv aarch64/include/* /usr/include/ && \
|
||||
cp -rfv aarch64/share/* /usr/share/ && \
|
||||
cd ../.. && \
|
||||
rm -rf vulkan
|
||||
fi
|
||||
ldconfig && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN <<EOT bash
|
||||
if ( [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "l4t" ] ) && [ "${SKIP_DRIVERS}" = "false" ]; then
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
software-properties-common pciutils
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ]; then
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/sbsa/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/arm64/cuda-keyring_1.1-1_all.deb
|
||||
fi
|
||||
fi
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
||||
rm -f cuda-keyring_1.1-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
if [ "${CUDA_MAJOR_VERSION}" = "13" ] && [ "arm64" = "$TARGETARCH" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
libcufile-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcudnn9-cuda-${CUDA_MAJOR_VERSION} cuda-cupti-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libnvjitlink-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION}
|
||||
fi
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu${UBUNTU_VERSION}-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss cudss-cuda-${CUDA_MAJOR_VERSION} && \
|
||||
wget https://developer.download.nvidia.com/compute/nvpl/25.5/local_installers/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
dpkg -i nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5_1.0-1_arm64.deb && \
|
||||
cp /var/nvpl-local-repo-ubuntu${UBUNTU_VERSION}-25.5/nvpl-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get install -y nvpl
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libclblast-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* \
|
||||
; fi
|
||||
|
||||
RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
hipblas-dev \
|
||||
rocblas-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
# I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
|
||||
# to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
|
||||
ldconfig && \
|
||||
# Log which GPU architectures have rocBLAS kernel support
|
||||
echo "rocBLAS library data architectures:" && \
|
||||
(ls /opt/rocm*/lib/rocblas/library/Kernels* 2>/dev/null || ls /opt/rocm*/lib64/rocblas/library/Kernels* 2>/dev/null) | grep -oP 'gfx[0-9a-z+-]+' | sort -u || \
|
||||
echo "WARNING: No rocBLAS kernel data found" \
|
||||
; fi
|
||||
|
||||
RUN echo "TARGETARCH: $TARGETARCH"
|
||||
|
||||
# We need protoc installed, and the version in 22.04 is too old. We will create one as part installing the GRPC build below
|
||||
# but that will also being in a newer version of absl which stablediffusion cannot compile with. This version of protoc is only
|
||||
# here so that we can generate the grpc code for the stablediffusion build
|
||||
RUN <<EOT bash
|
||||
if [ "amd64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-x86_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
if [ "arm64" = "$TARGETARCH" ]; then
|
||||
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v27.1/protoc-27.1-linux-aarch_64.zip -o protoc.zip && \
|
||||
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
|
||||
rm protoc.zip
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Install CMake (the version in 22.04 is too old)
|
||||
RUN <<EOT bash
|
||||
if [ "${CMAKE_FROM_SOURCE}" = "true" ]; then
|
||||
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
|
||||
else
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
fi
|
||||
EOT
|
||||
|
||||
COPY --from=grpc /opt/grpc /usr/local
|
||||
|
||||
|
||||
COPY . /LocalAI
|
||||
|
||||
RUN <<'EOT' bash
|
||||
set -euxo pipefail
|
||||
|
||||
if [[ -n "${CUDA_DOCKER_ARCH:-}" ]]; then
|
||||
CUDA_ARCH_ESC="${CUDA_DOCKER_ARCH//;/\\;}"
|
||||
export CMAKE_ARGS="${CMAKE_ARGS:-} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH_ESC}"
|
||||
echo "CMAKE_ARGS(env) = ${CMAKE_ARGS}"
|
||||
rm -rf /LocalAI/backend/cpp/turboquant-*-build
|
||||
fi
|
||||
|
||||
cd /LocalAI/backend/cpp/turboquant
|
||||
|
||||
if [ "${TARGETARCH}" = "arm64" ] || [ "${BUILD_TYPE}" = "hipblas" ]; then
|
||||
make turboquant-fallback
|
||||
make turboquant-grpc
|
||||
make turboquant-rpc-server
|
||||
else
|
||||
make turboquant-avx
|
||||
make turboquant-avx2
|
||||
make turboquant-avx512
|
||||
make turboquant-fallback
|
||||
make turboquant-grpc
|
||||
make turboquant-rpc-server
|
||||
fi
|
||||
EOT
|
||||
|
||||
|
||||
# Copy libraries using a script to handle architecture differences
|
||||
RUN make -BC /LocalAI/backend/cpp/turboquant package
|
||||
|
||||
|
||||
FROM scratch
|
||||
|
||||
|
||||
# Copy all available binaries (the build process only creates the appropriate ones for the target architecture)
|
||||
COPY --from=builder /LocalAI/backend/cpp/turboquant/package/. ./
|
||||
@@ -17,6 +17,7 @@ service Backend {
|
||||
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
||||
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
|
||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||
rpc AudioTranscriptionStream(TranscriptRequest) returns (stream TranscriptStreamResponse) {}
|
||||
rpc TTS(TTSRequest) returns (Result) {}
|
||||
rpc TTSStream(TTSRequest) returns (stream Reply) {}
|
||||
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||
@@ -322,11 +323,21 @@ message TranscriptRequest {
|
||||
bool translate = 5;
|
||||
bool diarize = 6;
|
||||
string prompt = 7;
|
||||
float temperature = 8;
|
||||
repeated string timestamp_granularities = 9;
|
||||
bool stream = 10;
|
||||
}
|
||||
|
||||
message TranscriptResult {
|
||||
repeated TranscriptSegment segments = 1;
|
||||
string text = 2;
|
||||
string language = 3;
|
||||
float duration = 4;
|
||||
}
|
||||
|
||||
message TranscriptStreamResponse {
|
||||
string delta = 1;
|
||||
TranscriptResult final_result = 2;
|
||||
}
|
||||
|
||||
message TranscriptSegment {
|
||||
@@ -546,6 +557,7 @@ message ModelMetadataResponse {
|
||||
bool supports_thinking = 1;
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
||||
string media_marker = 4; // Marker the backend expects in the prompt for each multimodal input (images/audio/video). Empty when the backend does not use a marker.
|
||||
}
|
||||
|
||||
// Fine-tuning messages
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
IK_LLAMA_VERSION?=08ae48c667e3dcd3025821a8585190b4a46c2f7c
|
||||
IK_LLAMA_VERSION?=d4824131580b94ffa7b0e91c955e2b237c2fe16e
|
||||
LLAMA_REPO?=https://github.com/ikawrakow/ik_llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
@@ -62,7 +62,18 @@ add_executable(${TARGET} grpc-server.cpp json.hpp httplib.h)
|
||||
target_include_directories(${TARGET} PRIVATE ../llava)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common llama mtmd ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
|
||||
# Upstream llama.cpp renamed the `common` helpers library to `llama-common`.
|
||||
# Forks that branched before the rename (e.g. llama-cpp-turboquant) still
|
||||
# expose it as `common`. Detect which one is present so the same CMakeLists
|
||||
# drives both builds — otherwise an unresolved name silently degrades to a
|
||||
# plain `-l` flag and the PUBLIC include dir (where common.h lives) is lost.
|
||||
if (TARGET llama-common)
|
||||
set(_LLAMA_COMMON_TARGET llama-common)
|
||||
else()
|
||||
set(_LLAMA_COMMON_TARGET common)
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ${_LLAMA_COMMON_TARGET} llama mtmd ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
|
||||
absl::flags_parse
|
||||
gRPC::${_REFLECTION}
|
||||
gRPC::${_GRPC_GRPCPP}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=ff5ef8278615a2462b79b50abdf3cc95cfb31c6f
|
||||
LLAMA_VERSION?=cf8b0dbda9ac0eac30ee33f87bc6702ead1c4664
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
@@ -33,7 +33,7 @@ else ifeq ($(BUILD_TYPE),hipblas)
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||
AMDGPU_TARGETS?=gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201
|
||||
AMDGPU_TARGETS?=gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1102,gfx1151,gfx1200,gfx1201
|
||||
CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS)
|
||||
else ifeq ($(BUILD_TYPE),vulkan)
|
||||
CMAKE_ARGS+=-DGGML_VULKAN=1
|
||||
@@ -132,7 +132,7 @@ llama.cpp:
|
||||
cd llama.cpp && \
|
||||
git init && \
|
||||
git remote add origin $(LLAMA_REPO) && \
|
||||
git fetch origin && \
|
||||
git fetch --all --tags && \
|
||||
git checkout -b build $(LLAMA_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
|
||||
@@ -26,6 +26,8 @@
|
||||
#include <regex>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <mutex>
|
||||
#include <signal.h>
|
||||
#include <thread>
|
||||
@@ -76,6 +78,27 @@ static grpc::Status checkAuth(grpc::ServerContext* context) {
|
||||
return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token");
|
||||
}
|
||||
|
||||
// Minimal base64 encoder. The C++ backend already pulls in base64_decode from
|
||||
// llama.cpp's server-common.cpp, but no encoder is exposed — and we need one to
|
||||
// hand audio bytes to the existing PredictOptions.audios path (which expects
|
||||
// base64-encoded strings, just like images).
|
||||
static std::string base64_encode_bytes(const unsigned char* data, size_t len) {
|
||||
static const char tbl[] =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
std::string out;
|
||||
out.reserve(((len + 2) / 3) * 4);
|
||||
for (size_t i = 0; i < len; i += 3) {
|
||||
uint32_t triple = (uint32_t(data[i]) << 16);
|
||||
if (i + 1 < len) triple |= (uint32_t(data[i + 1]) << 8);
|
||||
if (i + 2 < len) triple |= uint32_t(data[i + 2]);
|
||||
out.push_back(tbl[(triple >> 18) & 0x3F]);
|
||||
out.push_back(tbl[(triple >> 12) & 0x3F]);
|
||||
out.push_back(i + 1 < len ? tbl[(triple >> 6) & 0x3F] : '=');
|
||||
out.push_back(i + 2 < len ? tbl[triple & 0x3F] : '=');
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// END LocalAI
|
||||
|
||||
|
||||
@@ -2791,6 +2814,13 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
|
||||
// Report the active multimodal media marker so the Go layer can emit the
|
||||
// same string when rendering prompts outside the tokenizer-template path.
|
||||
// Only meaningful when an mtmd context was initialized (vision/audio models).
|
||||
if (ctx_server.impl->mctx != nullptr) {
|
||||
response->set_media_marker(get_media_marker());
|
||||
}
|
||||
|
||||
// Check if chat templates are initialized
|
||||
if (ctx_server.impl->chat_params.tmpls == nullptr) {
|
||||
// If templates are not initialized, we can't detect thinking support
|
||||
@@ -2931,6 +2961,119 @@ public:
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
// runTranscriptionAsCompletion implements OAI /v1/audio/transcriptions on
|
||||
// top of the existing chat-completion + multimodal-audio pipeline, exactly
|
||||
// the way upstream llama.cpp's server does it (see
|
||||
// tools/server/server-context.cpp post_transcriptions_oai → forwards into
|
||||
// handle_completions_impl with a single user message attaching the audio
|
||||
// file via the mtmd marker).
|
||||
//
|
||||
// We synthesize a backend::PredictOptions with one user message
|
||||
// ("Transcribe audio to text" + optional language hint) and the audio
|
||||
// bytes attached via the existing PredictOptions.audios field, then
|
||||
// delegate to our own Predict() handler. This keeps every multimodal
|
||||
// codepath identical to the chat path and avoids duplicating ~700 lines
|
||||
// of task-construction logic.
|
||||
grpc::Status runTranscriptionAsCompletion(grpc::ServerContext* context,
|
||||
const backend::TranscriptRequest* request,
|
||||
backend::Reply* out_reply) {
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
}
|
||||
if (request->dst().empty()) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "dst (audio file path) is required");
|
||||
}
|
||||
|
||||
// Read audio bytes from the path LocalAI's HTTP layer wrote.
|
||||
std::ifstream f(request->dst(), std::ios::binary);
|
||||
if (!f.is_open()) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "failed to open audio file: " + request->dst());
|
||||
}
|
||||
std::vector<unsigned char> bytes((std::istreambuf_iterator<char>(f)),
|
||||
std::istreambuf_iterator<char>());
|
||||
f.close();
|
||||
if (bytes.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "audio file is empty: " + request->dst());
|
||||
}
|
||||
|
||||
std::string b64 = base64_encode_bytes(bytes.data(), bytes.size());
|
||||
|
||||
// Build the same prompt upstream uses in convert_transcriptions_to_chatcmpl.
|
||||
std::string user_prompt = "Transcribe audio to text";
|
||||
if (!request->language().empty()) {
|
||||
user_prompt += " (language: " + request->language() + ")";
|
||||
}
|
||||
if (!request->prompt().empty()) {
|
||||
// Optional context hint from the caller.
|
||||
user_prompt += "\n" + request->prompt();
|
||||
}
|
||||
|
||||
backend::PredictOptions synthetic;
|
||||
synthetic.set_usetokenizertemplate(true);
|
||||
synthetic.set_temperature(request->temperature());
|
||||
// Generation length: leave at 0 so parse_options uses -1 (model default).
|
||||
// The model's stop tokens / EOS handle termination naturally for ASR.
|
||||
backend::Message* msg = synthetic.add_messages();
|
||||
msg->set_role("user");
|
||||
msg->set_content(user_prompt);
|
||||
synthetic.add_audios(b64);
|
||||
|
||||
return Predict(context, &synthetic, out_reply);
|
||||
}
|
||||
|
||||
grpc::Status AudioTranscription(ServerContext* context,
|
||||
const backend::TranscriptRequest* request,
|
||||
backend::TranscriptResult* response) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
|
||||
backend::Reply reply;
|
||||
grpc::Status st = runTranscriptionAsCompletion(context, request, &reply);
|
||||
if (!st.ok()) {
|
||||
return st;
|
||||
}
|
||||
response->set_text(reply.message());
|
||||
if (!request->language().empty()) {
|
||||
response->set_language(request->language());
|
||||
}
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status AudioTranscriptionStream(ServerContext* context,
|
||||
const backend::TranscriptRequest* request,
|
||||
grpc::ServerWriter<backend::TranscriptStreamResponse>* writer) override {
|
||||
auto auth = checkAuth(context);
|
||||
if (!auth.ok()) return auth;
|
||||
|
||||
// Buffered streaming: run the transcription as a normal chat
|
||||
// completion, then emit one delta + one final event. Real
|
||||
// token-by-token streaming would require refactoring PredictStream's
|
||||
// 700-line writer-coupled body; the HTTP/SSE contract is identical
|
||||
// either way, and clients that only consume the assembled text don't
|
||||
// notice the difference.
|
||||
backend::Reply reply;
|
||||
grpc::Status st = runTranscriptionAsCompletion(context, request, &reply);
|
||||
if (!st.ok()) {
|
||||
return st;
|
||||
}
|
||||
|
||||
const std::string& text = reply.message();
|
||||
if (!text.empty()) {
|
||||
backend::TranscriptStreamResponse delta_chunk;
|
||||
delta_chunk.set_delta(text);
|
||||
writer->Write(delta_chunk);
|
||||
}
|
||||
|
||||
backend::TranscriptStreamResponse final_chunk;
|
||||
backend::TranscriptResult* final_result = final_chunk.mutable_final_result();
|
||||
final_result->set_text(text);
|
||||
if (!request->language().empty()) {
|
||||
final_result->set_language(request->language());
|
||||
}
|
||||
writer->Write(final_chunk);
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
81
backend/cpp/turboquant/Makefile
Normal file
81
backend/cpp/turboquant/Makefile
Normal file
@@ -0,0 +1,81 @@
|
||||
|
||||
# Pinned to the HEAD of feature/turboquant-kv-cache on https://github.com/TheTom/llama-cpp-turboquant.
|
||||
# Auto-bumped nightly by .github/workflows/bump_deps.yaml.
|
||||
TURBOQUANT_VERSION?=4d24ad87b8ed2ad160809af41930f1e04b83f234
|
||||
LLAMA_REPO?=https://github.com/TheTom/llama-cpp-turboquant
|
||||
|
||||
CMAKE_ARGS?=
|
||||
BUILD_TYPE?=
|
||||
NATIVE?=false
|
||||
ONEAPI_VARS?=/opt/intel/oneapi/setvars.sh
|
||||
TARGET?=--target grpc-server
|
||||
JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 1)
|
||||
ARCH?=$(shell uname -m)
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
LLAMA_CPP_DIR := $(CURRENT_MAKEFILE_DIR)/../llama-cpp
|
||||
|
||||
GREEN := \033[0;32m
|
||||
RESET := \033[0m
|
||||
|
||||
# turboquant is a llama.cpp fork. Rather than duplicating grpc-server.cpp / CMakeLists.txt /
|
||||
# prepare.sh we reuse the ones in backend/cpp/llama-cpp, and only swap which repo+sha the
|
||||
# fetch step pulls. Each flavor target copies ../llama-cpp into a sibling ../turboquant-<flavor>-build
|
||||
# directory, then invokes llama-cpp's own build-llama-cpp-grpc-server with LLAMA_REPO/LLAMA_VERSION
|
||||
# overridden to point at the fork.
|
||||
PATCHES_DIR := $(CURRENT_MAKEFILE_DIR)/patches
|
||||
|
||||
# Each flavor target:
|
||||
# 1. copies backend/cpp/llama-cpp/ (grpc-server.cpp + prepare.sh + CMakeLists.txt + Makefile)
|
||||
# into a sibling turboquant-<flavor>-build directory;
|
||||
# 2. clones the turboquant fork into turboquant-<flavor>-build/llama.cpp via the copy's
|
||||
# own `llama.cpp` target, overriding LLAMA_REPO/LLAMA_VERSION;
|
||||
# 3. applies patches from backend/cpp/turboquant/patches/ to the cloned fork sources
|
||||
# (needed until the fork catches up with upstream server-context.cpp changes);
|
||||
# 4. runs the copy's `grpc-server` target, which produces the binary we copy up as
|
||||
# turboquant-<flavor>.
|
||||
define turboquant-build
|
||||
rm -rf $(CURRENT_MAKEFILE_DIR)/../turboquant-$(1)-build
|
||||
cp -rf $(LLAMA_CPP_DIR) $(CURRENT_MAKEFILE_DIR)/../turboquant-$(1)-build
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../turboquant-$(1)-build purge
|
||||
# Augment the copied grpc-server.cpp's KV-cache allow-list with the
|
||||
# fork's turbo2/turbo3/turbo4 types. We patch the *copy*, never the
|
||||
# original under backend/cpp/llama-cpp/, so the stock llama-cpp build
|
||||
# stays compiling against vanilla upstream.
|
||||
bash $(CURRENT_MAKEFILE_DIR)/patch-grpc-server.sh $(CURRENT_MAKEFILE_DIR)/../turboquant-$(1)-build/grpc-server.cpp
|
||||
$(info $(GREEN)I turboquant build info:$(1)$(RESET))
|
||||
LLAMA_REPO=$(LLAMA_REPO) LLAMA_VERSION=$(TURBOQUANT_VERSION) \
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../turboquant-$(1)-build llama.cpp
|
||||
bash $(CURRENT_MAKEFILE_DIR)/apply-patches.sh $(CURRENT_MAKEFILE_DIR)/../turboquant-$(1)-build/llama.cpp $(PATCHES_DIR)
|
||||
CMAKE_ARGS="$(CMAKE_ARGS) $(2)" TARGET="$(3)" \
|
||||
LLAMA_REPO=$(LLAMA_REPO) LLAMA_VERSION=$(TURBOQUANT_VERSION) \
|
||||
$(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../turboquant-$(1)-build grpc-server
|
||||
cp -rfv $(CURRENT_MAKEFILE_DIR)/../turboquant-$(1)-build/grpc-server turboquant-$(1)
|
||||
endef
|
||||
|
||||
turboquant-avx2:
|
||||
$(call turboquant-build,avx2,-DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on,--target grpc-server)
|
||||
|
||||
turboquant-avx512:
|
||||
$(call turboquant-build,avx512,-DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on,--target grpc-server)
|
||||
|
||||
turboquant-avx:
|
||||
$(call turboquant-build,avx,-DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off,--target grpc-server)
|
||||
|
||||
turboquant-fallback:
|
||||
$(call turboquant-build,fallback,-DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off,--target grpc-server)
|
||||
|
||||
turboquant-grpc:
|
||||
$(call turboquant-build,grpc,-DGGML_RPC=ON -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off,--target grpc-server --target rpc-server)
|
||||
|
||||
turboquant-rpc-server: turboquant-grpc
|
||||
cp -rf $(CURRENT_MAKEFILE_DIR)/../turboquant-grpc-build/llama.cpp/build/bin/rpc-server turboquant-rpc-server
|
||||
|
||||
package:
|
||||
bash package.sh
|
||||
|
||||
purge:
|
||||
rm -rf $(CURRENT_MAKEFILE_DIR)/../turboquant-*-build
|
||||
rm -rf turboquant-* package
|
||||
|
||||
clean: purge
|
||||
50
backend/cpp/turboquant/apply-patches.sh
Executable file
50
backend/cpp/turboquant/apply-patches.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
# Apply the turboquant patch series to a cloned llama-cpp-turboquant checkout.
|
||||
#
|
||||
# The turboquant fork branched from upstream llama.cpp before a few API changes
|
||||
# that the shared backend/cpp/llama-cpp/grpc-server.cpp depends on. We carry
|
||||
# those upstream commits as patch files under backend/cpp/turboquant/patches/
|
||||
# and apply them here so the reused grpc-server source compiles against the
|
||||
# fork unmodified.
|
||||
#
|
||||
# Drop the corresponding patch from patches/ whenever the fork catches up with
|
||||
# upstream — the build will fail fast if a patch stops applying, which is the
|
||||
# signal to retire it.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# -ne 2 ]]; then
|
||||
echo "usage: $0 <llama.cpp-src-dir> <patches-dir>" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
SRC_DIR=$1
|
||||
PATCHES_DIR=$2
|
||||
|
||||
if [[ ! -d "$SRC_DIR" ]]; then
|
||||
echo "source dir does not exist: $SRC_DIR" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if [[ ! -d "$PATCHES_DIR" ]]; then
|
||||
echo "no patches dir at $PATCHES_DIR, nothing to apply"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
shopt -s nullglob
|
||||
patches=("$PATCHES_DIR"/*.patch)
|
||||
shopt -u nullglob
|
||||
|
||||
if [[ ${#patches[@]} -eq 0 ]]; then
|
||||
echo "no .patch files in $PATCHES_DIR, nothing to apply"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
cd "$SRC_DIR"
|
||||
|
||||
for patch in "${patches[@]}"; do
|
||||
echo "==> applying $patch"
|
||||
git apply --verbose "$patch"
|
||||
done
|
||||
|
||||
echo "all turboquant patches applied successfully"
|
||||
57
backend/cpp/turboquant/package.sh
Executable file
57
backend/cpp/turboquant/package.sh
Executable file
@@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to copy the appropriate libraries based on architecture
|
||||
# This script is used in the final stage of the Dockerfile
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
REPO_ROOT="${CURDIR}/../../.."
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avrf $CURDIR/turboquant-* $CURDIR/package/
|
||||
cp -rfv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then
|
||||
# x86_64 architecture
|
||||
echo "Detected x86_64 architecture, copying x86_64 libraries..."
|
||||
cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then
|
||||
# ARM64 architecture
|
||||
echo "Detected ARM64 architecture, copying ARM64 libraries..."
|
||||
cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2
|
||||
cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1
|
||||
cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0
|
||||
else
|
||||
echo "Error: Could not detect architecture"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Package GPU libraries based on BUILD_TYPE
|
||||
GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh"
|
||||
if [ -f "$GPU_LIB_SCRIPT" ]; then
|
||||
echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..."
|
||||
source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib"
|
||||
package_gpu_libs
|
||||
fi
|
||||
|
||||
echo "Packaging completed successfully"
|
||||
ls -liah $CURDIR/package/
|
||||
ls -liah $CURDIR/package/lib/
|
||||
80
backend/cpp/turboquant/patch-grpc-server.sh
Executable file
80
backend/cpp/turboquant/patch-grpc-server.sh
Executable file
@@ -0,0 +1,80 @@
|
||||
#!/bin/bash
|
||||
# Patch the shared backend/cpp/llama-cpp/grpc-server.cpp *copy* used by the
|
||||
# turboquant build to account for two gaps between upstream and the fork:
|
||||
#
|
||||
# 1. Augment the kv_cache_types[] allow-list so `LoadModel` accepts the
|
||||
# fork-specific `turbo2` / `turbo3` / `turbo4` cache types.
|
||||
# 2. Replace `get_media_marker()` (added upstream in ggml-org/llama.cpp#21962,
|
||||
# server-side random per-instance marker) with the legacy "<__media__>"
|
||||
# literal. The fork branched before that PR, so server-common.cpp has no
|
||||
# get_media_marker symbol. The fork's mtmd_default_marker() still returns
|
||||
# "<__media__>", and Go-side tooling falls back to that sentinel when the
|
||||
# backend does not expose media_marker, so substituting the literal keeps
|
||||
# behavior identical on the turboquant path.
|
||||
#
|
||||
# We patch the *copy* sitting in turboquant-<flavor>-build/, never the original
|
||||
# under backend/cpp/llama-cpp/, so the stock llama-cpp build keeps compiling
|
||||
# against vanilla upstream.
|
||||
#
|
||||
# Idempotent: skips each insertion if its marker is already present (so re-runs
|
||||
# of the same build dir don't double-insert).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# -ne 1 ]]; then
|
||||
echo "usage: $0 <grpc-server.cpp>" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
SRC=$1
|
||||
|
||||
if [[ ! -f "$SRC" ]]; then
|
||||
echo "grpc-server.cpp not found at $SRC" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if grep -q 'GGML_TYPE_TURBO2_0' "$SRC"; then
|
||||
echo "==> $SRC already has TurboQuant cache types, skipping KV allow-list patch"
|
||||
else
|
||||
echo "==> patching $SRC to allow turbo2/turbo3/turbo4 KV-cache types"
|
||||
|
||||
# Insert the three TURBO entries right after the first ` GGML_TYPE_Q5_1,`
|
||||
# line (the kv_cache_types[] allow-list). Using awk because the builder image
|
||||
# does not ship python3, and GNU sed's multi-line `a\` quoting is awkward.
|
||||
awk '
|
||||
/^ GGML_TYPE_Q5_1,$/ && !done {
|
||||
print
|
||||
print " // turboquant fork extras — added by patch-grpc-server.sh"
|
||||
print " GGML_TYPE_TURBO2_0,"
|
||||
print " GGML_TYPE_TURBO3_0,"
|
||||
print " GGML_TYPE_TURBO4_0,"
|
||||
done = 1
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
END {
|
||||
if (!done) {
|
||||
print "patch-grpc-server.sh: anchor ` GGML_TYPE_Q5_1,` not found" > "/dev/stderr"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
|
||||
echo "==> KV allow-list patch OK"
|
||||
fi
|
||||
|
||||
if grep -q 'get_media_marker()' "$SRC"; then
|
||||
echo "==> patching $SRC to replace get_media_marker() with legacy \"<__media__>\" literal"
|
||||
# Only one call site today (ModelMetadata), but replace all occurrences to
|
||||
# stay robust if upstream adds more. Use a temp file to avoid relying on
|
||||
# sed -i portability (the builder image uses GNU sed, but keeping this
|
||||
# consistent with the awk block above).
|
||||
sed 's/get_media_marker()/"<__media__>"/g' "$SRC" > "$SRC.tmp"
|
||||
mv "$SRC.tmp" "$SRC"
|
||||
echo "==> get_media_marker() substitution OK"
|
||||
else
|
||||
echo "==> $SRC has no get_media_marker() call, skipping media-marker patch"
|
||||
fi
|
||||
|
||||
echo "==> all patches applied"
|
||||
@@ -0,0 +1,47 @@
|
||||
From: LocalAI turboquant backend maintainers <noreply@localai.io>
|
||||
Subject: ggml-hip: add F16-K + TURBO-V fattn-vec template instances
|
||||
|
||||
Upstream commit fa4e8be0a0ce ("fix(cuda): add F16-K + TURBO-V dispatch cases
|
||||
in fattn.cu") added three new template instance files under ggml-cuda/:
|
||||
|
||||
- fattn-vec-instance-f16-turbo2_0.cu
|
||||
- fattn-vec-instance-f16-turbo3_0.cu
|
||||
- fattn-vec-instance-f16-turbo4_0.cu
|
||||
|
||||
and registered them in ggml/src/ggml-cuda/CMakeLists.txt. The companion
|
||||
dispatch cases FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO{2,3,4}_0)
|
||||
were added to ggml/src/ggml-cuda/fattn.cu, which is shared with the HIP
|
||||
build path via hipify.
|
||||
|
||||
However, ggml/src/ggml-hip/CMakeLists.txt carries its own explicit list of
|
||||
template instance sources (used when GGML_CUDA_FA_ALL_QUANTS is OFF, which
|
||||
is the default) and was never updated for the new F16-K + TURBO-V combos.
|
||||
The HIP build therefore compiles the dispatch cases (which reference
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, F16, TURBO*>) without ever compiling
|
||||
the matching template instantiations, causing a link-time failure in the
|
||||
-gpu-rocm-hipblas-turboquant CI job.
|
||||
|
||||
Add the three new template instance files to ggml-hip's list so the HIP
|
||||
build links cleanly. Drop this patch once the fork picks up the
|
||||
corresponding upstream sync in ggml-hip/CMakeLists.txt.
|
||||
|
||||
--- a/ggml/src/ggml-hip/CMakeLists.txt
|
||||
+++ b/ggml/src/ggml-hip/CMakeLists.txt
|
||||
@@ -85,14 +85,17 @@ else()
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-q8_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu
|
||||
+ ../ggml-cuda/template-instances/fattn-vec-instance-f16-turbo3_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-q8_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu
|
||||
+ ../ggml-cuda/template-instances/fattn-vec-instance-f16-turbo2_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-q8_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu
|
||||
+ ../ggml-cuda/template-instances/fattn-vec-instance-f16-turbo4_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu
|
||||
../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu
|
||||
65
backend/cpp/turboquant/run.sh
Executable file
65
backend/cpp/turboquant/run.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
|
||||
BINARY=turboquant-fallback
|
||||
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/turboquant-avx ]; then
|
||||
BINARY=turboquant-avx
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/turboquant-avx2 ]; then
|
||||
BINARY=turboquant-avx2
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/turboquant-avx512 ]; then
|
||||
BINARY=turboquant-avx512
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMACPP_GRPC_SERVERS" ]; then
|
||||
if [ -e $CURDIR/turboquant-grpc ]; then
|
||||
BINARY=turboquant-grpc
|
||||
fi
|
||||
fi
|
||||
|
||||
# Extend ld library path with the dir where this script is located/lib
|
||||
if [ "$(uname)" == "Darwin" ]; then
|
||||
export DYLD_LIBRARY_PATH=$CURDIR/lib:$DYLD_LIBRARY_PATH
|
||||
else
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
# Tell rocBLAS where to find TensileLibrary data (GPU kernel tuning files)
|
||||
if [ -d "$CURDIR/lib/rocblas/library" ]; then
|
||||
export ROCBLAS_TENSILE_LIBPATH=$CURDIR/lib/rocblas/library
|
||||
fi
|
||||
fi
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using binary: $BINARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/$BINARY "$@"
|
||||
fi
|
||||
|
||||
echo "Using binary: $BINARY"
|
||||
exec $CURDIR/$BINARY "$@"
|
||||
|
||||
# We should never reach this point, however just in case we do, run fallback
|
||||
exec $CURDIR/turboquant-fallback "$@"
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=6b675a5ede9b0edf0a0f44191e8b79d7ef27615a
|
||||
STABLEDIFFUSION_GGML_VERSION?=44cca3d626d301e2215d5e243277e8f0e65bfa78
|
||||
|
||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||
|
||||
|
||||
@@ -26,6 +26,10 @@
|
||||
#include "stb_image_resize.h"
|
||||
#include <stdlib.h>
|
||||
#include <regex>
|
||||
#include <errno.h>
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
|
||||
|
||||
|
||||
@@ -980,6 +984,256 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha
|
||||
return !ret;
|
||||
}
|
||||
|
||||
// ---------------- Video generation ----------------
|
||||
|
||||
sd_vid_gen_params_t* sd_vid_gen_params_new(void) {
|
||||
sd_vid_gen_params_t *params = (sd_vid_gen_params_t *)std::malloc(sizeof(sd_vid_gen_params_t));
|
||||
sd_vid_gen_params_init(params);
|
||||
sd_sample_params_init(¶ms->sample_params);
|
||||
sd_sample_params_init(¶ms->high_noise_sample_params);
|
||||
sd_cache_params_init(¶ms->cache);
|
||||
return params;
|
||||
}
|
||||
|
||||
// Persistent storage for cleaned video prompts (kept alive for the duration of generation)
|
||||
static std::string cleaned_vid_prompt_storage;
|
||||
static std::string cleaned_vid_negative_prompt_storage;
|
||||
|
||||
void sd_vid_gen_params_set_prompts(sd_vid_gen_params_t *params, const char *prompt, const char *negative_prompt) {
|
||||
lora_vec.clear();
|
||||
lora_strings.clear();
|
||||
|
||||
std::string prompt_str = prompt ? prompt : "";
|
||||
std::string negative_prompt_str = negative_prompt ? negative_prompt : "";
|
||||
|
||||
const char* lora_dir_to_use = lora_dir_path.empty() ? nullptr : lora_dir_path.c_str();
|
||||
|
||||
auto [loras, cleaned_prompt] = parse_loras_from_prompt(prompt_str, lora_dir_to_use);
|
||||
lora_vec = loras;
|
||||
cleaned_vid_prompt_storage = cleaned_prompt;
|
||||
|
||||
auto [neg_loras, cleaned_negative] = parse_loras_from_prompt(negative_prompt_str, lora_dir_to_use);
|
||||
cleaned_vid_negative_prompt_storage = cleaned_negative;
|
||||
|
||||
params->prompt = cleaned_vid_prompt_storage.c_str();
|
||||
params->negative_prompt = cleaned_vid_negative_prompt_storage.c_str();
|
||||
params->loras = lora_vec.empty() ? nullptr : lora_vec.data();
|
||||
params->lora_count = static_cast<uint32_t>(lora_vec.size());
|
||||
}
|
||||
|
||||
void sd_vid_gen_params_set_dimensions(sd_vid_gen_params_t *params, int width, int height) {
|
||||
params->width = width;
|
||||
params->height = height;
|
||||
}
|
||||
|
||||
void sd_vid_gen_params_set_seed(sd_vid_gen_params_t *params, int64_t seed) {
|
||||
params->seed = seed;
|
||||
}
|
||||
|
||||
void sd_vid_gen_params_set_video_frames(sd_vid_gen_params_t *params, int n) {
|
||||
params->video_frames = n;
|
||||
}
|
||||
|
||||
// Load an image file into an sd_image_t, resizing to target dims if needed.
|
||||
// Returns a heap-allocated buffer the caller must free (or nullptr on failure).
|
||||
static uint8_t* load_and_resize_image(const char* path, int target_width, int target_height, sd_image_t* out) {
|
||||
if (!path || strlen(path) == 0) {
|
||||
*out = {0, 0, 0, nullptr};
|
||||
return nullptr;
|
||||
}
|
||||
int c = 0, img_w = 0, img_h = 0;
|
||||
uint8_t* buf = stbi_load(path, &img_w, &img_h, &c, 3);
|
||||
if (!buf) {
|
||||
fprintf(stderr, "Failed to load image from '%s'\n", path);
|
||||
*out = {0, 0, 0, nullptr};
|
||||
return nullptr;
|
||||
}
|
||||
if (img_w != target_width || img_h != target_height) {
|
||||
fprintf(stderr, "Resizing image from %dx%d to %dx%d\n", img_w, img_h, target_width, target_height);
|
||||
uint8_t* resized = (uint8_t*)malloc((size_t)target_width * target_height * 3);
|
||||
if (!resized) { free(buf); *out = {0, 0, 0, nullptr}; return nullptr; }
|
||||
stbir_resize(buf, img_w, img_h, 0,
|
||||
resized, target_width, target_height, 0, STBIR_TYPE_UINT8,
|
||||
3, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||
STBIR_COLORSPACE_SRGB, nullptr);
|
||||
free(buf);
|
||||
buf = resized;
|
||||
}
|
||||
*out = {(uint32_t)target_width, (uint32_t)target_height, 3, buf};
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Pipe raw RGB/RGBA frames to ffmpeg stdin and let it produce an MP4 at dst.
|
||||
// Uses fork+execvp to avoid shell interpretation of dst.
|
||||
static int ffmpeg_mux_raw_to_mp4(sd_image_t* frames, int num_frames, int fps, const char* dst) {
|
||||
if (num_frames <= 0 || !frames || !frames[0].data) {
|
||||
fprintf(stderr, "ffmpeg_mux: empty frames\n");
|
||||
return 1;
|
||||
}
|
||||
int width = (int)frames[0].width;
|
||||
int height = (int)frames[0].height;
|
||||
int channels = (int)frames[0].channel;
|
||||
const char* pix_fmt_in = (channels == 4) ? "rgba" : "rgb24";
|
||||
|
||||
char size_str[32];
|
||||
char fps_str[32];
|
||||
snprintf(size_str, sizeof(size_str), "%dx%d", width, height);
|
||||
snprintf(fps_str, sizeof(fps_str), "%d", fps);
|
||||
|
||||
int pipefd[2];
|
||||
if (pipe(pipefd) != 0) { perror("pipe"); return 1; }
|
||||
|
||||
pid_t pid = fork();
|
||||
if (pid < 0) { perror("fork"); close(pipefd[0]); close(pipefd[1]); return 1; }
|
||||
|
||||
if (pid == 0) {
|
||||
// child
|
||||
close(pipefd[1]);
|
||||
if (dup2(pipefd[0], STDIN_FILENO) < 0) { perror("dup2"); _exit(127); }
|
||||
close(pipefd[0]);
|
||||
std::vector<char*> argv = {
|
||||
const_cast<char*>("ffmpeg"),
|
||||
const_cast<char*>("-y"),
|
||||
const_cast<char*>("-hide_banner"),
|
||||
const_cast<char*>("-loglevel"), const_cast<char*>("warning"),
|
||||
const_cast<char*>("-f"), const_cast<char*>("rawvideo"),
|
||||
const_cast<char*>("-pix_fmt"), const_cast<char*>(pix_fmt_in),
|
||||
const_cast<char*>("-s"), size_str,
|
||||
const_cast<char*>("-framerate"), fps_str,
|
||||
const_cast<char*>("-i"), const_cast<char*>("-"),
|
||||
const_cast<char*>("-c:v"), const_cast<char*>("libx264"),
|
||||
const_cast<char*>("-pix_fmt"), const_cast<char*>("yuv420p"),
|
||||
const_cast<char*>("-movflags"), const_cast<char*>("+faststart"),
|
||||
// Force MP4 container. Distributed LocalAI hands us a staging
|
||||
// path (e.g. /staging/localai-output-NNN.tmp) with a non-standard
|
||||
// extension; relying on filename suffix makes ffmpeg bail with
|
||||
// "Unable to choose an output format".
|
||||
const_cast<char*>("-f"), const_cast<char*>("mp4"),
|
||||
const_cast<char*>(dst),
|
||||
nullptr
|
||||
};
|
||||
execvp(argv[0], argv.data());
|
||||
perror("execvp ffmpeg");
|
||||
_exit(127);
|
||||
}
|
||||
|
||||
// parent
|
||||
close(pipefd[0]);
|
||||
|
||||
// Ignore SIGPIPE so a dying ffmpeg surfaces via write() errno instead of killing us.
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
|
||||
for (int i = 0; i < num_frames; i++) {
|
||||
if (!frames[i].data) continue;
|
||||
size_t frame_bytes = (size_t)frames[i].width * frames[i].height * frames[i].channel;
|
||||
const uint8_t* p = frames[i].data;
|
||||
size_t remaining = frame_bytes;
|
||||
while (remaining > 0) {
|
||||
ssize_t n = write(pipefd[1], p, remaining);
|
||||
if (n < 0) {
|
||||
if (errno == EINTR) continue;
|
||||
perror("write frame to ffmpeg");
|
||||
close(pipefd[1]);
|
||||
int status;
|
||||
waitpid(pid, &status, 0);
|
||||
return 1;
|
||||
}
|
||||
p += n;
|
||||
remaining -= (size_t)n;
|
||||
}
|
||||
}
|
||||
close(pipefd[1]);
|
||||
|
||||
int status = 0;
|
||||
while (waitpid(pid, &status, 0) < 0) {
|
||||
if (errno != EINTR) { perror("waitpid"); return 1; }
|
||||
}
|
||||
if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) {
|
||||
fprintf(stderr, "ffmpeg exited with status %d\n", status);
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int fps, char *init_image, char *end_image) {
|
||||
if (!p) return 1;
|
||||
if (!dst || strlen(dst) == 0) {
|
||||
fprintf(stderr, "gen_video: dst is empty\n");
|
||||
std::free(p);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<int> skip_layers = {7, 8, 9};
|
||||
|
||||
fprintf(stderr, "Generating video: %dx%d, frames=%d, fps=%d, steps=%d, cfg=%.2f\n",
|
||||
p->width, p->height, p->video_frames, fps, steps, cfg_scale);
|
||||
|
||||
// Sample params (shared by both low and high-noise passes — MoE models use the high-noise
|
||||
// set during the first phase; single-model Wan2.1 ignores it. Same defaults for both is fine.)
|
||||
p->sample_params.guidance.txt_cfg = cfg_scale;
|
||||
p->sample_params.guidance.slg.layers = skip_layers.data();
|
||||
p->sample_params.guidance.slg.layer_count = skip_layers.size();
|
||||
p->sample_params.sample_method = sample_method;
|
||||
p->sample_params.sample_steps = steps;
|
||||
p->sample_params.scheduler = scheduler;
|
||||
p->sample_params.flow_shift = flow_shift;
|
||||
|
||||
p->high_noise_sample_params.guidance.txt_cfg = cfg_scale;
|
||||
p->high_noise_sample_params.guidance.slg.layers = skip_layers.data();
|
||||
p->high_noise_sample_params.guidance.slg.layer_count = skip_layers.size();
|
||||
p->high_noise_sample_params.sample_method = sample_method;
|
||||
p->high_noise_sample_params.sample_steps = steps;
|
||||
p->high_noise_sample_params.scheduler = scheduler;
|
||||
p->high_noise_sample_params.flow_shift = flow_shift;
|
||||
|
||||
// Load init/end reference images if provided (resized to output dims).
|
||||
uint8_t* init_buf = nullptr;
|
||||
uint8_t* end_buf = nullptr;
|
||||
sd_image_t init_img = {0, 0, 0, nullptr};
|
||||
sd_image_t end_img = {0, 0, 0, nullptr};
|
||||
if (init_image && strlen(init_image) > 0) {
|
||||
init_buf = load_and_resize_image(init_image, p->width, p->height, &init_img);
|
||||
if (!init_buf) { std::free(p); return 1; }
|
||||
}
|
||||
if (end_image && strlen(end_image) > 0) {
|
||||
end_buf = load_and_resize_image(end_image, p->width, p->height, &end_img);
|
||||
if (!end_buf) { if (init_buf) free(init_buf); std::free(p); return 1; }
|
||||
}
|
||||
p->init_image = init_img;
|
||||
p->end_image = end_img;
|
||||
|
||||
// Generate
|
||||
int num_frames_out = 0;
|
||||
sd_image_t* frames = generate_video(sd_c, p, &num_frames_out);
|
||||
std::free(p);
|
||||
|
||||
if (!frames || num_frames_out == 0) {
|
||||
fprintf(stderr, "generate_video produced no frames\n");
|
||||
if (init_buf) free(init_buf);
|
||||
if (end_buf) free(end_buf);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Generated %d frames, muxing to %s via ffmpeg\n", num_frames_out, dst);
|
||||
|
||||
int rc = ffmpeg_mux_raw_to_mp4(frames, num_frames_out, fps, dst);
|
||||
|
||||
for (int i = 0; i < num_frames_out; i++) {
|
||||
if (frames[i].data) free(frames[i].data);
|
||||
}
|
||||
free(frames);
|
||||
if (init_buf) free(init_buf);
|
||||
if (end_buf) free(end_buf);
|
||||
|
||||
if (rc == 0) {
|
||||
fprintf(stderr, "gen_video done: %s\n", dst);
|
||||
}
|
||||
fflush(stderr);
|
||||
return rc;
|
||||
}
|
||||
|
||||
int unload() {
|
||||
free_sd_ctx(sd_c);
|
||||
return 0;
|
||||
|
||||
@@ -23,6 +23,7 @@ type SDGGML struct {
|
||||
var (
|
||||
LoadModel func(model, model_apth string, options []uintptr, threads int32, diff int) int
|
||||
GenImage func(params uintptr, steps int, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []uintptr, refImagesCount int) int
|
||||
GenVideo func(params uintptr, steps int, dst string, cfgScale float32, fps int, initImage string, endImage string) int
|
||||
|
||||
TilingParamsSetEnabled func(params uintptr, enabled bool)
|
||||
TilingParamsSetTileSizes func(params uintptr, tileSizeX int, tileSizeY int)
|
||||
@@ -34,6 +35,12 @@ var (
|
||||
ImgGenParamsSetDimensions func(params uintptr, width int, height int)
|
||||
ImgGenParamsSetSeed func(params uintptr, seed int64)
|
||||
ImgGenParamsGetVaeTilingParams func(params uintptr) uintptr
|
||||
|
||||
VidGenParamsNew func() uintptr
|
||||
VidGenParamsSetPrompts func(params uintptr, prompt string, negativePrompt string)
|
||||
VidGenParamsSetDimensions func(params uintptr, width int, height int)
|
||||
VidGenParamsSetSeed func(params uintptr, seed int64)
|
||||
VidGenParamsSetVideoFrames func(params uintptr, n int)
|
||||
)
|
||||
|
||||
// Copied from Purego internal/strings
|
||||
@@ -153,3 +160,58 @@ func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sd *SDGGML) GenerateVideo(opts *pb.GenerateVideoRequest) error {
|
||||
dst := opts.Dst
|
||||
if dst == "" {
|
||||
return fmt.Errorf("dst is empty")
|
||||
}
|
||||
|
||||
width := int(opts.Width)
|
||||
height := int(opts.Height)
|
||||
if width == 0 {
|
||||
width = 512
|
||||
}
|
||||
if height == 0 {
|
||||
height = 512
|
||||
}
|
||||
|
||||
numFrames := int(opts.NumFrames)
|
||||
if numFrames <= 0 {
|
||||
numFrames = 16
|
||||
}
|
||||
|
||||
fps := int(opts.Fps)
|
||||
if fps <= 0 {
|
||||
fps = 16
|
||||
}
|
||||
|
||||
steps := int(opts.Step)
|
||||
if steps <= 0 {
|
||||
steps = 20
|
||||
}
|
||||
|
||||
cfg := opts.CfgScale
|
||||
if cfg == 0 {
|
||||
cfg = sd.cfgScale
|
||||
}
|
||||
if cfg == 0 {
|
||||
cfg = 5.0
|
||||
}
|
||||
|
||||
// sd_vid_gen_params_new allocates; gen_video frees it after the generation call.
|
||||
p := VidGenParamsNew()
|
||||
VidGenParamsSetPrompts(p, opts.Prompt, opts.NegativePrompt)
|
||||
VidGenParamsSetDimensions(p, width, height)
|
||||
VidGenParamsSetSeed(p, int64(opts.Seed))
|
||||
VidGenParamsSetVideoFrames(p, numFrames)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "GenerateVideo: dst=%s size=%dx%d frames=%d fps=%d steps=%d cfg=%.2f\n",
|
||||
dst, width, height, numFrames, fps, steps, cfg)
|
||||
|
||||
ret := GenVideo(p, steps, dst, cfg, fps, opts.StartImage, opts.EndImage)
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("video inference failed (code %d)", ret)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,6 +18,13 @@ void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed);
|
||||
|
||||
int load_model(const char *model, char *model_path, char* options[], int threads, int diffusionModel);
|
||||
int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char* ref_images[], int ref_images_count);
|
||||
|
||||
sd_vid_gen_params_t* sd_vid_gen_params_new(void);
|
||||
void sd_vid_gen_params_set_prompts(sd_vid_gen_params_t *params, const char *prompt, const char *negative_prompt);
|
||||
void sd_vid_gen_params_set_dimensions(sd_vid_gen_params_t *params, int width, int height);
|
||||
void sd_vid_gen_params_set_seed(sd_vid_gen_params_t *params, int64_t seed);
|
||||
void sd_vid_gen_params_set_video_frames(sd_vid_gen_params_t *params, int n);
|
||||
int gen_video(sd_vid_gen_params_t *p, int steps, char *dst, float cfg_scale, int fps, char *init_image, char *end_image);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -32,6 +32,7 @@ func main() {
|
||||
libFuncs := []LibFuncs{
|
||||
{&LoadModel, "load_model"},
|
||||
{&GenImage, "gen_image"},
|
||||
{&GenVideo, "gen_video"},
|
||||
{&TilingParamsSetEnabled, "sd_tiling_params_set_enabled"},
|
||||
{&TilingParamsSetTileSizes, "sd_tiling_params_set_tile_sizes"},
|
||||
{&TilingParamsSetRelSizes, "sd_tiling_params_set_rel_sizes"},
|
||||
@@ -42,6 +43,12 @@ func main() {
|
||||
{&ImgGenParamsSetDimensions, "sd_img_gen_params_set_dimensions"},
|
||||
{&ImgGenParamsSetSeed, "sd_img_gen_params_set_seed"},
|
||||
{&ImgGenParamsGetVaeTilingParams, "sd_img_gen_params_get_vae_tiling_params"},
|
||||
|
||||
{&VidGenParamsNew, "sd_vid_gen_params_new"},
|
||||
{&VidGenParamsSetPrompts, "sd_vid_gen_params_set_prompts"},
|
||||
{&VidGenParamsSetDimensions, "sd_vid_gen_params_set_dimensions"},
|
||||
{&VidGenParamsSetSeed, "sd_vid_gen_params_set_seed"},
|
||||
{&VidGenParamsSetVideoFrames, "sd_vid_gen_params_set_video_frames"},
|
||||
}
|
||||
|
||||
for _, lf := range libFuncs {
|
||||
|
||||
@@ -56,5 +56,6 @@ func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: text,
|
||||
Language: opts.Language,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=95ea8f9bfb03a15db08a8989966fd1ae3361e20d
|
||||
WHISPER_CPP_VERSION?=fc674574ca27cac59a15e5b22a09b9d9ad62aafe
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -120,6 +120,12 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
// whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate
|
||||
// for the converted file produced by AudioToWav above.
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
}
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
@@ -158,5 +164,7 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||
return pb.TranscriptResult{
|
||||
Segments: segments,
|
||||
Text: strings.TrimSpace(text),
|
||||
Language: opts.Language,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -43,6 +43,35 @@
|
||||
- CPU
|
||||
capabilities:
|
||||
default: "cpu-ik-llama-cpp"
|
||||
- &turboquant
|
||||
name: "turboquant"
|
||||
alias: "turboquant"
|
||||
license: mit
|
||||
description: |
|
||||
Fork of llama.cpp adding the TurboQuant KV-cache quantization scheme.
|
||||
Reuses the LocalAI llama.cpp gRPC server sources against the fork's libllama.
|
||||
urls:
|
||||
- https://github.com/TheTom/llama-cpp-turboquant
|
||||
tags:
|
||||
- text-to-text
|
||||
- LLM
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
- HIP
|
||||
- turboquant
|
||||
- kv-cache
|
||||
capabilities:
|
||||
default: "cpu-turboquant"
|
||||
nvidia: "cuda12-turboquant"
|
||||
intel: "intel-sycl-f16-turboquant"
|
||||
amd: "rocm-turboquant"
|
||||
vulkan: "vulkan-turboquant"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-turboquant"
|
||||
nvidia-cuda-13: "cuda13-turboquant"
|
||||
nvidia-cuda-12: "cuda12-turboquant"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-turboquant"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-turboquant"
|
||||
- &whispercpp
|
||||
name: "whisper"
|
||||
alias: "whisper"
|
||||
@@ -197,6 +226,29 @@
|
||||
amd: "rocm-vllm"
|
||||
intel: "intel-vllm"
|
||||
nvidia-cuda-12: "cuda12-vllm"
|
||||
cpu: "cpu-vllm"
|
||||
- &sglang
|
||||
name: "sglang"
|
||||
license: apache-2.0
|
||||
urls:
|
||||
- https://github.com/sgl-project/sglang
|
||||
tags:
|
||||
- text-to-text
|
||||
- multimodal
|
||||
icon: https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png
|
||||
description: |
|
||||
SGLang is a fast serving framework for large language models and vision language models.
|
||||
It co-designs the backend runtime (RadixAttention, continuous batching, structured
|
||||
decoding) and the frontend language to make interaction with models faster and more
|
||||
controllable. Features include fast backend runtime, flexible frontend language,
|
||||
extensive model support, and an active community.
|
||||
alias: "sglang"
|
||||
capabilities:
|
||||
nvidia: "cuda12-sglang"
|
||||
amd: "rocm-sglang"
|
||||
intel: "intel-sglang"
|
||||
nvidia-cuda-12: "cuda12-sglang"
|
||||
cpu: "cpu-sglang"
|
||||
- &vllm-omni
|
||||
name: "vllm-omni"
|
||||
license: apache-2.0
|
||||
@@ -331,6 +383,34 @@
|
||||
intel: "intel-rerankers"
|
||||
amd: "rocm-rerankers"
|
||||
metal: "metal-rerankers"
|
||||
- &tinygrad
|
||||
name: "tinygrad"
|
||||
alias: "tinygrad"
|
||||
license: MIT
|
||||
description: |
|
||||
tinygrad is a minimalist deep-learning framework with zero runtime
|
||||
dependencies that targets CUDA, ROCm, Metal, WebGPU and CPU (CLANG).
|
||||
The LocalAI tinygrad backend exposes a single multimodal runtime that
|
||||
covers LLM text generation (Llama / Qwen / Mistral via safetensors or
|
||||
GGUF) with native tool-call extraction, BERT-family embeddings,
|
||||
Stable Diffusion 1.x / 2 / XL image generation, and Whisper speech-to-text.
|
||||
|
||||
Single image: tinygrad generates its own GPU kernels and dlopens the
|
||||
host driver libraries at runtime, so there is no per-toolkit build
|
||||
split. The same image runs CPU-only or accelerates against
|
||||
CUDA / ROCm / Metal when the host driver is visible.
|
||||
urls:
|
||||
- https://github.com/tinygrad/tinygrad
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-tinygrad"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-tinygrad
|
||||
tags:
|
||||
- text-to-text
|
||||
- LLM
|
||||
- embeddings
|
||||
- image-generation
|
||||
- transcription
|
||||
- multimodal
|
||||
- &transformers
|
||||
name: "transformers"
|
||||
icon: https://avatars.githubusercontent.com/u/25720743?s=200&v=4
|
||||
@@ -507,7 +587,6 @@
|
||||
alias: "whisperx"
|
||||
capabilities:
|
||||
nvidia: "cuda12-whisperx"
|
||||
amd: "rocm-whisperx"
|
||||
metal: "metal-whisperx"
|
||||
default: "cpu-whisperx"
|
||||
nvidia-cuda-13: "cuda13-whisperx"
|
||||
@@ -915,6 +994,33 @@
|
||||
name: "ik-llama-cpp-development"
|
||||
capabilities:
|
||||
default: "cpu-ik-llama-cpp-development"
|
||||
- !!merge <<: *turboquant
|
||||
name: "turboquant-development"
|
||||
capabilities:
|
||||
default: "cpu-turboquant-development"
|
||||
nvidia: "cuda12-turboquant-development"
|
||||
intel: "intel-sycl-f16-turboquant-development"
|
||||
amd: "rocm-turboquant-development"
|
||||
vulkan: "vulkan-turboquant-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-turboquant-development"
|
||||
nvidia-cuda-13: "cuda13-turboquant-development"
|
||||
nvidia-cuda-12: "cuda12-turboquant-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-turboquant-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-turboquant-development"
|
||||
- !!merge <<: *stablediffusionggml
|
||||
name: "stablediffusion-ggml-development"
|
||||
capabilities:
|
||||
default: "cpu-stablediffusion-ggml-development"
|
||||
nvidia: "cuda12-stablediffusion-ggml-development"
|
||||
intel: "intel-sycl-f16-stablediffusion-ggml-development"
|
||||
# amd: "rocm-stablediffusion-ggml-development"
|
||||
vulkan: "vulkan-stablediffusion-ggml-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-stablediffusion-ggml-development"
|
||||
metal: "metal-stablediffusion-ggml-development"
|
||||
nvidia-cuda-13: "cuda13-stablediffusion-ggml-development"
|
||||
nvidia-cuda-12: "cuda12-stablediffusion-ggml-development"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-arm64-stablediffusion-ggml-development"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-stablediffusion-ggml-development"
|
||||
- !!merge <<: *neutts
|
||||
name: "cpu-neutts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-neutts"
|
||||
@@ -1356,6 +1462,97 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-ik-llama-cpp"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-ik-llama-cpp
|
||||
## turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "cpu-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "cpu-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "cuda12-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "cuda12-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "cuda13-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "cuda13-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "rocm-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "rocm-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "intel-sycl-f32-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f32-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "intel-sycl-f32-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f32-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "intel-sycl-f16-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sycl-f16-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "intel-sycl-f16-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sycl-f16-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "vulkan-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-vulkan-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "vulkan-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-vulkan-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "nvidia-l4t-arm64-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "nvidia-l4t-arm64-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "cuda13-nvidia-l4t-arm64-turboquant"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-turboquant
|
||||
- !!merge <<: *turboquant
|
||||
name: "cuda13-nvidia-l4t-arm64-turboquant-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-turboquant"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-turboquant
|
||||
## whisper
|
||||
- !!merge <<: *whispercpp
|
||||
name: "nvidia-l4t-arm64-whisper"
|
||||
@@ -1563,6 +1760,7 @@
|
||||
nvidia: "cuda12-vllm-development"
|
||||
amd: "rocm-vllm-development"
|
||||
intel: "intel-vllm-development"
|
||||
cpu: "cpu-vllm-development"
|
||||
- !!merge <<: *vllm
|
||||
name: "cuda12-vllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vllm"
|
||||
@@ -1578,6 +1776,11 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-vllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-vllm
|
||||
- !!merge <<: *vllm
|
||||
name: "cpu-vllm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-vllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-vllm
|
||||
- !!merge <<: *vllm
|
||||
name: "cuda12-vllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vllm"
|
||||
@@ -1593,6 +1796,59 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-vllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-vllm
|
||||
- !!merge <<: *vllm
|
||||
name: "cpu-vllm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-vllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-vllm
|
||||
# sglang
|
||||
- !!merge <<: *sglang
|
||||
name: "sglang-development"
|
||||
capabilities:
|
||||
nvidia: "cuda12-sglang-development"
|
||||
amd: "rocm-sglang-development"
|
||||
intel: "intel-sglang-development"
|
||||
cpu: "cpu-sglang-development"
|
||||
- !!merge <<: *sglang
|
||||
name: "cuda12-sglang"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-sglang"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-sglang
|
||||
- !!merge <<: *sglang
|
||||
name: "rocm-sglang"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-sglang"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-sglang
|
||||
- !!merge <<: *sglang
|
||||
name: "intel-sglang"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sglang"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-intel-sglang
|
||||
- !!merge <<: *sglang
|
||||
name: "cpu-sglang"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-sglang"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-sglang
|
||||
- !!merge <<: *sglang
|
||||
name: "cuda12-sglang-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-sglang"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-sglang
|
||||
- !!merge <<: *sglang
|
||||
name: "rocm-sglang-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-sglang"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-sglang
|
||||
- !!merge <<: *sglang
|
||||
name: "intel-sglang-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sglang"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-sglang
|
||||
- !!merge <<: *sglang
|
||||
name: "cpu-sglang-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-sglang"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-sglang
|
||||
# vllm-omni
|
||||
- !!merge <<: *vllm-omni
|
||||
name: "vllm-omni-development"
|
||||
@@ -1848,6 +2104,15 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rerankers"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-rerankers
|
||||
## tinygrad
|
||||
## Single image — the meta anchor above carries the latest uri directly
|
||||
## since there is only one variant. The development entry below points at
|
||||
## the master tag.
|
||||
- !!merge <<: *tinygrad
|
||||
name: "tinygrad-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-tinygrad"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-tinygrad
|
||||
## Transformers
|
||||
- !!merge <<: *transformers
|
||||
name: "transformers-development"
|
||||
@@ -2479,7 +2744,6 @@
|
||||
name: "whisperx-development"
|
||||
capabilities:
|
||||
nvidia: "cuda12-whisperx-development"
|
||||
amd: "rocm-whisperx-development"
|
||||
metal: "metal-whisperx-development"
|
||||
default: "cpu-whisperx-development"
|
||||
nvidia-cuda-13: "cuda13-whisperx-development"
|
||||
@@ -2505,16 +2769,6 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "rocm-whisperx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "rocm-whisperx-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-whisperx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-whisperx
|
||||
- !!merge <<: *whisperx
|
||||
name: "cuda13-whisperx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-whisperx"
|
||||
|
||||
@@ -344,7 +344,16 @@ function ensureVenv() {
|
||||
|
||||
if [ ! -d "${EDIR}/venv" ]; then
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
"${interpreter}" -m venv --copies "${EDIR}/venv"
|
||||
# --copies is only needed when we will later relocate the venv via
|
||||
# _makeVenvPortable (PORTABLE_PYTHON=true). Some Python builds —
|
||||
# notably macOS system Python — refuse to create a venv with
|
||||
# --copies because the build doesn't support it. Fall back to
|
||||
# symlinks in that case.
|
||||
local venv_args=""
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
venv_args="--copies"
|
||||
fi
|
||||
"${interpreter}" -m venv ${venv_args} "${EDIR}/venv"
|
||||
source "${EDIR}/venv/bin/activate"
|
||||
"${interpreter}" -m pip install --upgrade pip
|
||||
else
|
||||
|
||||
100
backend/python/common/mlx_utils.py
Normal file
100
backend/python/common/mlx_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Shared utilities for the mlx and mlx-vlm gRPC backends.
|
||||
|
||||
These helpers wrap mlx-lm's and mlx-vlm's native tool-parser modules, which
|
||||
auto-detect the right parser from the model's chat template. Each tool
|
||||
module exposes ``tool_call_start``, ``tool_call_end`` and
|
||||
``parse_tool_call(text, tools) -> dict | list[dict]``.
|
||||
|
||||
The split-reasoning helper is generic enough to work with any think-start /
|
||||
think-end delimiter pair.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
|
||||
def split_reasoning(text, think_start, think_end):
|
||||
"""Split ``<think>...</think>`` blocks out of ``text``.
|
||||
|
||||
Returns ``(reasoning_content, remaining_text)``. When ``think_start`` is
|
||||
empty or not found, returns ``("", text)`` unchanged.
|
||||
"""
|
||||
if not think_start or not text or think_start not in text:
|
||||
return "", text
|
||||
pattern = re.compile(
|
||||
re.escape(think_start) + r"(.*?)" + re.escape(think_end or ""),
|
||||
re.DOTALL,
|
||||
)
|
||||
reasoning_parts = pattern.findall(text)
|
||||
if not reasoning_parts:
|
||||
return "", text
|
||||
remaining = pattern.sub("", text).strip()
|
||||
return "\n".join(p.strip() for p in reasoning_parts), remaining
|
||||
|
||||
|
||||
def parse_tool_calls(text, tool_module, tools):
|
||||
"""Extract tool calls from ``text`` using a mlx-lm tool module.
|
||||
|
||||
Ports the ``process_tool_calls`` logic from
|
||||
``mlx_vlm/server.py`` (v0.10 onwards). ``tool_module`` must expose
|
||||
``tool_call_start``, ``tool_call_end`` and ``parse_tool_call``.
|
||||
|
||||
Returns ``(calls, remaining_text)`` where ``calls`` is a list of dicts:
|
||||
|
||||
[{"index": int, "id": str, "name": str, "arguments": str (JSON)}]
|
||||
|
||||
and ``remaining_text`` is the free-form text with the tool call blocks
|
||||
removed. ``(calls, text)`` is returned unchanged if ``tool_module`` is
|
||||
``None`` or the start delimiter isn't present.
|
||||
"""
|
||||
if tool_module is None or not text:
|
||||
return [], text
|
||||
start = getattr(tool_module, "tool_call_start", None)
|
||||
end = getattr(tool_module, "tool_call_end", None)
|
||||
parse_fn = getattr(tool_module, "parse_tool_call", None)
|
||||
if not start or parse_fn is None or start not in text:
|
||||
return [], text
|
||||
|
||||
if end == "" or end is None:
|
||||
pattern = re.compile(
|
||||
re.escape(start) + r".*?(?:\n|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
else:
|
||||
pattern = re.compile(
|
||||
re.escape(start) + r".*?" + re.escape(end),
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
matches = pattern.findall(text)
|
||||
if not matches:
|
||||
return [], text
|
||||
|
||||
remaining = pattern.sub(" ", text).strip()
|
||||
calls = []
|
||||
for match in matches:
|
||||
call_body = match.strip().removeprefix(start)
|
||||
if end:
|
||||
call_body = call_body.removesuffix(end)
|
||||
call_body = call_body.strip()
|
||||
try:
|
||||
parsed = parse_fn(call_body, tools)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[mlx_utils] Invalid tool call: {call_body!r} ({e})",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
if not isinstance(parsed, list):
|
||||
parsed = [parsed]
|
||||
for tc in parsed:
|
||||
calls.append(
|
||||
{
|
||||
"index": len(calls),
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": (tc.get("name") or "").strip(),
|
||||
"arguments": json.dumps(tc.get("arguments", {}), ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
return calls, remaining
|
||||
65
backend/python/common/python_utils.py
Normal file
65
backend/python/common/python_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Generic utilities shared across Python gRPC backends.
|
||||
|
||||
These helpers don't depend on any specific inference framework and can be
|
||||
imported by any backend that needs to parse LocalAI gRPC options or build a
|
||||
chat-template-compatible message list from proto Message objects.
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
def parse_options(options_list):
|
||||
"""Parse Options[] list of ``key:value`` strings into a dict.
|
||||
|
||||
Supports type inference for common cases (bool, int, float). Unknown or
|
||||
mixed-case values are returned as strings.
|
||||
|
||||
Used by LoadModel to extract backend-specific options passed via
|
||||
``ModelOptions.Options`` in ``backend.proto``.
|
||||
"""
|
||||
opts = {}
|
||||
for opt in options_list:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
# Try type conversion
|
||||
if value.lower() in ("true", "false"):
|
||||
opts[key] = value.lower() == "true"
|
||||
else:
|
||||
try:
|
||||
opts[key] = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
opts[key] = float(value)
|
||||
except ValueError:
|
||||
opts[key] = value
|
||||
return opts
|
||||
|
||||
|
||||
def messages_to_dicts(proto_messages):
|
||||
"""Convert proto ``Message`` objects to dicts suitable for ``apply_chat_template``.
|
||||
|
||||
Handles: ``role``, ``content``, ``name``, ``tool_call_id``,
|
||||
``reasoning_content``, ``tool_calls`` (JSON string → Python list).
|
||||
|
||||
HuggingFace chat templates (and their MLX/vLLM wrappers) expect a list of
|
||||
plain dicts — proto Message objects don't work directly with Jinja, so
|
||||
this conversion is needed before every ``apply_chat_template`` call.
|
||||
"""
|
||||
result = []
|
||||
for msg in proto_messages:
|
||||
d = {"role": msg.role, "content": msg.content or ""}
|
||||
if msg.name:
|
||||
d["name"] = msg.name
|
||||
if msg.tool_call_id:
|
||||
d["tool_call_id"] = msg.tool_call_id
|
||||
if msg.reasoning_content:
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
return result
|
||||
43
backend/python/common/vllm_utils.py
Normal file
43
backend/python/common/vllm_utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""vLLM-specific helpers for the vllm and vllm-omni gRPC backends.
|
||||
|
||||
Generic helpers (``parse_options``, ``messages_to_dicts``) live in
|
||||
``python_utils`` and are re-exported here for backwards compatibility with
|
||||
existing imports in both backends.
|
||||
"""
|
||||
import sys
|
||||
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
|
||||
__all__ = ["parse_options", "messages_to_dicts", "setup_parsers"]
|
||||
|
||||
|
||||
def setup_parsers(opts):
|
||||
"""Return ``(tool_parser_cls, reasoning_parser_cls)`` from an opts dict.
|
||||
|
||||
Uses vLLM's native ``ToolParserManager`` / ``ReasoningParserManager``.
|
||||
Returns ``(None, None)`` if vLLM isn't installed or the requested
|
||||
parser name can't be resolved.
|
||||
"""
|
||||
tool_parser_cls = None
|
||||
reasoning_parser_cls = None
|
||||
|
||||
tool_parser_name = opts.get("tool_parser")
|
||||
reasoning_parser_name = opts.get("reasoning_parser")
|
||||
|
||||
if tool_parser_name:
|
||||
try:
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
tool_parser_cls = ToolParserManager.get_tool_parser(tool_parser_name)
|
||||
print(f"[vllm_utils] Loaded tool_parser: {tool_parser_name}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"[vllm_utils] Failed to load tool_parser {tool_parser_name}: {e}", file=sys.stderr)
|
||||
|
||||
if reasoning_parser_name:
|
||||
try:
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
reasoning_parser_cls = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
|
||||
print(f"[vllm_utils] Loaded reasoning_parser: {reasoning_parser_name}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"[vllm_utils] Failed to load reasoning_parser {reasoning_parser_name}: {e}", file=sys.stderr)
|
||||
|
||||
return tool_parser_cls, reasoning_parser_cls
|
||||
@@ -15,17 +15,21 @@ Two startup modes:
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
from python_utils import messages_to_dicts, parse_options as _shared_parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
|
||||
import backend_pb2
|
||||
@@ -62,37 +66,10 @@ def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None):
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
|
||||
def is_float(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_int(s):
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def parse_options(options):
|
||||
"""Parse key:value option strings into a dict."""
|
||||
result = {}
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
result[key] = value
|
||||
return result
|
||||
# Re-export the shared helper under the local name for back-compat with
|
||||
# any callers (and the existing distributed worker tests) that imported
|
||||
# parse_options directly from this module.
|
||||
parse_options = _shared_parse_options
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
@@ -188,6 +165,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
)
|
||||
print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr)
|
||||
|
||||
# Log auto-detected TokenizerWrapper capabilities. Same shape
|
||||
# as the mlx backend: has_tool_calling / has_thinking from
|
||||
# mlx_lm.tokenizer_utils + the start/end markers it sniffed
|
||||
# from the chat template / vocab.
|
||||
has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False))
|
||||
has_thinking = bool(getattr(self.tokenizer, "has_thinking", False))
|
||||
tcs = getattr(self.tokenizer, "tool_call_start", None)
|
||||
tce = getattr(self.tokenizer, "tool_call_end", None)
|
||||
print(
|
||||
f"[Rank 0] Tokenizer capabilities: has_tool_calling={has_tools} "
|
||||
f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
print(f"[Rank 0] Error loading model: {err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading model: {err}")
|
||||
@@ -201,7 +192,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
tokens = self._get_tokens_from_prompt(prompt_text)
|
||||
@@ -211,7 +202,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
|
||||
self.coordinator.broadcast_tokens(tokens)
|
||||
|
||||
max_tokens, sampler_params = self._build_generation_params(request)
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
|
||||
|
||||
if self.coordinator:
|
||||
gen_params = self.coordinator.broadcast_generation_params(
|
||||
@@ -222,6 +213,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
max_tokens = gen_params["max_tokens"]
|
||||
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
# Use prompt cache in single-node mode
|
||||
gen_kwargs = {}
|
||||
@@ -238,22 +230,44 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
tokens = remaining_tokens if remaining_tokens else cache_key
|
||||
|
||||
generated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
**gen_kwargs,
|
||||
):
|
||||
generated.append(response.text)
|
||||
last_response = response
|
||||
if cache_key is not None:
|
||||
cache_key.append(response.token)
|
||||
if stop_words and any(s in "".join(generated) for s in stop_words):
|
||||
break
|
||||
|
||||
if self.lru_cache is not None and cache_key is not None:
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8'))
|
||||
full_text = self._truncate_at_stop("".join(generated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
|
||||
return backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr)
|
||||
@@ -268,7 +282,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
tokens = self._get_tokens_from_prompt(prompt_text)
|
||||
@@ -278,7 +292,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
|
||||
self.coordinator.broadcast_tokens(tokens)
|
||||
|
||||
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
|
||||
request, default_max_tokens=512
|
||||
)
|
||||
|
||||
if self.coordinator:
|
||||
gen_params = self.coordinator.broadcast_generation_params(
|
||||
@@ -289,6 +305,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
max_tokens = gen_params["max_tokens"]
|
||||
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
# Use prompt cache in single-node mode
|
||||
gen_kwargs = {}
|
||||
@@ -304,17 +321,45 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
gen_kwargs['prompt_cache'] = prompt_cache
|
||||
tokens = remaining_tokens if remaining_tokens else cache_key
|
||||
|
||||
accumulated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
**gen_kwargs,
|
||||
):
|
||||
if cache_key is not None:
|
||||
cache_key.append(response.token)
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
accumulated.append(response.text)
|
||||
last_response = response
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(response.text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
|
||||
)
|
||||
if stop_words and any(s in "".join(accumulated) for s in stop_words):
|
||||
break
|
||||
|
||||
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content="",
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr)
|
||||
@@ -335,12 +380,74 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
context.set_details("Embeddings are not supported in the MLX distributed backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
async def TokenizeString(self, request, context):
|
||||
if not hasattr(self, "tokenizer") or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("tokenizer not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokens = self.tokenizer.encode(request.Prompt)
|
||||
if hasattr(tokens, "tolist"):
|
||||
tokens = tokens.tolist()
|
||||
tokens = list(tokens)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
async def Free(self, request, context):
|
||||
try:
|
||||
# If we're rank 0 of a distributed run, tell workers to shut
|
||||
# down their per-request loops first so they release the model.
|
||||
if self.coordinator is not None:
|
||||
try:
|
||||
from coordinator import CMD_SHUTDOWN
|
||||
self.coordinator.broadcast_command(CMD_SHUTDOWN)
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] failed to broadcast shutdown: {e}", file=sys.stderr)
|
||||
if hasattr(self, "model"):
|
||||
del self.model
|
||||
if hasattr(self, "tokenizer"):
|
||||
del self.tokenizer
|
||||
if self.lru_cache is not None:
|
||||
try:
|
||||
self.lru_cache.clear()
|
||||
except Exception:
|
||||
pass
|
||||
self.lru_cache = None
|
||||
self.coordinator = None
|
||||
self.group = None
|
||||
gc.collect()
|
||||
try:
|
||||
import mlx.core as mx # type: ignore
|
||||
if hasattr(mx, "clear_cache"):
|
||||
mx.clear_cache()
|
||||
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
|
||||
mx.metal.clear_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="MLX distributed model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
def _prepare_prompt(self, request):
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages]
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
messages = messages_to_dicts(request.Messages)
|
||||
kwargs = {"tokenize": False, "add_generation_prompt": True}
|
||||
if request.Tools:
|
||||
try:
|
||||
kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if request.Metadata.get("enable_thinking", "").lower() == "true":
|
||||
kwargs["enable_thinking"] = True
|
||||
try:
|
||||
return self.tokenizer.apply_chat_template(messages, **kwargs)
|
||||
except TypeError:
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return request.Prompt
|
||||
|
||||
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
|
||||
@@ -349,6 +456,82 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return tokens.tolist()
|
||||
return list(tokens)
|
||||
|
||||
def _tool_module_from_tokenizer(self):
|
||||
"""Same shim as the mlx backend: fall back to json.loads when the
|
||||
installed mlx-lm doesn't expose a tool_parser callable on the
|
||||
wrapper (true on 0.29.x — only HEAD ships parsers)."""
|
||||
start = getattr(self.tokenizer, "tool_call_start", None)
|
||||
end = getattr(self.tokenizer, "tool_call_end", None)
|
||||
if not start:
|
||||
return None
|
||||
parse_fn = getattr(self.tokenizer, "tool_parser", None)
|
||||
if parse_fn is None:
|
||||
def parse_fn(body, tools): # noqa: E306
|
||||
return json.loads(body.strip())
|
||||
return types.SimpleNamespace(
|
||||
tool_call_start=start,
|
||||
tool_call_end=end or "",
|
||||
parse_tool_call=parse_fn,
|
||||
)
|
||||
|
||||
def _truncate_at_stop(self, text, stop_words):
|
||||
if not stop_words:
|
||||
return text
|
||||
earliest = len(text)
|
||||
for stop in stop_words:
|
||||
if not stop:
|
||||
continue
|
||||
idx = text.find(stop)
|
||||
if idx >= 0 and idx < earliest:
|
||||
earliest = idx
|
||||
return text[:earliest] if earliest < len(text) else text
|
||||
|
||||
def _finalize_output(self, request, generated_text, last_response):
|
||||
content = generated_text
|
||||
reasoning_content = ""
|
||||
if getattr(self.tokenizer, "has_thinking", False):
|
||||
think_start = getattr(self.tokenizer, "think_start", "") or ""
|
||||
think_end = getattr(self.tokenizer, "think_end", "") or ""
|
||||
reasoning_content, content = split_reasoning(content, think_start, think_end)
|
||||
|
||||
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
|
||||
tool_module = None
|
||||
if getattr(self.tokenizer, "has_tool_calling", False):
|
||||
tool_module = self._tool_module_from_tokenizer()
|
||||
if tool_module is not None:
|
||||
parsed_tools = None
|
||||
if request.Tools:
|
||||
try:
|
||||
parsed_tools = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
parsed_tools = None
|
||||
calls, content = parse_tool_calls(content, tool_module, parsed_tools)
|
||||
for c in calls:
|
||||
tool_calls_proto.append(
|
||||
backend_pb2.ToolCallDelta(
|
||||
index=c["index"], id=c["id"], name=c["name"], arguments=c["arguments"],
|
||||
)
|
||||
)
|
||||
|
||||
prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
|
||||
completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
|
||||
|
||||
logprobs_bytes = b""
|
||||
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
|
||||
try:
|
||||
lp = getattr(last_response, "logprobs", None)
|
||||
if lp is not None:
|
||||
token_id = int(getattr(last_response, "token", 0) or 0)
|
||||
token_text = self.tokenizer.decode([token_id]) if token_id else ""
|
||||
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
|
||||
logprobs_bytes = json.dumps(
|
||||
{"content": [{"token": token_text, "logprob": top_logprob}]}
|
||||
).encode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Logprobs extraction failed: {e}", file=sys.stderr)
|
||||
|
||||
return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes
|
||||
|
||||
def _build_generation_params(self, request, default_max_tokens=200):
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -373,6 +556,22 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
'xtc_probability': 0.0,
|
||||
}
|
||||
|
||||
# Logits processor parameters — pulled from the request and
|
||||
# forwarded to make_logits_processors. Rank 0 is the only rank
|
||||
# running the sampler so we don't need to broadcast these to
|
||||
# workers (workers participate in the pipeline-parallel forward
|
||||
# pass only).
|
||||
logits_params = {}
|
||||
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
|
||||
if repetition_penalty and repetition_penalty != 1.0:
|
||||
logits_params['repetition_penalty'] = repetition_penalty
|
||||
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
|
||||
if presence_penalty:
|
||||
logits_params['presence_penalty'] = presence_penalty
|
||||
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
|
||||
if frequency_penalty:
|
||||
logits_params['frequency_penalty'] = frequency_penalty
|
||||
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
@@ -392,9 +591,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
for opt_key, param_key in option_mapping.items():
|
||||
if opt_key in self.options:
|
||||
sampler_params[param_key] = self.options[opt_key]
|
||||
for opt_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
|
||||
if opt_key in self.options:
|
||||
logits_params[opt_key] = self.options[opt_key]
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
stop_words = list(getattr(request, 'StopPrompts', []) or [])
|
||||
return max_tokens, sampler_params, logits_params, stop_words
|
||||
|
||||
# XTC special tokens
|
||||
xtc_special_tokens = []
|
||||
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
@@ -6,6 +9,12 @@ import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
# Make the shared helpers importable so we can unit-test them without a
|
||||
# running gRPC server.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -85,3 +94,44 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.fail("sampling params service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestSharedHelpers(unittest.TestCase):
|
||||
"""Server-less unit tests for the helpers the mlx-distributed backend depends on."""
|
||||
|
||||
def test_parse_options_typed(self):
|
||||
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true"])
|
||||
self.assertEqual(opts["temperature"], 0.7)
|
||||
self.assertEqual(opts["max_tokens"], 128)
|
||||
self.assertIs(opts["trust"], True)
|
||||
|
||||
def test_messages_to_dicts_roundtrip(self):
|
||||
msgs = [
|
||||
backend_pb2.Message(role="user", content="hi"),
|
||||
backend_pb2.Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
|
||||
),
|
||||
backend_pb2.Message(role="tool", content="42", tool_call_id="call_1", name="f"),
|
||||
]
|
||||
out = messages_to_dicts(msgs)
|
||||
self.assertEqual(out[0], {"role": "user", "content": "hi"})
|
||||
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
|
||||
self.assertEqual(out[2]["tool_call_id"], "call_1")
|
||||
|
||||
def test_split_reasoning(self):
|
||||
r, c = split_reasoning("<think>plan</think>final", "<think>", "</think>")
|
||||
self.assertEqual(r, "plan")
|
||||
self.assertEqual(c, "final")
|
||||
|
||||
def test_parse_tool_calls_with_shim(self):
|
||||
tm = types.SimpleNamespace(
|
||||
tool_call_start="<tool_call>",
|
||||
tool_call_end="</tool_call>",
|
||||
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
|
||||
)
|
||||
calls, remaining = parse_tool_calls("<tool_call>Paris</tool_call>", tm, tools=None)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0]["name"], "get_weather")
|
||||
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
from typing import List
|
||||
import time
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
@@ -15,30 +18,18 @@ import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
from mlx_vlm import load, generate, stream_generate
|
||||
from mlx_vlm import load, stream_generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config, load_image
|
||||
from mlx_vlm.tool_parsers import _infer_tool_parser, load_tool_module
|
||||
from mlx_vlm.utils import load_config
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
import mlx.core as mx
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
@@ -78,36 +69,52 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr)
|
||||
print(f"Request: {request}", file=sys.stderr)
|
||||
|
||||
# Parse options like in the diffusers backend
|
||||
options = request.Options
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We store all the options in a dict for later use
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
self.options[key] = value
|
||||
|
||||
|
||||
# Parse Options[] key:value strings into a typed dict
|
||||
self.options = parse_options(request.Options)
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
|
||||
# Load model and processor using MLX-VLM
|
||||
# mlx-vlm load function returns (model, processor) instead of (model, tokenizer)
|
||||
self.model, self.processor = load(request.Model)
|
||||
|
||||
|
||||
# Load model config for chat template support
|
||||
self.config = load_config(request.Model)
|
||||
|
||||
|
||||
# Auto-infer the tool parser from the chat template. mlx-vlm has
|
||||
# its own _infer_tool_parser that falls back to mlx-lm parsers.
|
||||
tokenizer = (
|
||||
self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
|
||||
)
|
||||
self.tool_module = None
|
||||
if hasattr(tokenizer, "chat_template"):
|
||||
try:
|
||||
parser_type = _infer_tool_parser(tokenizer.chat_template)
|
||||
if parser_type is not None:
|
||||
self.tool_module = load_tool_module(parser_type)
|
||||
print(
|
||||
f"[mlx-vlm] auto-detected tool parser: {parser_type}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[mlx-vlm] no tool parser matched the chat template",
|
||||
file=sys.stderr,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[mlx-vlm] failed to load tool parser: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Reasoning tokens — check if the tokenizer advertises thinking
|
||||
# markers. Fall back to empty strings (split_reasoning no-ops).
|
||||
self.think_start = getattr(tokenizer, "think_start", "") or ""
|
||||
self.think_end = getattr(tokenizer, "think_end", "") or ""
|
||||
self.has_thinking = bool(
|
||||
getattr(tokenizer, "has_thinking", False) or self.think_start
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}")
|
||||
@@ -128,63 +135,72 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
temp_files = []
|
||||
try:
|
||||
# Process images and audios from request
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
|
||||
# Process images
|
||||
if request.Images:
|
||||
for img_data in request.Images:
|
||||
img_path = self.load_image_from_base64(img_data)
|
||||
if img_path:
|
||||
image_paths.append(img_path)
|
||||
temp_files.append(img_path)
|
||||
|
||||
# Process audios
|
||||
if request.Audios:
|
||||
for audio_data in request.Audios:
|
||||
audio_path = self.load_audio_from_base64(audio_data)
|
||||
if audio_path:
|
||||
audio_paths.append(audio_path)
|
||||
temp_files.append(audio_path)
|
||||
|
||||
# Prepare the prompt with multimodal information
|
||||
prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, generation_params = self._build_generation_params(request)
|
||||
|
||||
print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
|
||||
print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
|
||||
|
||||
# Generate text using MLX-VLM with multimodal inputs
|
||||
response = generate(
|
||||
image_paths, audio_paths = self._collect_media(request, temp_files)
|
||||
|
||||
prompt = self._prepare_prompt(
|
||||
request,
|
||||
num_images=len(image_paths),
|
||||
num_audios=len(audio_paths),
|
||||
)
|
||||
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
print(
|
||||
f"Generating text with MLX-VLM - max_tokens: {max_tokens}, "
|
||||
f"images: {len(image_paths)}, audios: {len(audio_paths)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
accumulated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
prompt=prompt,
|
||||
image=image_paths if image_paths else None,
|
||||
audio=audio_paths if audio_paths else None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=generation_params.get('temp', 0.6),
|
||||
top_p=generation_params.get('top_p', 1.0),
|
||||
verbose=False
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
):
|
||||
accumulated.append(response.text)
|
||||
last_response = response
|
||||
if stop_words and any(s in "".join(accumulated) for s in stop_words):
|
||||
break
|
||||
|
||||
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
|
||||
|
||||
|
||||
return backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Generation failed: {str(e)}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
self.cleanup_temp_files(temp_files)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""
|
||||
A gRPC method that calculates embeddings for a given sentence.
|
||||
|
||||
|
||||
Note: MLX-VLM doesn't support embeddings directly. This method returns an error.
|
||||
|
||||
Args:
|
||||
@@ -199,6 +215,79 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
context.set_details("Embeddings are not supported in the MLX-VLM backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
def _collect_media(self, request, temp_files):
|
||||
"""Decode base64 Images and Audios into temp file paths.
|
||||
|
||||
Appends every temp file to ``temp_files`` so the finally block can
|
||||
clean up even on mid-generation errors.
|
||||
"""
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
if request.Images:
|
||||
for img_data in request.Images:
|
||||
img_path = self.load_image_from_base64(img_data)
|
||||
if img_path:
|
||||
image_paths.append(img_path)
|
||||
temp_files.append(img_path)
|
||||
if request.Audios:
|
||||
for audio_data in request.Audios:
|
||||
audio_path = self.load_audio_from_base64(audio_data)
|
||||
if audio_path:
|
||||
audio_paths.append(audio_path)
|
||||
temp_files.append(audio_path)
|
||||
return image_paths, audio_paths
|
||||
|
||||
async def TokenizeString(self, request, context):
|
||||
"""Tokenize ``request.Prompt`` via the processor's tokenizer."""
|
||||
if not hasattr(self, "processor") or self.processor is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("processor not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokenizer = (
|
||||
self.processor.tokenizer
|
||||
if hasattr(self.processor, "tokenizer")
|
||||
else self.processor
|
||||
)
|
||||
tokens = tokenizer.encode(request.Prompt)
|
||||
if hasattr(tokens, "tolist"):
|
||||
tokens = tokens.tolist()
|
||||
tokens = list(tokens)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
async def Free(self, request, context):
|
||||
"""Drop the loaded model, processor and tool module."""
|
||||
try:
|
||||
if hasattr(self, "model"):
|
||||
del self.model
|
||||
if hasattr(self, "processor"):
|
||||
del self.processor
|
||||
if hasattr(self, "config"):
|
||||
del self.config
|
||||
self.tool_module = None
|
||||
gc.collect()
|
||||
# mlx.clear_cache (mlx >= 0.30) supersedes mlx.metal.clear_cache.
|
||||
try:
|
||||
if hasattr(mx, "clear_cache"):
|
||||
mx.clear_cache()
|
||||
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
|
||||
mx.metal.clear_cache()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import torch # type: ignore
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="MLX-VLM model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support.
|
||||
@@ -212,36 +301,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
temp_files = []
|
||||
try:
|
||||
# Process images and audios from request
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
|
||||
# Process images
|
||||
if request.Images:
|
||||
for img_data in request.Images:
|
||||
img_path = self.load_image_from_base64(img_data)
|
||||
if img_path:
|
||||
image_paths.append(img_path)
|
||||
temp_files.append(img_path)
|
||||
|
||||
# Process audios
|
||||
if request.Audios:
|
||||
for audio_data in request.Audios:
|
||||
audio_path = self.load_audio_from_base64(audio_data)
|
||||
if audio_path:
|
||||
audio_paths.append(audio_path)
|
||||
temp_files.append(audio_path)
|
||||
|
||||
# Prepare the prompt with multimodal information
|
||||
prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
|
||||
print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
|
||||
print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
|
||||
|
||||
# Stream text generation using MLX-VLM with multimodal inputs
|
||||
image_paths, audio_paths = self._collect_media(request, temp_files)
|
||||
|
||||
prompt = self._prepare_prompt(
|
||||
request,
|
||||
num_images=len(image_paths),
|
||||
num_audios=len(audio_paths),
|
||||
)
|
||||
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
|
||||
request, default_max_tokens=512
|
||||
)
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
print(
|
||||
f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, "
|
||||
f"images: {len(image_paths)}, audios: {len(audio_paths)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
accumulated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
@@ -249,77 +330,91 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
image=image_paths if image_paths else None,
|
||||
audio=audio_paths if audio_paths else None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=generation_params.get('temp', 0.6),
|
||||
top_p=generation_params.get('top_p', 1.0),
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
):
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
|
||||
accumulated.append(response.text)
|
||||
last_response = response
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(response.text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
|
||||
)
|
||||
if stop_words and any(s in "".join(accumulated) for s in stop_words):
|
||||
break
|
||||
|
||||
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content="",
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Streaming generation failed: {str(e)}")
|
||||
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
self.cleanup_temp_files(temp_files)
|
||||
|
||||
def _build_template_kwargs(self, request, num_images, num_audios):
|
||||
"""Collect kwargs for ``apply_chat_template`` that survive model variants."""
|
||||
kwargs = {"num_images": num_images, "num_audios": num_audios}
|
||||
if request.Tools:
|
||||
try:
|
||||
kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if request.Metadata.get("enable_thinking", "").lower() == "true":
|
||||
kwargs["enable_thinking"] = True
|
||||
return kwargs
|
||||
|
||||
def _apply_template(self, request, messages, num_images, num_audios):
|
||||
kwargs = self._build_template_kwargs(request, num_images, num_audios)
|
||||
try:
|
||||
return apply_chat_template(self.processor, self.config, messages, **kwargs)
|
||||
except TypeError:
|
||||
# Fallback for older mlx-vlm versions that reject tools=/enable_thinking=
|
||||
return apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios,
|
||||
)
|
||||
|
||||
def _prepare_prompt(self, request, num_images=0, num_audios=0):
|
||||
"""
|
||||
Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs.
|
||||
|
||||
Args:
|
||||
request: The gRPC request containing prompt and message information.
|
||||
num_images: Number of images in the request.
|
||||
num_audios: Number of audio files in the request.
|
||||
|
||||
Returns:
|
||||
str: The prepared prompt.
|
||||
Prepare the prompt for MLX-VLM generation, handling chat templates and
|
||||
multimodal inputs. Forwards tool definitions and enable_thinking when
|
||||
present on the request.
|
||||
"""
|
||||
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
# Convert gRPC messages to the format expected by apply_chat_template
|
||||
messages = []
|
||||
for msg in request.Messages:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# Use mlx-vlm's apply_chat_template which handles multimodal inputs
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
elif request.Prompt:
|
||||
# If we have a direct prompt but also have images/audio, we need to format it properly
|
||||
messages = messages_to_dicts(request.Messages)
|
||||
return self._apply_template(request, messages, num_images, num_audios)
|
||||
|
||||
if request.Prompt:
|
||||
if num_images > 0 or num_audios > 0:
|
||||
# Create a simple message structure for multimodal prompt
|
||||
messages = [{"role": "user", "content": request.Prompt}]
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return request.Prompt
|
||||
else:
|
||||
# Fallback to empty prompt with multimodal template if we have media
|
||||
if num_images > 0 or num_audios > 0:
|
||||
messages = [{"role": "user", "content": ""}]
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return ""
|
||||
return self._apply_template(request, messages, num_images, num_audios)
|
||||
return request.Prompt
|
||||
|
||||
# Fallback to empty prompt with multimodal template if we have media
|
||||
if num_images > 0 or num_audios > 0:
|
||||
messages = [{"role": "user", "content": ""}]
|
||||
return self._apply_template(request, messages, num_images, num_audios)
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
@@ -327,62 +422,122 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def _build_generation_params(self, request, default_max_tokens=200):
|
||||
"""
|
||||
Build generation parameters from request attributes and options for MLX-VLM.
|
||||
|
||||
Args:
|
||||
request: The gRPC request.
|
||||
default_max_tokens: Default max_tokens if not specified.
|
||||
Build generation parameters from request attributes and options.
|
||||
|
||||
Returns:
|
||||
tuple: (max_tokens, generation_params dict)
|
||||
tuple: (max_tokens, sampler_params, logits_params, stop_words)
|
||||
"""
|
||||
# Extract max_tokens
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens)
|
||||
if max_tokens == 0:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
# Extract generation parameters from request attributes
|
||||
temp = getattr(request, 'Temperature', 0.0)
|
||||
if temp == 0.0:
|
||||
temp = 0.6 # Default temperature
|
||||
|
||||
top_p = getattr(request, 'TopP', 0.0)
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0 # Default top_p
|
||||
|
||||
# Initialize generation parameters for MLX-VLM
|
||||
generation_params = {
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens) or default_max_tokens
|
||||
|
||||
temp = getattr(request, 'Temperature', 0.0) or 0.6
|
||||
top_p = getattr(request, 'TopP', 0.0) or 1.0
|
||||
min_p = getattr(request, 'MinP', 0.0) or 0.0
|
||||
top_k = getattr(request, 'TopK', 0) or 0
|
||||
|
||||
sampler_params = {
|
||||
'temp': temp,
|
||||
'top_p': top_p,
|
||||
'min_p': min_p,
|
||||
'top_k': top_k,
|
||||
}
|
||||
|
||||
# Add seed if specified
|
||||
|
||||
logits_params = {}
|
||||
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
|
||||
if repetition_penalty and repetition_penalty != 1.0:
|
||||
logits_params['repetition_penalty'] = repetition_penalty
|
||||
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
|
||||
if presence_penalty:
|
||||
logits_params['presence_penalty'] = presence_penalty
|
||||
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
|
||||
if frequency_penalty:
|
||||
logits_params['frequency_penalty'] = frequency_penalty
|
||||
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Override with options if available
|
||||
|
||||
if hasattr(self, 'options'):
|
||||
# Max tokens from options
|
||||
if 'max_tokens' in self.options:
|
||||
max_tokens = self.options['max_tokens']
|
||||
|
||||
# Generation parameters from options
|
||||
param_option_mapping = {
|
||||
'temp': 'temp',
|
||||
'temperature': 'temp', # alias
|
||||
'top_p': 'top_p',
|
||||
option_mapping = {
|
||||
'temp': 'temp', 'temperature': 'temp',
|
||||
'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k',
|
||||
}
|
||||
|
||||
for option_key, param_key in param_option_mapping.items():
|
||||
for option_key, param_key in option_mapping.items():
|
||||
if option_key in self.options:
|
||||
generation_params[param_key] = self.options[option_key]
|
||||
|
||||
# Handle seed from options
|
||||
sampler_params[param_key] = self.options[option_key]
|
||||
for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
|
||||
if option_key in self.options:
|
||||
logits_params[option_key] = self.options[option_key]
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
return max_tokens, generation_params
|
||||
|
||||
stop_words = list(getattr(request, 'StopPrompts', []) or [])
|
||||
return max_tokens, sampler_params, logits_params, stop_words
|
||||
|
||||
def _finalize_output(self, request, generated_text, last_response):
|
||||
"""Split reasoning + tool calls out of generated_text and return the
|
||||
tuple consumed by Reply-builders."""
|
||||
content = generated_text
|
||||
reasoning_content = ""
|
||||
|
||||
if getattr(self, "has_thinking", False):
|
||||
reasoning_content, content = split_reasoning(content, self.think_start, self.think_end)
|
||||
|
||||
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
|
||||
if self.tool_module is not None:
|
||||
parsed_tools = None
|
||||
if request.Tools:
|
||||
try:
|
||||
parsed_tools = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
parsed_tools = None
|
||||
calls, content = parse_tool_calls(content, self.tool_module, parsed_tools)
|
||||
for c in calls:
|
||||
tool_calls_proto.append(
|
||||
backend_pb2.ToolCallDelta(
|
||||
index=c["index"],
|
||||
id=c["id"],
|
||||
name=c["name"],
|
||||
arguments=c["arguments"],
|
||||
)
|
||||
)
|
||||
|
||||
prompt_tokens = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
|
||||
completion_tokens = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
|
||||
|
||||
logprobs_bytes = b""
|
||||
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
|
||||
try:
|
||||
lp = getattr(last_response, "logprobs", None)
|
||||
if lp is not None:
|
||||
token_id = int(getattr(last_response, "token", 0) or 0)
|
||||
tokenizer = (
|
||||
self.processor.tokenizer
|
||||
if hasattr(self.processor, "tokenizer")
|
||||
else self.processor
|
||||
)
|
||||
token_text = tokenizer.decode([token_id]) if token_id else ""
|
||||
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
|
||||
logprobs_bytes = json.dumps(
|
||||
{"content": [{"token": token_text, "logprob": top_logprob}]}
|
||||
).encode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[mlx-vlm] Logprobs extraction failed: {e}", file=sys.stderr)
|
||||
|
||||
return content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes
|
||||
|
||||
def _truncate_at_stop(self, text, stop_words):
|
||||
if not stop_words:
|
||||
return text
|
||||
earliest = len(text)
|
||||
for stop in stop_words:
|
||||
if not stop:
|
||||
continue
|
||||
idx = text.find(stop)
|
||||
if idx >= 0 and idx < earliest:
|
||||
earliest = idx
|
||||
return text[:earliest] if earliest < len(text) else text
|
||||
|
||||
def load_image_from_base64(self, image_data: str):
|
||||
"""
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import grpc
|
||||
import backend_pb2_grpc
|
||||
import backend_pb2
|
||||
# Make the shared helpers importable so we can unit-test them without a
|
||||
# running gRPC server.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
@@ -143,4 +145,55 @@ class TestBackendServicer(unittest.TestCase):
|
||||
print(err)
|
||||
self.fail("Embedding service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestSharedHelpers(unittest.TestCase):
|
||||
"""Server-less unit tests for the helpers the mlx-vlm backend depends on."""
|
||||
|
||||
def test_parse_options_typed(self):
|
||||
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello"])
|
||||
self.assertEqual(opts["temperature"], 0.7)
|
||||
self.assertEqual(opts["max_tokens"], 128)
|
||||
self.assertIs(opts["trust"], True)
|
||||
self.assertEqual(opts["name"], "hello")
|
||||
|
||||
def test_messages_to_dicts_roundtrip(self):
|
||||
msgs = [
|
||||
backend_pb2.Message(role="user", content="hi"),
|
||||
backend_pb2.Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
|
||||
),
|
||||
backend_pb2.Message(
|
||||
role="tool",
|
||||
content="42",
|
||||
tool_call_id="call_1",
|
||||
name="f",
|
||||
),
|
||||
]
|
||||
out = messages_to_dicts(msgs)
|
||||
self.assertEqual(out[0], {"role": "user", "content": "hi"})
|
||||
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
|
||||
self.assertEqual(out[2]["tool_call_id"], "call_1")
|
||||
|
||||
def test_split_reasoning(self):
|
||||
r, c = split_reasoning("<think>plan</think>final", "<think>", "</think>")
|
||||
self.assertEqual(r, "plan")
|
||||
self.assertEqual(c, "final")
|
||||
|
||||
def test_parse_tool_calls_with_shim(self):
|
||||
tm = types.SimpleNamespace(
|
||||
tool_call_start="<tool_call>",
|
||||
tool_call_end="</tool_call>",
|
||||
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
|
||||
)
|
||||
calls, remaining = parse_tool_calls(
|
||||
"<tool_call>Paris</tool_call>",
|
||||
tm,
|
||||
tools=None,
|
||||
)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0]["name"], "get_weather")
|
||||
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
|
||||
@@ -2,11 +2,13 @@
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import types
|
||||
from typing import List
|
||||
import time
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
@@ -15,13 +17,13 @@ import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
from mlx_lm import load, generate, stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm import load, stream_generate
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||
import mlx.core as mx
|
||||
import base64
|
||||
import io
|
||||
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
|
||||
@@ -30,21 +32,6 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
@@ -78,46 +65,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
print(f"Loading MLX model: {request.Model}", file=sys.stderr)
|
||||
print(f"Request: {request}", file=sys.stderr)
|
||||
|
||||
# Parse options like in the diffusers backend
|
||||
options = request.Options
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We store all the options in a dict for later use
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
self.options[key] = value
|
||||
|
||||
|
||||
# Parse Options[] key:value strings into a typed dict (shared helper)
|
||||
self.options = parse_options(request.Options)
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
|
||||
# Build tokenizer config for MLX using options
|
||||
tokenizer_config = {}
|
||||
|
||||
|
||||
# Handle trust_remote_code from request or options
|
||||
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
|
||||
|
||||
# Handle EOS token from options
|
||||
if "eos_token" in self.options:
|
||||
tokenizer_config["eos_token"] = self.options["eos_token"]
|
||||
|
||||
|
||||
# Handle other tokenizer config options
|
||||
for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]:
|
||||
if key in self.options:
|
||||
tokenizer_config[key] = self.options[key]
|
||||
|
||||
|
||||
# Load model and tokenizer using MLX
|
||||
if tokenizer_config:
|
||||
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
|
||||
@@ -125,6 +93,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
else:
|
||||
self.model, self.tokenizer = load(request.Model)
|
||||
|
||||
# mlx_lm.load() returns a TokenizerWrapper that detects tool
|
||||
# calling and thinking markers from the chat template / vocab.
|
||||
# mlx-lm >= 0.30 also exposes a parser callable on the wrapper;
|
||||
# earlier versions don't (we fall back to json.loads inside
|
||||
# _tool_module_from_tokenizer below).
|
||||
has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False))
|
||||
has_thinking = bool(getattr(self.tokenizer, "has_thinking", False))
|
||||
tcs = getattr(self.tokenizer, "tool_call_start", None)
|
||||
tce = getattr(self.tokenizer, "tool_call_end", None)
|
||||
print(
|
||||
f"MLX tokenizer capabilities: has_tool_calling={has_tools} "
|
||||
f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Initialize thread-safe LRU prompt cache for efficient generation
|
||||
max_cache_entries = self.options.get("max_cache_entries", 10)
|
||||
self.max_kv_size = self.options.get("max_kv_size", None)
|
||||
@@ -134,7 +117,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
can_trim_fn=can_trim_prompt_cache,
|
||||
trim_fn=trim_prompt_cache,
|
||||
)
|
||||
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}")
|
||||
@@ -172,30 +155,58 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
remaining_tokens = cache_key
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, sampler_params = self._build_generation_params(request)
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
|
||||
|
||||
print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
|
||||
print(
|
||||
f"Generating text with MLX - max_tokens: {max_tokens}, "
|
||||
f"cache_hit: {len(remaining_tokens) < len(cache_key)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Create sampler with parameters
|
||||
# Create sampler and optional logits processors (penalties)
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
# Use stream_generate to track generated tokens for cache key
|
||||
# Use stream_generate to collect text + track tokens for cache key
|
||||
generated_text = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=remaining_tokens if remaining_tokens else cache_key,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
generated_text.append(response.text)
|
||||
cache_key.append(response.token)
|
||||
last_response = response
|
||||
# Early stop on user-provided stop sequences
|
||||
if stop_words and any(s in "".join(generated_text) for s in stop_words):
|
||||
break
|
||||
|
||||
# Insert completed cache
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8'))
|
||||
full_text = self._truncate_at_stop("".join(generated_text), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
|
||||
return backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX Predict: {e}", file=sys.stderr)
|
||||
@@ -206,7 +217,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Embedding(self, request, context):
|
||||
"""
|
||||
A gRPC method that calculates embeddings for a given sentence.
|
||||
|
||||
|
||||
Note: MLX-LM doesn't support embeddings directly. This method returns an error.
|
||||
|
||||
Args:
|
||||
@@ -221,6 +232,62 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
context.set_details("Embeddings are not supported in the MLX backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
async def TokenizeString(self, request, context):
|
||||
"""Tokenize ``request.Prompt`` using the loaded model's tokenizer."""
|
||||
if not hasattr(self, "tokenizer") or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("tokenizer not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokens = self.tokenizer.encode(request.Prompt)
|
||||
if hasattr(tokens, "tolist"):
|
||||
tokens = tokens.tolist()
|
||||
tokens = list(tokens)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
async def Free(self, request, context):
|
||||
"""Drop the loaded model, tokenizer and prompt cache.
|
||||
|
||||
Metal / CUDA memory is released via ``gc.collect()`` + the
|
||||
platform-specific cache clear hooks when available.
|
||||
"""
|
||||
try:
|
||||
if hasattr(self, "model"):
|
||||
del self.model
|
||||
if hasattr(self, "tokenizer"):
|
||||
del self.tokenizer
|
||||
if hasattr(self, "lru_cache") and self.lru_cache is not None:
|
||||
try:
|
||||
self.lru_cache.clear()
|
||||
except Exception:
|
||||
pass
|
||||
self.lru_cache = None
|
||||
gc.collect()
|
||||
# Metal: drop the cached allocator. mlx.clear_cache (mlx >= 0.30)
|
||||
# supersedes the now-deprecated mlx.metal.clear_cache.
|
||||
try:
|
||||
if hasattr(mx, "clear_cache"):
|
||||
mx.clear_cache()
|
||||
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
|
||||
mx.metal.clear_cache()
|
||||
except Exception:
|
||||
pass
|
||||
# CUDA: release the torch cache if a CUDA-backed mlx variant
|
||||
# happens to be installed alongside torch (best-effort).
|
||||
try:
|
||||
import torch # type: ignore
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="MLX model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results using MLX.
|
||||
@@ -251,24 +318,64 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
remaining_tokens = cache_key
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
|
||||
request, default_max_tokens=512
|
||||
)
|
||||
|
||||
print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
|
||||
print(
|
||||
f"Streaming text with MLX - max_tokens: {max_tokens}, "
|
||||
f"cache_hit: {len(remaining_tokens) < len(cache_key)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Create sampler with parameters
|
||||
# Create sampler and optional logits processors (penalties)
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
# Stream text generation using MLX with proper parameters
|
||||
accumulated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=remaining_tokens if remaining_tokens else cache_key,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
cache_key.append(response.token)
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
accumulated.append(response.text)
|
||||
last_response = response
|
||||
# Emit a content delta. Structured reasoning / tool parsing
|
||||
# happens on the final chunk so we don't fragment the state
|
||||
# machine in v1.
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(response.text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
|
||||
)
|
||||
# Early stop on user-provided stop sequences
|
||||
if stop_words and any(s in "".join(accumulated) for s in stop_words):
|
||||
break
|
||||
|
||||
# Final chunk: run reasoning + tool parsing on accumulated text
|
||||
# and emit the structured ChatDelta with token counts + logprobs.
|
||||
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content="",
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX PredictStream: {e}", file=sys.stderr)
|
||||
@@ -294,21 +401,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
Returns:
|
||||
str: The prepared prompt.
|
||||
"""
|
||||
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
|
||||
# If tokenizer template is enabled and messages are provided instead
|
||||
# of prompt, apply the tokenizer template (forwards tool definitions
|
||||
# and enable_thinking when the model supports them).
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
# Convert gRPC messages to the format expected by apply_chat_template
|
||||
messages = []
|
||||
for msg in request.Messages:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
messages = messages_to_dicts(request.Messages)
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return request.Prompt
|
||||
kwargs = {"tokenize": False, "add_generation_prompt": True}
|
||||
if request.Tools:
|
||||
try:
|
||||
kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
enable_thinking = request.Metadata.get("enable_thinking", "").lower()
|
||||
if enable_thinking == "true":
|
||||
kwargs["enable_thinking"] = True
|
||||
|
||||
try:
|
||||
return self.tokenizer.apply_chat_template(messages, **kwargs)
|
||||
except TypeError:
|
||||
# Fallback for tokenizers whose template doesn't accept
|
||||
# tools= or enable_thinking=.
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return request.Prompt
|
||||
|
||||
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
|
||||
"""
|
||||
@@ -338,18 +457,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
default_max_tokens: Default max_tokens if not specified.
|
||||
|
||||
Returns:
|
||||
tuple: (max_tokens, sampler_params dict)
|
||||
tuple: (max_tokens, sampler_params dict, logits_processor_params dict,
|
||||
stop_words list)
|
||||
"""
|
||||
# Extract max_tokens
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens)
|
||||
if max_tokens == 0:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
|
||||
# Extract sampler parameters from request attributes
|
||||
temp = getattr(request, 'Temperature', 0.0)
|
||||
if temp == 0.0:
|
||||
temp = 0.6 # Default temperature
|
||||
|
||||
|
||||
top_p = getattr(request, 'TopP', 0.0)
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0 # Default top_p
|
||||
@@ -369,18 +489,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
'xtc_threshold': 0.0,
|
||||
'xtc_probability': 0.0,
|
||||
}
|
||||
|
||||
|
||||
# Logits processor parameters — only set fields the request actually
|
||||
# provides so we can feed them unconditionally to make_logits_processors.
|
||||
logits_params = {}
|
||||
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
|
||||
if repetition_penalty and repetition_penalty != 1.0:
|
||||
logits_params['repetition_penalty'] = repetition_penalty
|
||||
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
|
||||
if presence_penalty:
|
||||
logits_params['presence_penalty'] = presence_penalty
|
||||
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
|
||||
if frequency_penalty:
|
||||
logits_params['frequency_penalty'] = frequency_penalty
|
||||
|
||||
# Add seed if specified
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
|
||||
|
||||
# Override with options if available
|
||||
if hasattr(self, 'options'):
|
||||
# Max tokens from options
|
||||
if 'max_tokens' in self.options:
|
||||
max_tokens = self.options['max_tokens']
|
||||
|
||||
|
||||
# Sampler parameters from options
|
||||
sampler_option_mapping = {
|
||||
'temp': 'temp',
|
||||
@@ -391,32 +524,142 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
'xtc_threshold': 'xtc_threshold',
|
||||
'xtc_probability': 'xtc_probability',
|
||||
}
|
||||
|
||||
|
||||
for option_key, param_key in sampler_option_mapping.items():
|
||||
if option_key in self.options:
|
||||
sampler_params[param_key] = self.options[option_key]
|
||||
|
||||
|
||||
# Logits processor overrides
|
||||
for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
|
||||
if option_key in self.options:
|
||||
logits_params[option_key] = self.options[option_key]
|
||||
|
||||
# Handle seed from options
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
|
||||
# Special tokens for XTC sampling (if tokenizer has eos_token_ids)
|
||||
xtc_special_tokens = []
|
||||
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
|
||||
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
|
||||
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
||||
xtc_special_tokens = [self.tokenizer.eos_token_id]
|
||||
|
||||
|
||||
# Add newline token if available
|
||||
try:
|
||||
newline_tokens = self.tokenizer.encode("\n")
|
||||
xtc_special_tokens.extend(newline_tokens)
|
||||
except:
|
||||
except Exception:
|
||||
pass # Skip if encoding fails
|
||||
|
||||
|
||||
sampler_params['xtc_special_tokens'] = xtc_special_tokens
|
||||
|
||||
return max_tokens, sampler_params
|
||||
|
||||
# Stop sequences are applied post-decode (mlx-lm doesn't have a
|
||||
# built-in stop-sequence sampler param). Preserve the list here.
|
||||
stop_words = list(getattr(request, 'StopPrompts', []) or [])
|
||||
|
||||
return max_tokens, sampler_params, logits_params, stop_words
|
||||
|
||||
def _tool_module_from_tokenizer(self):
|
||||
"""Build a duck-typed tool module from the TokenizerWrapper.
|
||||
|
||||
On mlx-lm >= 0.30 the wrapper exposes a ``tool_parser`` callable
|
||||
that's been resolved from the model's chat template. On older
|
||||
releases (e.g. 0.29.x) the wrapper only carries the start/end
|
||||
markers — fall back to ``json.loads`` of the body, which matches
|
||||
what ``mlx_lm.tool_parsers.json_tools.parse_tool_call`` does on
|
||||
HEAD and covers the only format 0.29 detects (``<tool_call>``).
|
||||
"""
|
||||
start = getattr(self.tokenizer, "tool_call_start", None)
|
||||
end = getattr(self.tokenizer, "tool_call_end", None)
|
||||
if not start:
|
||||
return None
|
||||
parse_fn = getattr(self.tokenizer, "tool_parser", None)
|
||||
if parse_fn is None:
|
||||
def parse_fn(body, tools): # noqa: E306 — local fallback
|
||||
return json.loads(body.strip())
|
||||
return types.SimpleNamespace(
|
||||
tool_call_start=start,
|
||||
tool_call_end=end or "",
|
||||
parse_tool_call=parse_fn,
|
||||
)
|
||||
|
||||
def _finalize_output(self, request, generated_text, last_response):
|
||||
"""Build a ChatDelta + token counts + logprobs from accumulated output.
|
||||
|
||||
Returns ``(content, reasoning_content, tool_calls_proto,
|
||||
prompt_token_count, completion_token_count, logprobs_bytes)``.
|
||||
"""
|
||||
content = generated_text
|
||||
reasoning_content = ""
|
||||
|
||||
if getattr(self.tokenizer, "has_thinking", False):
|
||||
think_start = getattr(self.tokenizer, "think_start", "") or ""
|
||||
think_end = getattr(self.tokenizer, "think_end", "") or ""
|
||||
reasoning_content, content = split_reasoning(content, think_start, think_end)
|
||||
|
||||
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
|
||||
tool_module = None
|
||||
if getattr(self.tokenizer, "has_tool_calling", False):
|
||||
tool_module = self._tool_module_from_tokenizer()
|
||||
if tool_module is not None:
|
||||
parsed_tools = None
|
||||
if request.Tools:
|
||||
try:
|
||||
parsed_tools = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
parsed_tools = None
|
||||
calls, content = parse_tool_calls(content, tool_module, parsed_tools)
|
||||
for c in calls:
|
||||
tool_calls_proto.append(
|
||||
backend_pb2.ToolCallDelta(
|
||||
index=c["index"],
|
||||
id=c["id"],
|
||||
name=c["name"],
|
||||
arguments=c["arguments"],
|
||||
)
|
||||
)
|
||||
|
||||
prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
|
||||
completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
|
||||
|
||||
logprobs_bytes = b""
|
||||
# Logprobs extraction — only when the request asked for them.
|
||||
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
|
||||
try:
|
||||
lp = getattr(last_response, "logprobs", None)
|
||||
if lp is not None:
|
||||
# GenerationResponse.logprobs on the last chunk is the
|
||||
# logprob distribution of the final token. Without a
|
||||
# per-token history we at minimum surface the last token's
|
||||
# top-1 logprob so clients get a non-empty field.
|
||||
token_id = int(getattr(last_response, "token", 0) or 0)
|
||||
token_text = self.tokenizer.decode([token_id]) if token_id else ""
|
||||
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
|
||||
logprobs_bytes = json.dumps(
|
||||
{
|
||||
"content": [
|
||||
{"token": token_text, "logprob": top_logprob}
|
||||
]
|
||||
}
|
||||
).encode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[mlx] Logprobs extraction failed: {e}", file=sys.stderr)
|
||||
|
||||
return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes
|
||||
|
||||
def _truncate_at_stop(self, text, stop_words):
|
||||
"""Truncate ``text`` at the first occurrence of any stop sequence."""
|
||||
if not stop_words:
|
||||
return text
|
||||
earliest = len(text)
|
||||
for stop in stop_words:
|
||||
if not stop:
|
||||
continue
|
||||
idx = text.find(stop)
|
||||
if idx >= 0 and idx < earliest:
|
||||
earliest = idx
|
||||
return text[:earliest] if earliest < len(text) else text
|
||||
|
||||
async def serve(address):
|
||||
# Start asyncio gRPC server
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import types
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
# Make the shared helpers importable so we can unit-test them without a
|
||||
# running gRPC server.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service.
|
||||
@@ -231,4 +240,104 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.tearDown()
|
||||
|
||||
|
||||
def test_tokenize_string(self):
|
||||
"""TokenizeString should return a non-empty token list for a known prompt."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||
)
|
||||
self.assertTrue(response.success)
|
||||
resp = stub.TokenizeString(backend_pb2.PredictOptions(Prompt="Hello, world"))
|
||||
self.assertGreater(resp.length, 0)
|
||||
self.assertEqual(len(list(resp.tokens)), resp.length)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TokenizeString service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_free(self):
|
||||
"""Free should release the model and not crash on subsequent calls."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||
)
|
||||
self.assertTrue(response.success)
|
||||
free_resp = stub.Free(backend_pb2.HealthMessage())
|
||||
self.assertTrue(free_resp.success)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Free service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestSharedHelpers(unittest.TestCase):
|
||||
"""Server-less unit tests for the helpers the mlx backend depends on."""
|
||||
|
||||
def test_parse_options_typed(self):
|
||||
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello", "no_colon_skipped"])
|
||||
self.assertEqual(opts["temperature"], 0.7)
|
||||
self.assertEqual(opts["max_tokens"], 128)
|
||||
self.assertIs(opts["trust"], True)
|
||||
self.assertEqual(opts["name"], "hello")
|
||||
self.assertNotIn("no_colon_skipped", opts)
|
||||
|
||||
def test_messages_to_dicts_roundtrip(self):
|
||||
# Build proto Message objects (via backend_pb2 to match real gRPC)
|
||||
msgs = [
|
||||
backend_pb2.Message(role="user", content="hi"),
|
||||
backend_pb2.Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
|
||||
),
|
||||
backend_pb2.Message(
|
||||
role="tool",
|
||||
content="42",
|
||||
tool_call_id="call_1",
|
||||
name="f",
|
||||
),
|
||||
]
|
||||
out = messages_to_dicts(msgs)
|
||||
self.assertEqual(out[0], {"role": "user", "content": "hi"})
|
||||
self.assertEqual(out[1]["role"], "assistant")
|
||||
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
|
||||
self.assertEqual(out[2]["tool_call_id"], "call_1")
|
||||
self.assertEqual(out[2]["name"], "f")
|
||||
|
||||
def test_split_reasoning(self):
|
||||
r, c = split_reasoning("<think>step 1\nstep 2</think>The answer is 42.", "<think>", "</think>")
|
||||
self.assertEqual(r, "step 1\nstep 2")
|
||||
self.assertEqual(c, "The answer is 42.")
|
||||
|
||||
def test_split_reasoning_no_marker(self):
|
||||
r, c = split_reasoning("just text", "<think>", "</think>")
|
||||
self.assertEqual(r, "")
|
||||
self.assertEqual(c, "just text")
|
||||
|
||||
def test_parse_tool_calls_with_shim(self):
|
||||
tm = types.SimpleNamespace(
|
||||
tool_call_start="<tool_call>",
|
||||
tool_call_end="</tool_call>",
|
||||
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
|
||||
)
|
||||
calls, remaining = parse_tool_calls(
|
||||
"Sure: <tool_call>Paris</tool_call>",
|
||||
tm,
|
||||
tools=None,
|
||||
)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0]["name"], "get_weather")
|
||||
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
|
||||
self.assertEqual(calls[0]["index"], 0)
|
||||
self.assertNotIn("<tool_call>", remaining)
|
||||
|
||||
|
||||
# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py
|
||||
17
backend/python/sglang/Makefile
Normal file
17
backend/python/sglang/Makefile
Normal file
@@ -0,0 +1,17 @@
|
||||
.PHONY: sglang
|
||||
sglang:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: sglang
|
||||
@echo "Running sglang..."
|
||||
bash run.sh
|
||||
@echo "sglang run."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
502
backend/python/sglang/backend.py
Normal file
502
backend/python/sglang/backend.py
Normal file
@@ -0,0 +1,502 @@
|
||||
#!/usr/bin/env python3
|
||||
"""LocalAI gRPC backend for sglang.
|
||||
|
||||
Wraps sglang's async Engine API behind the Backend gRPC contract defined
|
||||
in backend.proto. Mirrors the structure of backend/python/vllm/backend.py
|
||||
so that the two backends stay behavior-equivalent at the protocol level.
|
||||
|
||||
The streaming path applies sglang's per-request FunctionCallParser and
|
||||
ReasoningParser so tool_calls and reasoning_content are emitted
|
||||
incrementally inside ChatDelta, which is a capability sglang exposes
|
||||
natively and vLLM does not.
|
||||
"""
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
import uuid
|
||||
import base64
|
||||
import io
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
|
||||
# sglang imports. Engine is the stable public entry point; parser modules
|
||||
# are wrapped in try/except so older / leaner installs that omit them
|
||||
# still load the backend for plain text generation.
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
|
||||
try:
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
# sglang's FunctionCallParser expects a list of pydantic Tool objects
|
||||
# (protocol.Tool with .function.name), not plain dicts. Wrap at the
|
||||
# request boundary to match.
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool as SglTool
|
||||
HAS_TOOL_PARSERS = True
|
||||
except Exception:
|
||||
FunctionCallParser = None # type: ignore
|
||||
SglTool = None # type: ignore
|
||||
HAS_TOOL_PARSERS = False
|
||||
|
||||
try:
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
HAS_REASONING_PARSERS = True
|
||||
except Exception:
|
||||
ReasoningParser = None # type: ignore
|
||||
HAS_REASONING_PARSERS = False
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
HAS_TRANSFORMERS = True
|
||||
except Exception:
|
||||
AutoTokenizer = None # type: ignore
|
||||
HAS_TRANSFORMERS = False
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""gRPC servicer implementing the Backend service for sglang."""
|
||||
|
||||
def _parse_options(self, options_list) -> Dict[str, str]:
|
||||
opts: Dict[str, str] = {}
|
||||
for opt in options_list:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
opts[key.strip()] = value.strip()
|
||||
return opts
|
||||
|
||||
def _messages_to_dicts(self, messages) -> List[dict]:
|
||||
result: List[dict] = []
|
||||
for msg in messages:
|
||||
d = {"role": msg.role, "content": msg.content or ""}
|
||||
if msg.name:
|
||||
d["name"] = msg.name
|
||||
if msg.tool_call_id:
|
||||
d["tool_call_id"] = msg.tool_call_id
|
||||
if msg.reasoning_content:
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
async def LoadModel(self, request, context):
|
||||
engine_kwargs = {"model_path": request.Model}
|
||||
|
||||
if request.Quantization:
|
||||
engine_kwargs["quantization"] = request.Quantization
|
||||
if request.LoadFormat:
|
||||
engine_kwargs["load_format"] = request.LoadFormat
|
||||
if request.GPUMemoryUtilization:
|
||||
engine_kwargs["mem_fraction_static"] = float(request.GPUMemoryUtilization)
|
||||
if request.TrustRemoteCode:
|
||||
engine_kwargs["trust_remote_code"] = True
|
||||
if request.EnforceEager:
|
||||
engine_kwargs["disable_cuda_graph"] = True
|
||||
if request.TensorParallelSize:
|
||||
engine_kwargs["tp_size"] = int(request.TensorParallelSize)
|
||||
if request.MaxModelLen:
|
||||
engine_kwargs["context_length"] = int(request.MaxModelLen)
|
||||
if request.DType:
|
||||
engine_kwargs["dtype"] = request.DType
|
||||
|
||||
opts = self._parse_options(request.Options)
|
||||
|
||||
# Cache parser names — actual parser instances are created per
|
||||
# request because sglang's parsers are stateful.
|
||||
self.tool_parser_name: Optional[str] = opts.get("tool_parser") or None
|
||||
self.reasoning_parser_name: Optional[str] = opts.get("reasoning_parser") or None
|
||||
|
||||
# Also hand the parser names to sglang's engine so its HTTP/OAI
|
||||
# paths work identically if someone hits the engine directly.
|
||||
if self.tool_parser_name:
|
||||
engine_kwargs["tool_call_parser"] = self.tool_parser_name
|
||||
if self.reasoning_parser_name:
|
||||
engine_kwargs["reasoning_parser"] = self.reasoning_parser_name
|
||||
|
||||
try:
|
||||
self.llm = Engine(**engine_kwargs)
|
||||
except Exception as err:
|
||||
print(f"sglang Engine init failed: {err!r}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"{err!r}")
|
||||
|
||||
# sglang does not expose a uniform get_tokenizer() off Engine.
|
||||
# Use transformers directly — same path sglang uses internally.
|
||||
self.tokenizer = None
|
||||
if HAS_TRANSFORMERS:
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
request.Model,
|
||||
trust_remote_code=bool(request.TrustRemoteCode),
|
||||
)
|
||||
except Exception as err:
|
||||
print(f"AutoTokenizer load failed (non-fatal): {err!r}", file=sys.stderr)
|
||||
|
||||
print("Model loaded successfully", file=sys.stderr)
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
async def Predict(self, request, context):
|
||||
gen = self._predict(request, context, streaming=False)
|
||||
res = await gen.__anext__()
|
||||
return res
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
iterations = self._predict(request, context, streaming=True)
|
||||
try:
|
||||
async for iteration in iterations:
|
||||
yield iteration
|
||||
finally:
|
||||
try:
|
||||
await iterations.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def TokenizeString(self, request, context):
|
||||
if not getattr(self, "tokenizer", None):
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("tokenizer not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokens = self.tokenizer.encode(request.Prompt)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
async def Free(self, request, context):
|
||||
try:
|
||||
if hasattr(self, "llm"):
|
||||
try:
|
||||
self.llm.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
del self.llm
|
||||
if hasattr(self, "tokenizer"):
|
||||
del self.tokenizer
|
||||
self.tool_parser_name = None
|
||||
self.reasoning_parser_name = None
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="Model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
def _build_sampling_params(self, request) -> dict:
|
||||
sampling_params: dict = {"temperature": 0.7, "max_new_tokens": 200}
|
||||
mapping = {
|
||||
"N": "n",
|
||||
"PresencePenalty": "presence_penalty",
|
||||
"FrequencyPenalty": "frequency_penalty",
|
||||
"RepetitionPenalty": "repetition_penalty",
|
||||
"Temperature": "temperature",
|
||||
"TopP": "top_p",
|
||||
"TopK": "top_k",
|
||||
"MinP": "min_p",
|
||||
"Seed": "seed",
|
||||
"StopPrompts": "stop",
|
||||
"StopTokenIds": "stop_token_ids",
|
||||
"IgnoreEOS": "ignore_eos",
|
||||
"Tokens": "max_new_tokens",
|
||||
"MinTokens": "min_new_tokens",
|
||||
"SkipSpecialTokens": "skip_special_tokens",
|
||||
}
|
||||
for proto_field, sgl_key in mapping.items():
|
||||
if not hasattr(request, proto_field):
|
||||
continue
|
||||
value = getattr(request, proto_field)
|
||||
if value in (None, 0, 0.0, [], False, ""):
|
||||
continue
|
||||
# repeated fields come back as RepeatedScalarContainer — convert
|
||||
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
||||
value = list(value)
|
||||
if not value:
|
||||
continue
|
||||
sampling_params[sgl_key] = value
|
||||
|
||||
# Grammar → JSON schema or EBNF structured decoding.
|
||||
if getattr(request, "Grammar", ""):
|
||||
grammar = request.Grammar
|
||||
try:
|
||||
json.loads(grammar)
|
||||
sampling_params["json_schema"] = grammar
|
||||
except json.JSONDecodeError:
|
||||
sampling_params["ebnf"] = grammar
|
||||
|
||||
return sampling_params
|
||||
|
||||
def _build_prompt(self, request) -> str:
|
||||
prompt = request.Prompt
|
||||
if prompt or not request.UseTokenizerTemplate or not request.Messages:
|
||||
return prompt
|
||||
|
||||
if self.tokenizer is None:
|
||||
print(
|
||||
"UseTokenizerTemplate requested but tokenizer not loaded; "
|
||||
"falling back to naive concatenation",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return "\n".join(m.content or "" for m in request.Messages)
|
||||
|
||||
messages_dicts = self._messages_to_dicts(request.Messages)
|
||||
template_kwargs: dict = {"tokenize": False, "add_generation_prompt": True}
|
||||
if request.Tools:
|
||||
try:
|
||||
template_kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if request.Metadata.get("enable_thinking", "").lower() == "true":
|
||||
template_kwargs["enable_thinking"] = True
|
||||
|
||||
try:
|
||||
return self.tokenizer.apply_chat_template(messages_dicts, **template_kwargs)
|
||||
except TypeError:
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages_dicts, tokenize=False, add_generation_prompt=True,
|
||||
)
|
||||
|
||||
def _make_parsers(self, request):
|
||||
"""Construct fresh per-request parser instances (stateful)."""
|
||||
tool_parser = None
|
||||
reasoning_parser = None
|
||||
|
||||
if HAS_TOOL_PARSERS and self.tool_parser_name and request.Tools:
|
||||
try:
|
||||
tools_raw = json.loads(request.Tools)
|
||||
tools = [SglTool.model_validate(t) for t in tools_raw] if SglTool else tools_raw
|
||||
tool_parser = FunctionCallParser(
|
||||
tools=tools, tool_call_parser=self.tool_parser_name,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"FunctionCallParser init failed: {e!r}", file=sys.stderr)
|
||||
|
||||
if HAS_REASONING_PARSERS and self.reasoning_parser_name:
|
||||
try:
|
||||
reasoning_parser = ReasoningParser(
|
||||
model_type=self.reasoning_parser_name,
|
||||
stream_reasoning=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"ReasoningParser init failed: {e!r}", file=sys.stderr)
|
||||
|
||||
return tool_parser, reasoning_parser
|
||||
|
||||
async def _predict(self, request, context, streaming: bool = False):
|
||||
sampling_params = self._build_sampling_params(request)
|
||||
prompt = self._build_prompt(request)
|
||||
|
||||
tool_parser, reasoning_parser = self._make_parsers(request)
|
||||
|
||||
image_data = list(request.Images) if request.Images else None
|
||||
video_data = list(request.Videos) if request.Videos else None
|
||||
|
||||
# Kick off streaming generation. We always use stream=True so the
|
||||
# non-stream path still gets parser coverage on the final text.
|
||||
try:
|
||||
iterator = await self.llm.async_generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
image_data=image_data,
|
||||
video_data=video_data,
|
||||
stream=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"sglang async_generate failed: {e!r}", file=sys.stderr)
|
||||
yield backend_pb2.Reply(message=bytes(f"error: {e!r}", "utf-8"))
|
||||
return
|
||||
|
||||
generated_text = ""
|
||||
last_chunk: Optional[dict] = None
|
||||
# Track tool call ids once per (request, tool_index) to match the
|
||||
# OpenAI streaming contract (id sent on first chunk for that tool).
|
||||
tool_ids_seen: Dict[int, str] = {}
|
||||
|
||||
try:
|
||||
async for chunk in iterator:
|
||||
last_chunk = chunk
|
||||
cumulative = chunk.get("text", "") if isinstance(chunk, dict) else ""
|
||||
delta_text = cumulative[len(generated_text):] if cumulative.startswith(generated_text) else cumulative
|
||||
generated_text = cumulative
|
||||
if not delta_text:
|
||||
continue
|
||||
|
||||
reasoning_delta = ""
|
||||
content_delta = delta_text
|
||||
|
||||
if reasoning_parser is not None:
|
||||
try:
|
||||
r, n = reasoning_parser.parse_stream_chunk(delta_text)
|
||||
reasoning_delta = r or ""
|
||||
content_delta = n or ""
|
||||
except Exception as e:
|
||||
print(f"reasoning_parser.parse_stream_chunk: {e!r}", file=sys.stderr)
|
||||
|
||||
tool_call_deltas: List[backend_pb2.ToolCallDelta] = []
|
||||
if tool_parser is not None and content_delta:
|
||||
try:
|
||||
normal_text, calls = tool_parser.parse_stream_chunk(content_delta)
|
||||
content_delta = normal_text or ""
|
||||
for tc in calls:
|
||||
idx = int(getattr(tc, "tool_index", 0) or 0)
|
||||
tc_id = tool_ids_seen.get(idx)
|
||||
if tc_id is None:
|
||||
tc_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
tool_ids_seen[idx] = tc_id
|
||||
tool_call_deltas.append(backend_pb2.ToolCallDelta(
|
||||
index=idx,
|
||||
id=tc_id,
|
||||
name=getattr(tc, "name", "") or "",
|
||||
arguments=getattr(tc, "parameters", "") or "",
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"tool_parser.parse_stream_chunk: {e!r}", file=sys.stderr)
|
||||
|
||||
if streaming and (content_delta or reasoning_delta or tool_call_deltas):
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(content_delta, "utf-8"),
|
||||
chat_deltas=[backend_pb2.ChatDelta(
|
||||
content=content_delta,
|
||||
reasoning_content=reasoning_delta,
|
||||
tool_calls=tool_call_deltas,
|
||||
)],
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await iterator.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Extract token counts from the final chunk's meta_info.
|
||||
meta = {}
|
||||
if isinstance(last_chunk, dict):
|
||||
meta = last_chunk.get("meta_info") or {}
|
||||
prompt_tokens = int(meta.get("prompt_tokens", 0) or 0)
|
||||
completion_tokens = int(meta.get("completion_tokens", 0) or 0)
|
||||
|
||||
# Non-streaming path: re-parse the full text with fresh parsers
|
||||
# so we return a clean, complete ChatDelta. Streaming parsers
|
||||
# used above have accumulated state we don't want to reuse.
|
||||
final_content = generated_text
|
||||
final_reasoning = ""
|
||||
final_tool_calls: List[backend_pb2.ToolCallDelta] = []
|
||||
|
||||
if not streaming:
|
||||
final_reasoning_parser = None
|
||||
if HAS_REASONING_PARSERS and self.reasoning_parser_name:
|
||||
try:
|
||||
final_reasoning_parser = ReasoningParser(
|
||||
model_type=self.reasoning_parser_name,
|
||||
stream_reasoning=False,
|
||||
)
|
||||
except Exception:
|
||||
final_reasoning_parser = None
|
||||
|
||||
if final_reasoning_parser is not None:
|
||||
try:
|
||||
r, n = final_reasoning_parser.parse_non_stream(generated_text)
|
||||
final_reasoning = r or ""
|
||||
final_content = n if n is not None else generated_text
|
||||
except Exception as e:
|
||||
print(f"reasoning_parser.parse_non_stream: {e!r}", file=sys.stderr)
|
||||
|
||||
if HAS_TOOL_PARSERS and self.tool_parser_name and request.Tools:
|
||||
try:
|
||||
tools_raw = json.loads(request.Tools)
|
||||
tools = [SglTool.model_validate(t) for t in tools_raw] if SglTool else tools_raw
|
||||
fresh_tool_parser = FunctionCallParser(
|
||||
tools=tools, tool_call_parser=self.tool_parser_name,
|
||||
)
|
||||
normal, calls = fresh_tool_parser.parse_non_stream(final_content)
|
||||
if calls:
|
||||
final_content = normal
|
||||
for tc in calls:
|
||||
idx = int(getattr(tc, "tool_index", 0) or 0)
|
||||
final_tool_calls.append(backend_pb2.ToolCallDelta(
|
||||
index=idx,
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
name=getattr(tc, "name", "") or "",
|
||||
arguments=getattr(tc, "parameters", "") or "",
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"tool_parser.parse_non_stream: {e!r}", file=sys.stderr)
|
||||
|
||||
chat_delta = backend_pb2.ChatDelta(
|
||||
content=final_content if not streaming else "",
|
||||
reasoning_content=final_reasoning,
|
||||
tool_calls=final_tool_calls,
|
||||
)
|
||||
|
||||
if streaming:
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
chat_deltas=[chat_delta],
|
||||
)
|
||||
return
|
||||
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(final_content or "", "utf-8"),
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
chat_deltas=[chat_delta],
|
||||
)
|
||||
|
||||
|
||||
async def serve(address):
|
||||
server = grpc.aio.server(
|
||||
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
|
||||
|
||||
await server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
await server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the sglang gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(serve(args.addr))
|
||||
87
backend/python/sglang/install.sh
Executable file
87
backend/python/sglang/install.sh
Executable file
@@ -0,0 +1,87 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation"
|
||||
|
||||
# Avoid overcommitting the CPU during builds that compile native code.
|
||||
export NVCC_THREADS=2
|
||||
export MAX_JOBS=1
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xcpu" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy=unsafe-best-match"
|
||||
fi
|
||||
|
||||
# sglang's CPU path has no prebuilt wheel on PyPI — upstream publishes
|
||||
# a separate pyproject_cpu.toml that must be swapped in before `pip install`.
|
||||
# Reference: docker/xeon.Dockerfile in the sglang upstream repo.
|
||||
#
|
||||
# When BUILD_TYPE is empty (CPU profile) or FROM_SOURCE=true is forced,
|
||||
# install torch/transformers/etc from requirements-cpu.txt, then clone
|
||||
# sglang and install its python/ and sgl-kernel/ packages from source
|
||||
# using the CPU pyproject.
|
||||
if [ "x${BUILD_TYPE}" == "x" ] || [ "x${FROM_SOURCE:-}" == "xtrue" ]; then
|
||||
# sgl-kernel's CPU build links against libnuma and libtbb. Install
|
||||
# them here (Docker builder stage) before running the source build.
|
||||
# Harmless no-op on runs outside the docker build since installRequirements
|
||||
# below still needs them only if we reach the source build branch.
|
||||
if command -v apt-get >/dev/null 2>&1 && [ "$(id -u)" = "0" ]; then
|
||||
apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
libnuma-dev numactl libtbb-dev libgomp1 libomp-dev google-perftools \
|
||||
build-essential cmake ninja-build
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
# sgl-kernel's pyproject_cpu.toml uses scikit-build-core as its build
|
||||
# backend. With --no-build-isolation, that (and ninja/cmake) must be
|
||||
# present in the venv before we build from source.
|
||||
uv pip install --no-build-isolation "scikit-build-core>=0.10" ninja cmake
|
||||
|
||||
# sgl-kernel's CPU shm.cpp uses __m512 AVX-512 intrinsics unconditionally.
|
||||
# csrc/cpu/CMakeLists.txt hard-codes add_compile_options(-march=native),
|
||||
# which on runners without AVX-512 in /proc/cpuinfo fails with
|
||||
# "__m512 return without 'avx512f' enabled changes the ABI".
|
||||
# CXXFLAGS alone is insufficient because CMake's add_compile_options()
|
||||
# appends -march=native *after* CXXFLAGS, overriding it.
|
||||
# We therefore patch the CMakeLists.txt to replace -march=native with
|
||||
# -march=sapphirerapids so the flag is consistent throughout the build.
|
||||
# The resulting binary still requires an AVX-512 capable CPU at runtime,
|
||||
# same constraint sglang upstream documents in docker/xeon.Dockerfile.
|
||||
|
||||
_sgl_src=$(mktemp -d)
|
||||
trap 'rm -rf "${_sgl_src}"' EXIT
|
||||
git clone --depth 1 https://github.com/sgl-project/sglang "${_sgl_src}/sglang"
|
||||
|
||||
# Patch -march=native → -march=sapphirerapids in the CPU kernel CMakeLists
|
||||
sed -i 's/-march=native/-march=sapphirerapids/g' \
|
||||
"${_sgl_src}/sglang/sgl-kernel/csrc/cpu/CMakeLists.txt"
|
||||
|
||||
pushd "${_sgl_src}/sglang/sgl-kernel"
|
||||
if [ -f pyproject_cpu.toml ]; then
|
||||
cp pyproject_cpu.toml pyproject.toml
|
||||
fi
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
|
||||
popd
|
||||
|
||||
pushd "${_sgl_src}/sglang/python"
|
||||
if [ -f pyproject_cpu.toml ]; then
|
||||
cp pyproject_cpu.toml pyproject.toml
|
||||
fi
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} .
|
||||
popd
|
||||
else
|
||||
installRequirements
|
||||
fi
|
||||
63
backend/python/sglang/package.sh
Executable file
63
backend/python/sglang/package.sh
Executable file
@@ -0,0 +1,63 @@
|
||||
#!/bin/bash
|
||||
# Package runtime shared libraries for the sglang backend.
|
||||
#
|
||||
# Dockerfile.python's final stage is FROM scratch — every system library
|
||||
# the backend dlopens at runtime must be explicitly copied into
|
||||
# ${BACKEND}/lib, which libbackend.sh adds to LD_LIBRARY_PATH.
|
||||
#
|
||||
# sglang's CPU kernel links against libnuma and libtbb; torch's CPU
|
||||
# kernels use libgomp; tcmalloc + iomp5 are preloaded per sglang's
|
||||
# docker/xeon.Dockerfile recipe for best CPU throughput. Missing any of
|
||||
# these makes the engine crash on import.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
LIB_DIR="${CURDIR}/lib"
|
||||
mkdir -p "${LIB_DIR}"
|
||||
|
||||
copy_with_symlinks() {
|
||||
local soname="$1"
|
||||
local hit=""
|
||||
for dir in \
|
||||
/usr/lib/x86_64-linux-gnu \
|
||||
/usr/lib/aarch64-linux-gnu \
|
||||
/lib/x86_64-linux-gnu \
|
||||
/lib/aarch64-linux-gnu \
|
||||
/usr/lib \
|
||||
/lib; do
|
||||
if [ -e "${dir}/${soname}" ]; then
|
||||
hit="${dir}/${soname}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ -z "${hit}" ]; then
|
||||
echo "warning: ${soname} not found in standard lib paths" >&2
|
||||
return 0
|
||||
fi
|
||||
local real
|
||||
real=$(readlink -f "${hit}")
|
||||
cp -v "${real}" "${LIB_DIR}/"
|
||||
local real_base
|
||||
real_base=$(basename "${real}")
|
||||
if [ "${real_base}" != "${soname}" ]; then
|
||||
ln -sf "${real_base}" "${LIB_DIR}/${soname}"
|
||||
fi
|
||||
}
|
||||
|
||||
copy_with_symlinks libnuma.so.1
|
||||
copy_with_symlinks libgomp.so.1
|
||||
copy_with_symlinks libtbb.so.12
|
||||
copy_with_symlinks libtbbmalloc.so.2
|
||||
copy_with_symlinks libtcmalloc.so.4
|
||||
|
||||
# intel-openmp ships libiomp5.so inside the venv under venv/lib/ — sglang's
|
||||
# CPU kernel was compiled against its __kmpc_* symbols, so it must be on
|
||||
# LD_LIBRARY_PATH at runtime. Copy it into the backend lib dir where
|
||||
# libbackend.sh will pick it up.
|
||||
if [ -f "${CURDIR}/venv/lib/libiomp5.so" ]; then
|
||||
cp -v "${CURDIR}/venv/lib/libiomp5.so" "${LIB_DIR}/"
|
||||
fi
|
||||
|
||||
echo "sglang packaging completed successfully"
|
||||
ls -liah "${LIB_DIR}/"
|
||||
2
backend/python/sglang/requirements-after.txt
Normal file
2
backend/python/sglang/requirements-after.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# sglang is installed per-acceleration in requirements-{profile}-after.txt
|
||||
# (cublas12, hipblas, intel, cpu)
|
||||
3
backend/python/sglang/requirements-cpu-after.txt
Normal file
3
backend/python/sglang/requirements-cpu-after.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
# sglang has no prebuilt CPU wheel on PyPI. install.sh performs a
|
||||
# from-source build using the upstream pyproject_cpu.toml recipe from
|
||||
# docker/xeon.Dockerfile when BUILD_TYPE is empty (CPU profile).
|
||||
7
backend/python/sglang/requirements-cpu.txt
Normal file
7
backend/python/sglang/requirements-cpu.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
accelerate
|
||||
torch==2.9.0
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
intel-openmp; platform_machine == 'x86_64'
|
||||
3
backend/python/sglang/requirements-cublas12-after.txt
Normal file
3
backend/python/sglang/requirements-cublas12-after.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
# Bump this pin deliberately — sglang releases weekly and API surfaces
|
||||
# (FunctionCallParser, ReasoningParser) move between releases.
|
||||
sglang[all]>=0.4.0
|
||||
5
backend/python/sglang/requirements-cublas12.txt
Normal file
5
backend/python/sglang/requirements-cublas12.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
accelerate
|
||||
torch==2.7.1
|
||||
torchvision
|
||||
torchaudio==2.7.1
|
||||
transformers
|
||||
2
backend/python/sglang/requirements-hipblas-after.txt
Normal file
2
backend/python/sglang/requirements-hipblas-after.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# sglang's ROCm build is installed from source per docker/rocm.Dockerfile
|
||||
# upstream; install.sh handles the source build when BUILD_TYPE=hipblas.
|
||||
5
backend/python/sglang/requirements-hipblas.txt
Normal file
5
backend/python/sglang/requirements-hipblas.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/rocm7.0
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
6
backend/python/sglang/requirements-install.txt
Normal file
6
backend/python/sglang/requirements-install.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
# sglang and sgl-kernel do not declare full PEP517 build deps; install the
|
||||
# basic build tooling into the venv before pulling the rest of the stack.
|
||||
packaging
|
||||
setuptools
|
||||
wheel
|
||||
setuptools-scm
|
||||
2
backend/python/sglang/requirements-intel-after.txt
Normal file
2
backend/python/sglang/requirements-intel-after.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# sglang's Intel XPU build is installed from source per docker/xpu.Dockerfile
|
||||
# upstream; install.sh handles the source build when BUILD_PROFILE=intel.
|
||||
7
backend/python/sglang/requirements-intel.txt
Normal file
7
backend/python/sglang/requirements-intel.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
accelerate
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
optimum[openvino]
|
||||
setuptools
|
||||
4
backend/python/sglang/requirements.txt
Normal file
4
backend/python/sglang/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
grpcio==1.80.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
29
backend/python/sglang/run.sh
Executable file
29
backend/python/sglang/run.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $(realpath $0))
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# sglang's CPU kernel references LLVM OpenMP (__kmpc_*) symbols that are
|
||||
# not declared in its NEEDED list — they get resolved through LD_PRELOAD
|
||||
# of libiomp5.so in sglang's own docker/xeon.Dockerfile. Do the same here.
|
||||
# Harmless on GPU builds where libiomp5.so is absent.
|
||||
if [ -f "${backend_dir}/lib/libiomp5.so" ]; then
|
||||
if [ -n "${LD_PRELOAD:-}" ]; then
|
||||
export LD_PRELOAD="${backend_dir}/lib/libiomp5.so:${LD_PRELOAD}"
|
||||
else
|
||||
export LD_PRELOAD="${backend_dir}/lib/libiomp5.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# sglang CPU engine requires this env var to switch to the CPU backend.
|
||||
# No-op on GPU builds. See docker/xeon.Dockerfile in sglang upstream.
|
||||
if [ -f "${backend_dir}/lib/libiomp5.so" ]; then
|
||||
export SGLANG_USE_CPU_ENGINE=1
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
25
backend/python/tinygrad/Makefile
Normal file
25
backend/python/tinygrad/Makefile
Normal file
@@ -0,0 +1,25 @@
|
||||
.DEFAULT_GOAL := install
|
||||
|
||||
.PHONY: install
|
||||
install:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: install
|
||||
@echo "Running tinygrad..."
|
||||
bash run.sh
|
||||
@echo "tinygrad run."
|
||||
|
||||
.PHONY: test
|
||||
test: install
|
||||
@echo "Testing tinygrad..."
|
||||
bash test.sh
|
||||
@echo "tinygrad tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
785
backend/python/tinygrad/backend.py
Normal file
785
backend/python/tinygrad/backend.py
Normal file
@@ -0,0 +1,785 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LocalAI gRPC backend for tinygrad.
|
||||
|
||||
LLM execution is delegated to `tinygrad.apps.llm.Transformer` — we keep
|
||||
only a thin HF → GGUF-name adapter (vendor/appsllm_adapter.py) for the
|
||||
safetensors path; GGUF models load through `Transformer.from_gguf()`
|
||||
with native Q4/Q6/Q8 support.
|
||||
|
||||
Scope:
|
||||
- LLM text generation via apps.llm (Qwen3 / Qwen3.5 / Llama 3.x /
|
||||
GLM-4 / OLMoE / Kimi-K2 / Moonlight — anything apps.llm supports).
|
||||
- Native tool-call extraction via pluggable parsers (hermes,
|
||||
llama3_json, qwen3_xml, mistral).
|
||||
- Embeddings — mean-pooled last-hidden-state over the block stack.
|
||||
- Stable Diffusion 1.x, Whisper — handled by the vendored paths.
|
||||
|
||||
Sampling is greedy-only because `apps.llm.Transformer.generate` (in the
|
||||
tinygrad 0.12.0 PyPI release) ends with `.argmax(-1)` and takes no
|
||||
temperature / top-k / top-p / repetition-penalty arguments. These
|
||||
request fields are accepted and ignored.
|
||||
|
||||
The heavy imports (tinygrad, tokenizers, tinygrad.apps.llm) are deferred
|
||||
until `LoadModel`, because tinygrad binds its compute device at import
|
||||
time from env vars. `_select_tinygrad_device()` maps LocalAI's BUILD_TYPE
|
||||
onto the corresponding tinygrad env flag before any import happens.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent import futures
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import grpc
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors # noqa: E402
|
||||
|
||||
from tool_parsers import resolve_parser # noqa: E402
|
||||
from tool_parsers.base import ToolCall # noqa: E402
|
||||
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device selection — must run BEFORE `import tinygrad` anywhere.
|
||||
#
|
||||
# In production this is set by run.sh based on which driver libraries the
|
||||
# host has injected into the container (libcuda.so.1 → CUDA, libamdhip64
|
||||
# → HIP, otherwise CLANG). This helper is only a fallback for direct
|
||||
# invocations like the unit tests.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _select_tinygrad_device() -> None:
|
||||
if any(os.environ.get(k) == "1" for k in ("CUDA", "HIP", "METAL", "CLANG", "AMD", "NV")):
|
||||
return
|
||||
os.environ["CLANG"] = "1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model asset discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _resolve_model_assets(model_ref: str) -> Path:
|
||||
"""
|
||||
Accept either a local path or a HuggingFace repo id (e.g.
|
||||
"unsloth/Qwen3.5-0.8B-GGUF") and return the local directory / file.
|
||||
HF ids are materialized via `huggingface_hub.snapshot_download` — we
|
||||
pull both safetensors (for fp16 HF repos) and GGUF (for quantized
|
||||
repos) so the same code path handles either.
|
||||
"""
|
||||
p = Path(model_ref)
|
||||
if p.exists():
|
||||
return p
|
||||
if "/" in model_ref and not model_ref.startswith(("/", ".")):
|
||||
from huggingface_hub import snapshot_download
|
||||
local = snapshot_download(
|
||||
repo_id=model_ref,
|
||||
allow_patterns=[
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"generation_config.json",
|
||||
"*.safetensors",
|
||||
"*.safetensors.index.json",
|
||||
"*.gguf",
|
||||
],
|
||||
)
|
||||
return Path(local)
|
||||
raise FileNotFoundError(f"Model not found: {model_ref}")
|
||||
|
||||
|
||||
def _gguf_path(model_ref: Path) -> Optional[Path]:
|
||||
"""Return the GGUF file to load from a path that may be a file or dir."""
|
||||
if model_ref.is_file() and str(model_ref).endswith(".gguf"):
|
||||
return model_ref
|
||||
if model_ref.is_dir():
|
||||
ggufs = sorted(model_ref.glob("*.gguf"))
|
||||
if ggufs:
|
||||
return ggufs[0]
|
||||
return None
|
||||
|
||||
|
||||
def _load_hf_safetensors(model_dir: Path) -> dict[str, Any]:
|
||||
"""Load sharded or single-file HF safetensors from a directory."""
|
||||
from tinygrad.nn.state import safe_load
|
||||
|
||||
index = model_dir / "model.safetensors.index.json"
|
||||
if index.exists():
|
||||
with open(index) as fp:
|
||||
weight_map = json.load(fp)["weight_map"]
|
||||
shards: dict[str, Any] = {}
|
||||
for shard_name in set(weight_map.values()):
|
||||
shards[shard_name] = safe_load(str(model_dir / shard_name))
|
||||
return {k: shards[n][k] for k, n in weight_map.items()}
|
||||
|
||||
single = model_dir / "model.safetensors"
|
||||
if single.exists():
|
||||
return safe_load(str(single))
|
||||
|
||||
raise FileNotFoundError(f"No safetensors weights found under {model_dir}")
|
||||
|
||||
|
||||
def _auto_tool_parser(model_ref: Optional[str], config: dict) -> Optional[str]:
|
||||
"""Pick a tool parser automatically from model family heuristics.
|
||||
|
||||
Order of precedence: architecture name from config.json, then model ref
|
||||
string. Returns None to fall through to the passthrough parser.
|
||||
"""
|
||||
arches = " ".join(a.lower() for a in config.get("architectures", []))
|
||||
ref = (model_ref or "").lower()
|
||||
blob = f"{arches} {ref}"
|
||||
|
||||
if "qwen3" in blob:
|
||||
return "qwen3_xml"
|
||||
if "hermes" in blob or "qwen2" in blob or "qwen" in blob:
|
||||
return "hermes"
|
||||
if "llama-3" in blob or "llama_3" in blob or "llama3" in blob:
|
||||
return "llama3_json"
|
||||
if "mistral" in blob or "mixtral" in blob:
|
||||
return "mistral"
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Servicer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""gRPC servicer for the tinygrad backend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._reset_state()
|
||||
|
||||
def _reset_state(self) -> None:
|
||||
self.model_ref: Optional[str] = None
|
||||
self.model_type: str = "llm"
|
||||
self.options: dict[str, str] = {}
|
||||
# LLM state
|
||||
self.llm_model = None
|
||||
self.llm_config: dict = {}
|
||||
self.llm_tokenizer = None
|
||||
self.llm_eos_ids: list[int] = []
|
||||
self.chat_template: Optional[str] = None
|
||||
self.tool_parser = resolve_parser(None)
|
||||
self.max_context = 4096
|
||||
# Stable Diffusion state
|
||||
self.sd_model = None
|
||||
# Whisper state
|
||||
self.whisper_model = None
|
||||
self.whisper_tokenizer = None
|
||||
|
||||
# --------------------- helpers --------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _parse_options(options_list) -> dict[str, str]:
|
||||
opts: dict[str, str] = {}
|
||||
for opt in options_list:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
opts[key.strip()] = value.strip()
|
||||
return opts
|
||||
|
||||
@staticmethod
|
||||
def _detect_model_type(model_ref: str, explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
name = (model_ref or "").lower()
|
||||
if "whisper" in name:
|
||||
return "whisper"
|
||||
if "sdxl" in name:
|
||||
return "sdxl"
|
||||
if "sd-v1" in name or "v1-5" in name or "stable-diffusion" in name:
|
||||
return "sd15"
|
||||
if any(tag in name for tag in ("bge", "e5", "minilm", "bert")):
|
||||
return "bert"
|
||||
return "llm"
|
||||
|
||||
def _messages_to_dicts(self, messages) -> list[dict]:
|
||||
result = []
|
||||
for msg in messages:
|
||||
d: dict = {"role": msg.role, "content": msg.content or ""}
|
||||
if msg.name:
|
||||
d["name"] = msg.name
|
||||
if msg.tool_call_id:
|
||||
d["tool_call_id"] = msg.tool_call_id
|
||||
if msg.reasoning_content:
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
def _render_prompt(self, request) -> str:
|
||||
"""Render messages + tools into the model's chat template, or fall
|
||||
back to the raw Prompt field for models without a template."""
|
||||
if not request.Messages and request.Prompt:
|
||||
return request.Prompt
|
||||
|
||||
if not self.chat_template:
|
||||
# No template known — concatenate role/content lines.
|
||||
lines = []
|
||||
for msg in request.Messages:
|
||||
lines.append(f"{msg.role}: {msg.content or ''}")
|
||||
return "\n".join(lines) + "\nassistant:"
|
||||
|
||||
from jinja2 import Environment
|
||||
|
||||
env = Environment(trim_blocks=True, lstrip_blocks=True)
|
||||
template = env.from_string(self.chat_template)
|
||||
|
||||
tools = None
|
||||
if request.Tools:
|
||||
try:
|
||||
tools = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
tools = None
|
||||
|
||||
return template.render(
|
||||
messages=self._messages_to_dicts(request.Messages),
|
||||
tools=tools,
|
||||
add_generation_prompt=True,
|
||||
# Qwen3's chat template enables <think>...</think> reasoning
|
||||
# by default. On small models (0.6B) that reasoning preamble
|
||||
# eats the whole token budget before a tool call emerges, so
|
||||
# we disable it. Templates that don't know this var ignore it.
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
# --------------------- LLM path -------------------------------------
|
||||
|
||||
def _load_llm(self, model_path: Path) -> None:
|
||||
"""Load an LLM through `tinygrad.apps.llm.Transformer`.
|
||||
|
||||
Two paths:
|
||||
- GGUF file (anywhere in the tree) → `Transformer.from_gguf()`
|
||||
handles config, weight conversion (incl. Q4/Q6/Q8 quantization)
|
||||
and RoPE permute natively.
|
||||
- HF safetensors directory → build `TransformerConfig` from
|
||||
config.json and load weights via a small HF→GGUF-name adapter.
|
||||
"""
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.apps.llm import Transformer
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
|
||||
from vendor.appsllm_adapter import (
|
||||
_hf_to_appsllm_state_dict,
|
||||
_hf_to_transformer_kwargs,
|
||||
)
|
||||
|
||||
max_context_cap = 8192
|
||||
|
||||
gguf_file = _gguf_path(model_path)
|
||||
if gguf_file is not None:
|
||||
# GGUF path: apps.llm handles everything — config, quant, RoPE.
|
||||
gguf_tensor = Tensor.empty(
|
||||
os.stat(gguf_file).st_size, dtype=dtypes.uint8,
|
||||
device=f"disk:{gguf_file}",
|
||||
).to(Device.DEFAULT)
|
||||
model, kv = Transformer.from_gguf(gguf_tensor, max_context=max_context_cap)
|
||||
self.llm_model = model
|
||||
self.max_context = model.max_context
|
||||
# Preserve a config-shaped dict for tool-parser heuristics and
|
||||
# the "loaded" message.
|
||||
arch = kv.get("general.architecture", "")
|
||||
self.llm_config = {
|
||||
"architectures": [kv.get("general.name", arch) or arch],
|
||||
"gguf_kv": kv,
|
||||
}
|
||||
|
||||
# Tokenizer: prefer sidecar tokenizer.json (richer HF Jinja2
|
||||
# templates), fall back to apps.llm's SimpleTokenizer built
|
||||
# from GGUF metadata.
|
||||
self._load_tokenizer_for_dir(model_path if model_path.is_dir() else gguf_file.parent, gguf_kv=kv)
|
||||
else:
|
||||
# HF safetensors path.
|
||||
if not model_path.is_dir():
|
||||
raise FileNotFoundError(f"Expected HF model directory, got file: {model_path}")
|
||||
config_path = model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"config.json not found under {model_path}")
|
||||
with open(config_path) as fp:
|
||||
hf_config = json.load(fp)
|
||||
self.llm_config = hf_config
|
||||
|
||||
raw_weights = _load_hf_safetensors(model_path)
|
||||
n_layers = hf_config["num_hidden_layers"]
|
||||
state_dict = _hf_to_appsllm_state_dict(raw_weights, n_layers)
|
||||
|
||||
kwargs = _hf_to_transformer_kwargs(hf_config, state_dict, max_context_cap)
|
||||
self.max_context = kwargs["max_context"]
|
||||
|
||||
model = Transformer(**kwargs)
|
||||
load_state_dict(model, state_dict, strict=False, consume=True)
|
||||
self.llm_model = model
|
||||
|
||||
self._load_tokenizer_for_dir(model_path, gguf_kv=None)
|
||||
|
||||
# Auto-pick tool parser from options or model family.
|
||||
parser_name = self.options.get("tool_parser") or _auto_tool_parser(self.model_ref, self.llm_config)
|
||||
self.tool_parser = resolve_parser(parser_name)
|
||||
|
||||
def _load_tokenizer_for_dir(self, model_dir: Path, gguf_kv: Optional[dict]) -> None:
|
||||
"""Load HF tokenizer + chat template + EOS ids from a model directory.
|
||||
|
||||
Falls back to apps.llm's `SimpleTokenizer.from_gguf_kv` when there
|
||||
is no `tokenizer.json` sidecar (single-file GGUF, no HF repo).
|
||||
"""
|
||||
tokenizer_json = model_dir / "tokenizer.json"
|
||||
if tokenizer_json.exists():
|
||||
from tokenizers import Tokenizer as HFTokenizer
|
||||
self.llm_tokenizer = HFTokenizer.from_file(str(tokenizer_json))
|
||||
elif gguf_kv is not None:
|
||||
from tinygrad.apps.llm import SimpleTokenizer
|
||||
self.llm_tokenizer = SimpleTokenizer.from_gguf_kv(gguf_kv)
|
||||
else:
|
||||
raise FileNotFoundError(f"tokenizer.json not found under {model_dir}")
|
||||
|
||||
tok_cfg_path = model_dir / "tokenizer_config.json"
|
||||
if tok_cfg_path.exists():
|
||||
with open(tok_cfg_path) as fp:
|
||||
tok_cfg = json.load(fp)
|
||||
self.chat_template = tok_cfg.get("chat_template")
|
||||
|
||||
self.llm_eos_ids = []
|
||||
for cfg_name in ("generation_config.json", "config.json"):
|
||||
cfg_path = model_dir / cfg_name
|
||||
if not cfg_path.exists():
|
||||
continue
|
||||
with open(cfg_path) as fp:
|
||||
cfg = json.load(fp)
|
||||
eos = cfg.get("eos_token_id")
|
||||
if isinstance(eos, list):
|
||||
self.llm_eos_ids.extend(int(x) for x in eos)
|
||||
elif isinstance(eos, int):
|
||||
self.llm_eos_ids.append(eos)
|
||||
if self.llm_eos_ids:
|
||||
break
|
||||
if not self.llm_eos_ids and gguf_kv is not None:
|
||||
eos = gguf_kv.get("tokenizer.ggml.eos_token_id")
|
||||
if isinstance(eos, int):
|
||||
self.llm_eos_ids.append(eos)
|
||||
|
||||
# --------------------- Stable Diffusion path ------------------------
|
||||
|
||||
def _load_sd(self, model_ref: str) -> None:
|
||||
"""Load a Stable Diffusion 1.x checkpoint (CompVis `.ckpt` format)."""
|
||||
from huggingface_hub import hf_hub_download
|
||||
from tinygrad.nn.state import load_state_dict, torch_load
|
||||
|
||||
from vendor.stable_diffusion import StableDiffusion
|
||||
|
||||
ckpt_path = Path(model_ref)
|
||||
if not ckpt_path.exists():
|
||||
# Accept an HF repo id — fetch the canonical v1-5-pruned-emaonly.ckpt
|
||||
# from the requested repo. Common case is runwayml/stable-diffusion-v1-5.
|
||||
repo_id = model_ref if "/" in model_ref else "runwayml/stable-diffusion-v1-5"
|
||||
ckpt_file = self.options.get("sd_ckpt_filename", "v1-5-pruned-emaonly.ckpt")
|
||||
ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_file))
|
||||
|
||||
model = StableDiffusion()
|
||||
state_dict = torch_load(str(ckpt_path))
|
||||
if isinstance(state_dict, dict) and "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
load_state_dict(model, state_dict, strict=False, verbose=False, realize=False)
|
||||
self.sd_model = model
|
||||
|
||||
# --------------------- Whisper path ---------------------------------
|
||||
|
||||
def _load_whisper(self, model_ref: str) -> None:
|
||||
"""Load a Whisper checkpoint (OpenAI `.pt` format).
|
||||
|
||||
Accepts a model-size alias (tiny / tiny.en / base / base.en / small /
|
||||
small.en) OR an explicit `.pt` file path OR the HF repo id naming
|
||||
convention `openai/whisper-*` (mapped to the matching OpenAI alias).
|
||||
"""
|
||||
from vendor.whisper import init_whisper, MODEL_URLS
|
||||
|
||||
alias = model_ref
|
||||
if "/" in alias and alias.startswith("openai/whisper-"):
|
||||
alias = alias.removeprefix("openai/whisper-")
|
||||
if alias not in MODEL_URLS:
|
||||
# Explicit path to a .pt checkpoint — fall back to size heuristic
|
||||
# via filename.
|
||||
basename = Path(alias).name.lower()
|
||||
for name in MODEL_URLS:
|
||||
if name in basename:
|
||||
alias = name
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown Whisper model_ref={model_ref!r}; expected one of {list(MODEL_URLS)} "
|
||||
f"or an openai/whisper-* HF id"
|
||||
)
|
||||
|
||||
model, enc = init_whisper(alias, batch_size=1)
|
||||
self.whisper_model = model
|
||||
self.whisper_tokenizer = enc
|
||||
|
||||
# --------------------- LLM generation -------------------------------
|
||||
|
||||
def _encode_prompt(self, prompt: str) -> list[int]:
|
||||
"""Normalize tokenizer output: HF `tokenizers.Tokenizer.encode()`
|
||||
returns an `Encoding` with `.ids`; apps.llm's `SimpleTokenizer.encode()`
|
||||
returns `list[int]` directly."""
|
||||
encoded = self.llm_tokenizer.encode(prompt)
|
||||
return list(getattr(encoded, "ids", encoded))
|
||||
|
||||
def _decode_tokens(self, ids: list[int]) -> str:
|
||||
return self.llm_tokenizer.decode(ids)
|
||||
|
||||
def _generate_tokens(self, prompt: str, max_new_tokens: int, temperature: float):
|
||||
"""Yield (token_id, token_text) pairs using `apps.llm.Transformer.generate()`.
|
||||
|
||||
tinygrad 0.12.0's `generate()` is greedy-only (its `forward` ends
|
||||
with `.argmax(-1)` and it takes no temperature / top-k / top-p
|
||||
knobs). We accept `temperature` in the signature for API
|
||||
compatibility but it is ignored.
|
||||
"""
|
||||
del temperature # tinygrad.apps.llm.Transformer.generate is greedy-only
|
||||
ids = self._encode_prompt(prompt)
|
||||
if not ids:
|
||||
return
|
||||
|
||||
count = 0
|
||||
for next_tok in self.llm_model.generate(list(ids)):
|
||||
if next_tok in self.llm_eos_ids:
|
||||
break
|
||||
yield next_tok, self._decode_tokens([next_tok])
|
||||
count += 1
|
||||
if count >= max_new_tokens:
|
||||
break
|
||||
|
||||
# --------------------- gRPC methods ---------------------------------
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
async def LoadModel(self, request, context):
|
||||
try:
|
||||
_select_tinygrad_device()
|
||||
self._reset_state()
|
||||
self.options = self._parse_options(list(request.Options))
|
||||
self.model_ref = request.ModelFile or request.Model
|
||||
self.model_type = self._detect_model_type(self.model_ref, self.options.get("model_type"))
|
||||
|
||||
if self.model_type in ("sd15", "sd", "stable-diffusion"):
|
||||
self._load_sd(self.model_ref)
|
||||
return backend_pb2.Result(
|
||||
success=True, message="tinygrad Stable Diffusion 1.x loaded",
|
||||
)
|
||||
|
||||
if self.model_type == "whisper":
|
||||
self._load_whisper(self.model_ref)
|
||||
return backend_pb2.Result(
|
||||
success=True, message="tinygrad Whisper loaded",
|
||||
)
|
||||
|
||||
if self.model_type != "llm":
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message=f"tinygrad: model_type={self.model_type} not yet implemented",
|
||||
)
|
||||
|
||||
model_path = _resolve_model_assets(self.model_ref)
|
||||
self._load_llm(model_path)
|
||||
|
||||
return backend_pb2.Result(
|
||||
success=True,
|
||||
message=f"tinygrad LLM loaded (arch={self.llm_config.get('architectures', ['?'])[0]}, "
|
||||
f"parser={self.tool_parser.name})",
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"LoadModel failed: {exc}")
|
||||
|
||||
async def Predict(self, request, context):
|
||||
if self.llm_model is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("LLM not loaded")
|
||||
return backend_pb2.Reply()
|
||||
|
||||
try:
|
||||
prompt = self._render_prompt(request)
|
||||
max_new = request.Tokens if request.Tokens > 0 else 256
|
||||
temperature = request.Temperature if request.Temperature > 0 else 0.7
|
||||
|
||||
t0 = time.monotonic()
|
||||
pieces: list[str] = []
|
||||
ntok = 0
|
||||
for _, text in self._generate_tokens(prompt, max_new, temperature):
|
||||
pieces.append(text)
|
||||
ntok += 1
|
||||
elapsed = time.monotonic() - t0
|
||||
|
||||
full = "".join(pieces)
|
||||
from tool_parsers.hermes import HermesToolParser
|
||||
if isinstance(self.tool_parser, HermesToolParser):
|
||||
result = self.tool_parser.parse_full(full)
|
||||
content, calls, reasoning = result.content, result.tool_calls, result.reasoning
|
||||
else:
|
||||
content, calls = self.tool_parser.parse(full)
|
||||
reasoning = ""
|
||||
|
||||
delta = backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning,
|
||||
tool_calls=[
|
||||
backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
|
||||
for c in calls
|
||||
],
|
||||
)
|
||||
return backend_pb2.Reply(
|
||||
message=content.encode("utf-8"),
|
||||
tokens=ntok,
|
||||
timing_token_generation=elapsed,
|
||||
chat_deltas=[delta],
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Predict failed: {exc}")
|
||||
return backend_pb2.Reply()
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
if self.llm_model is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("LLM not loaded")
|
||||
return
|
||||
|
||||
try:
|
||||
prompt = self._render_prompt(request)
|
||||
max_new = request.Tokens if request.Tokens > 0 else 256
|
||||
temperature = request.Temperature if request.Temperature > 0 else 0.7
|
||||
|
||||
buffer = ""
|
||||
for _, text in self._generate_tokens(prompt, max_new, temperature):
|
||||
buffer += text
|
||||
yield backend_pb2.Reply(
|
||||
message=text.encode("utf-8"),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=text)],
|
||||
)
|
||||
|
||||
# Final emission carries the extracted tool calls (vLLM semantics).
|
||||
from tool_parsers.hermes import HermesToolParser
|
||||
if isinstance(self.tool_parser, HermesToolParser):
|
||||
result = self.tool_parser.parse_full(buffer)
|
||||
calls = result.tool_calls
|
||||
reasoning = result.reasoning
|
||||
else:
|
||||
_, calls = self.tool_parser.parse(buffer)
|
||||
reasoning = ""
|
||||
|
||||
if calls or reasoning:
|
||||
yield backend_pb2.Reply(
|
||||
chat_deltas=[backend_pb2.ChatDelta(
|
||||
reasoning_content=reasoning,
|
||||
tool_calls=[
|
||||
backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
|
||||
for c in calls
|
||||
],
|
||||
)],
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"PredictStream failed: {exc}")
|
||||
|
||||
async def Embedding(self, request, context):
|
||||
if self.llm_model is None or self.llm_tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("No model loaded")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
try:
|
||||
text = request.Embeddings
|
||||
if not text:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("Embeddings field is empty")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
from tinygrad import Tensor, dtypes
|
||||
from vendor.appsllm_adapter import _embed_hidden
|
||||
|
||||
ids = self._encode_prompt(text)
|
||||
if not ids:
|
||||
return backend_pb2.EmbeddingResult(embeddings=[])
|
||||
|
||||
# Clamp to context window — truncate long inputs rather than blow up.
|
||||
ids = ids[: self.max_context]
|
||||
tokens = Tensor([ids])
|
||||
|
||||
hidden = _embed_hidden(self.llm_model, tokens) # (1, seqlen, dim)
|
||||
# Mean pool over sequence dim
|
||||
pooled = hidden.mean(axis=1).squeeze(0) # (dim,)
|
||||
# L2 normalize
|
||||
norm = pooled.square().sum().sqrt()
|
||||
normalized = (pooled / (norm + 1e-12))
|
||||
vec = normalized.cast(dtypes.float32).tolist()
|
||||
|
||||
return backend_pb2.EmbeddingResult(embeddings=[float(x) for x in vec])
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Embedding failed: {exc}")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
async def GenerateImage(self, request, context):
|
||||
if self.sd_model is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("No Stable Diffusion model loaded")
|
||||
return backend_pb2.Result(success=False, message="not loaded")
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
from vendor.stable_diffusion import run_sd15
|
||||
|
||||
steps = request.step if request.step > 0 else 20
|
||||
guidance = 7.5
|
||||
seed = request.seed if request.seed != 0 else None
|
||||
img_tensor = run_sd15(
|
||||
model=self.sd_model,
|
||||
prompt=request.positive_prompt or "",
|
||||
negative_prompt=request.negative_prompt or "",
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
)
|
||||
arr = img_tensor.numpy()
|
||||
image = Image.fromarray(arr)
|
||||
dst = request.dst or "/tmp/tinygrad_image.png"
|
||||
image.save(dst)
|
||||
return backend_pb2.Result(success=True, message=dst)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"GenerateImage failed: {exc}")
|
||||
|
||||
def _transcribe(self, audio_path: str, language: Optional[str]) -> tuple[str, float]:
|
||||
from vendor.whisper import load_file_waveform, transcribe_waveform
|
||||
|
||||
waveform = load_file_waveform(audio_path)
|
||||
text = transcribe_waveform(
|
||||
self.whisper_model,
|
||||
self.whisper_tokenizer,
|
||||
[waveform],
|
||||
language=language or None,
|
||||
)
|
||||
duration = float(len(waveform)) / 16000.0
|
||||
return text, duration
|
||||
|
||||
async def AudioTranscription(self, request, context):
|
||||
if self.whisper_model is None or self.whisper_tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("No Whisper model loaded")
|
||||
return backend_pb2.TranscriptResult()
|
||||
|
||||
try:
|
||||
if not request.dst:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("TranscriptRequest.dst (audio file path) is required")
|
||||
return backend_pb2.TranscriptResult()
|
||||
|
||||
text, duration = self._transcribe(request.dst, request.language)
|
||||
segments = [backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)]
|
||||
return backend_pb2.TranscriptResult(
|
||||
text=text,
|
||||
language=request.language or "en",
|
||||
duration=duration,
|
||||
segments=segments,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"AudioTranscription failed: {exc}")
|
||||
return backend_pb2.TranscriptResult()
|
||||
|
||||
async def AudioTranscriptionStream(self, request, context):
|
||||
if self.whisper_model is None or self.whisper_tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("No Whisper model loaded")
|
||||
return
|
||||
|
||||
try:
|
||||
if not request.dst:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("TranscriptRequest.dst (audio file path) is required")
|
||||
return
|
||||
|
||||
# The vendored tinygrad whisper loop is chunked at the file level
|
||||
# (one inference pass per 30s segment), not token-level. To still
|
||||
# produce a streaming response we run the full transcription and
|
||||
# emit it as a single delta + a final-result envelope so the client
|
||||
# gets both code paths exercised.
|
||||
text, duration = self._transcribe(request.dst, request.language)
|
||||
yield backend_pb2.TranscriptStreamResponse(delta=text)
|
||||
final = backend_pb2.TranscriptResult(
|
||||
text=text,
|
||||
language=request.language or "en",
|
||||
duration=duration,
|
||||
segments=[backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)],
|
||||
)
|
||||
yield backend_pb2.TranscriptStreamResponse(final_result=final)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"AudioTranscriptionStream failed: {exc}")
|
||||
|
||||
async def Status(self, request, context):
|
||||
return backend_pb2.StatusResponse(state=backend_pb2.StatusResponse.READY)
|
||||
|
||||
async def Free(self, request, context):
|
||||
self._reset_state()
|
||||
return backend_pb2.Result(success=True, message="freed")
|
||||
|
||||
|
||||
async def serve(address):
|
||||
server = grpc.aio.server(
|
||||
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
|
||||
|
||||
await server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
await server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the tinygrad gRPC backend.")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="Bind address")
|
||||
args = parser.parse_args()
|
||||
asyncio.run(serve(args.addr))
|
||||
17
backend/python/tinygrad/install.sh
Executable file
17
backend/python/tinygrad/install.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# tinygrad >= 0.12 requires Python >= 3.11 (pyproject: `requires-python = ">=3.11"`).
|
||||
# LocalAI's default portable python is 3.10, so we pin to 3.11.x here.
|
||||
PYTHON_VERSION="3.11"
|
||||
PYTHON_PATCH="14"
|
||||
PY_STANDALONE_TAG="20260203"
|
||||
|
||||
installRequirements
|
||||
103
backend/python/tinygrad/package.sh
Executable file
103
backend/python/tinygrad/package.sh
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/bin/bash
|
||||
# Script to package runtime shared libraries for the tinygrad backend.
|
||||
#
|
||||
# The final Dockerfile.python stage is FROM scratch, so system libraries
|
||||
# must be explicitly copied into ${BACKEND}/lib so the backend can run on
|
||||
# any host without installing them. libbackend.sh automatically prepends
|
||||
# that directory to LD_LIBRARY_PATH at run time.
|
||||
#
|
||||
# tinygrad's CPU device (CLANG / LLVM renderer) JIT-compiles kernels at
|
||||
# runtime. The default `CLANG` path invokes the external `clang` binary via
|
||||
# subprocess, which does not exist in the scratch image. We force the
|
||||
# in-process LLVM path (`CPU_LLVM=1` in run.sh) which loads libLLVM.so.*
|
||||
# through ctypes and bundle the library + its runtime dependencies here.
|
||||
#
|
||||
# Also bundle libgomp (pulled by librosa / numpy via numba) and libsndfile
|
||||
# (required by soundfile -> librosa audio I/O for Whisper).
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
LIB_DIR="${CURDIR}/lib"
|
||||
mkdir -p "${LIB_DIR}"
|
||||
|
||||
SEARCH_DIRS=(
|
||||
/usr/lib/x86_64-linux-gnu
|
||||
/usr/lib/aarch64-linux-gnu
|
||||
/lib/x86_64-linux-gnu
|
||||
/lib/aarch64-linux-gnu
|
||||
/usr/lib
|
||||
/lib
|
||||
)
|
||||
|
||||
copy_with_symlinks() {
|
||||
local soname="$1"
|
||||
local hit=""
|
||||
for dir in "${SEARCH_DIRS[@]}"; do
|
||||
if [ -e "${dir}/${soname}" ]; then
|
||||
hit="${dir}/${soname}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ -z "${hit}" ]; then
|
||||
echo "warning: ${soname} not found in standard lib paths" >&2
|
||||
return 0
|
||||
fi
|
||||
local real
|
||||
real=$(readlink -f "${hit}")
|
||||
cp -v "${real}" "${LIB_DIR}/"
|
||||
local real_base
|
||||
real_base=$(basename "${real}")
|
||||
if [ "${real_base}" != "${soname}" ]; then
|
||||
ln -sf "${real_base}" "${LIB_DIR}/${soname}"
|
||||
fi
|
||||
}
|
||||
|
||||
# tinygrad searches for libLLVM under these sonames (see
|
||||
# tinygrad/runtime/autogen/llvm.py). Ubuntu 24.04's `llvm` metapackage
|
||||
# installs `libLLVM-18.so.1` into `/usr/lib/llvm-18/lib/`. Also scan the
|
||||
# standard lib directories in case a different distro layout puts it in
|
||||
# /usr/lib/x86_64-linux-gnu.
|
||||
llvm_so=""
|
||||
shopt -s nullglob
|
||||
LLVM_EXTRA_DIRS=(/usr/lib/llvm-*/lib /usr/lib/llvm-*)
|
||||
# First try the versioned symlink (libLLVM-18.so) since that's what
|
||||
# tinygrad's DLL loader matches against (see llvm.py DLL name list).
|
||||
for dir in "${SEARCH_DIRS[@]}" "${LLVM_EXTRA_DIRS[@]}"; do
|
||||
for candidate in "${dir}"/libLLVM-[0-9]*.so "${dir}"/libLLVM-[0-9]*.so.[0-9]*; do
|
||||
if [ -e "${candidate}" ]; then
|
||||
llvm_so="${candidate}"
|
||||
break 2
|
||||
fi
|
||||
done
|
||||
done
|
||||
# Fallback: any libLLVM.so file under /usr.
|
||||
if [ -z "${llvm_so}" ]; then
|
||||
llvm_so=$(find /usr -maxdepth 5 -name 'libLLVM*.so*' 2>/dev/null | head -1)
|
||||
fi
|
||||
shopt -u nullglob
|
||||
if [ -z "${llvm_so}" ]; then
|
||||
echo "ERROR: libLLVM not found — tinygrad CPU device needs it." >&2
|
||||
echo "Install the Ubuntu \`llvm\` package in the builder stage." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Found libLLVM at: ${llvm_so}"
|
||||
llvm_base=$(basename "${llvm_so}")
|
||||
real_llvm=$(readlink -f "${llvm_so}")
|
||||
cp -v "${real_llvm}" "${LIB_DIR}/"
|
||||
real_base=$(basename "${real_llvm}")
|
||||
if [ "${real_base}" != "${llvm_base}" ]; then
|
||||
ln -sf "${real_base}" "${LIB_DIR}/${llvm_base}"
|
||||
fi
|
||||
|
||||
# libLLVM has soft runtime deps on libedit / libtinfo; pick them up if
|
||||
# present. They're optional but loading without them can fail.
|
||||
copy_with_symlinks libedit.so.2
|
||||
copy_with_symlinks libtinfo.so.6
|
||||
|
||||
# Audio I/O for the Whisper path.
|
||||
copy_with_symlinks libsndfile.so.1
|
||||
copy_with_symlinks libgomp.so.1
|
||||
|
||||
echo "tinygrad packaging completed successfully"
|
||||
ls -liah "${LIB_DIR}/"
|
||||
11
backend/python/tinygrad/protogen.sh
Executable file
11
backend/python/tinygrad/protogen.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runProtogen
|
||||
1
backend/python/tinygrad/requirements-cpu.txt
Normal file
1
backend/python/tinygrad/requirements-cpu.txt
Normal file
@@ -0,0 +1 @@
|
||||
# tinygrad CPU backend uses CLANG device (no extra deps required).
|
||||
2
backend/python/tinygrad/requirements-cublas12.txt
Normal file
2
backend/python/tinygrad/requirements-cublas12.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# tinygrad drives CUDA through its own JIT (CUDA=1 env var).
|
||||
# Requires the CUDA 12 runtime from the base image; no extra Python deps.
|
||||
2
backend/python/tinygrad/requirements-cublas13.txt
Normal file
2
backend/python/tinygrad/requirements-cublas13.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# tinygrad drives CUDA through its own JIT (CUDA=1 env var).
|
||||
# Requires the CUDA 13 runtime from the base image; no extra Python deps.
|
||||
15
backend/python/tinygrad/requirements.txt
Normal file
15
backend/python/tinygrad/requirements.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
grpcio==1.80.0
|
||||
protobuf==6.33.5
|
||||
certifi
|
||||
setuptools
|
||||
numpy>=2.0.0
|
||||
tinygrad>=0.12.0
|
||||
tokenizers>=0.21.0
|
||||
huggingface_hub
|
||||
jinja2>=3.1.0
|
||||
tiktoken
|
||||
sentencepiece
|
||||
safetensors
|
||||
Pillow
|
||||
librosa
|
||||
soundfile
|
||||
55
backend/python/tinygrad/run.sh
Executable file
55
backend/python/tinygrad/run.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# tinygrad binds its compute device at import time from a single env var
|
||||
# (CUDA / HIP / METAL / CLANG). We pick one here based on what driver
|
||||
# libraries the host has injected into the container — when a user runs
|
||||
# the image with `--gpus all` (or the equivalent rocm runtime), the
|
||||
# nvidia-container-toolkit / rocm runtime mounts the right libraries
|
||||
# under /usr/lib so we can detect them.
|
||||
#
|
||||
# tinygrad's CUDA path uses two compiler pairs: an NVRTC-backed one and
|
||||
# an in-process PTX renderer. We force the PTX renderer here
|
||||
# (`CUDA_PTX=1`) so the image is independent of the host CUDA toolkit
|
||||
# version — only libcuda.so.1 (the driver) is required.
|
||||
find_lib() {
|
||||
local soname="$1"
|
||||
for dir in /usr/lib/x86_64-linux-gnu /usr/lib64 /usr/lib /lib/x86_64-linux-gnu /lib64 /lib; do
|
||||
if [ -e "${dir}/${soname}" ]; then
|
||||
echo "${dir}/${soname}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
if [ -z "${CUDA:-}${HIP:-}${METAL:-}${CLANG:-}" ]; then
|
||||
if find_lib libcuda.so.1 >/dev/null; then
|
||||
export CUDA=1
|
||||
export CUDA_PTX=1
|
||||
elif find_lib libamdhip64.so >/dev/null || find_lib libamdhip64.so.6 >/dev/null; then
|
||||
export HIP=1
|
||||
else
|
||||
export CLANG=1
|
||||
fi
|
||||
fi
|
||||
|
||||
# The CPU path (CLANG=1) JIT-compiles via libLLVM. Force tinygrad's
|
||||
# in-process LLVM compiler so we don't need an external `clang` binary
|
||||
# (which is not present in the scratch image).
|
||||
export CPU_LLVM=1
|
||||
if [ -z "${LLVM_PATH:-}" ]; then
|
||||
for candidate in "${EDIR}"/lib/libLLVM-*.so "${EDIR}"/lib/libLLVM-*.so.* "${EDIR}"/lib/libLLVM.so.*; do
|
||||
if [ -e "${candidate}" ]; then
|
||||
export LLVM_PATH="${candidate}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
153
backend/python/tinygrad/test.py
Normal file
153
backend/python/tinygrad/test.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Unit tests for the tinygrad gRPC backend.
|
||||
|
||||
These tests cover the cheap paths that don't need a real model checkpoint:
|
||||
- Health responds OK
|
||||
- Tool-call parsers emit expected ToolCall structures
|
||||
|
||||
The full LLM / embeddings / Stable Diffusion / Whisper paths are exercised by
|
||||
the root-level `make test-extra-backend-tinygrad-all` e2e targets, which boot
|
||||
the containerized backend against real HF checkpoints.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import grpc
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from tool_parsers.hermes import HermesToolParser # noqa: E402
|
||||
from vendor.appsllm_adapter import _hf_to_appsllm_state_dict # noqa: E402
|
||||
|
||||
|
||||
class TestHealth(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
def tearDown(self):
|
||||
self.service.kill()
|
||||
self.service.wait()
|
||||
|
||||
def test_health(self):
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b"OK")
|
||||
|
||||
|
||||
class TestHermesParser(unittest.TestCase):
|
||||
def test_single_tool_call(self):
|
||||
parser = HermesToolParser()
|
||||
text = (
|
||||
"Sure, let me check.\n"
|
||||
"<tool_call>\n"
|
||||
'{"name": "get_weather", "arguments": {"city": "Paris"}}\n'
|
||||
"</tool_call>\n"
|
||||
"Done."
|
||||
)
|
||||
content, calls = parser.parse(text)
|
||||
self.assertIn("Sure", content)
|
||||
self.assertIn("Done", content)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0].name, "get_weather")
|
||||
self.assertIn("Paris", calls[0].arguments)
|
||||
|
||||
def test_multi_call_and_thinking(self):
|
||||
parser = HermesToolParser()
|
||||
text = (
|
||||
"<think>I need both.</think>"
|
||||
'<tool_call>{"name":"a","arguments":{"x":1}}</tool_call>'
|
||||
'<tool_call>{"name":"b","arguments":{}}</tool_call>'
|
||||
)
|
||||
result = parser.parse_full(text)
|
||||
self.assertEqual(result.reasoning, "I need both.")
|
||||
self.assertEqual([c.name for c in result.tool_calls], ["a", "b"])
|
||||
self.assertEqual(result.tool_calls[0].index, 0)
|
||||
self.assertEqual(result.tool_calls[1].index, 1)
|
||||
|
||||
def test_no_tool_call_is_passthrough(self):
|
||||
parser = HermesToolParser()
|
||||
text = "plain assistant answer with no tool call"
|
||||
content, calls = parser.parse(text)
|
||||
self.assertEqual(content, text)
|
||||
self.assertEqual(calls, [])
|
||||
|
||||
|
||||
class TestAppsLLMAdapter(unittest.TestCase):
|
||||
"""Smoke tests for the HF → tinygrad.apps.llm state-dict keymap."""
|
||||
|
||||
def _fake_hf_weights(self, n_layers: int = 2, include_lm_head: bool = True):
|
||||
keys = [
|
||||
"model.embed_tokens.weight",
|
||||
"model.norm.weight",
|
||||
]
|
||||
if include_lm_head:
|
||||
keys.append("lm_head.weight")
|
||||
for l in range(n_layers):
|
||||
keys += [
|
||||
f"model.layers.{l}.input_layernorm.weight",
|
||||
f"model.layers.{l}.post_attention_layernorm.weight",
|
||||
f"model.layers.{l}.self_attn.q_proj.weight",
|
||||
f"model.layers.{l}.self_attn.k_proj.weight",
|
||||
f"model.layers.{l}.self_attn.v_proj.weight",
|
||||
f"model.layers.{l}.self_attn.o_proj.weight",
|
||||
f"model.layers.{l}.self_attn.q_norm.weight",
|
||||
f"model.layers.{l}.self_attn.k_norm.weight",
|
||||
f"model.layers.{l}.mlp.gate_proj.weight",
|
||||
f"model.layers.{l}.mlp.up_proj.weight",
|
||||
f"model.layers.{l}.mlp.down_proj.weight",
|
||||
]
|
||||
# sentinel objects so we can verify identity-based aliasing
|
||||
return {k: object() for k in keys}
|
||||
|
||||
def test_keymap_renames_every_hf_key(self):
|
||||
hf = self._fake_hf_weights(n_layers=2)
|
||||
sd = _hf_to_appsllm_state_dict(hf, 2)
|
||||
expected = {
|
||||
"token_embd.weight", "output_norm.weight", "output.weight",
|
||||
"blk.0.attn_norm.weight", "blk.0.ffn_norm.weight",
|
||||
"blk.0.attn_q.weight", "blk.0.attn_k.weight", "blk.0.attn_v.weight",
|
||||
"blk.0.attn_output.weight",
|
||||
"blk.0.attn_q_norm.weight", "blk.0.attn_k_norm.weight",
|
||||
"blk.0.ffn_gate.weight", "blk.0.ffn_up.weight", "blk.0.ffn_down.weight",
|
||||
"blk.1.attn_norm.weight", "blk.1.ffn_norm.weight",
|
||||
"blk.1.attn_q.weight", "blk.1.attn_k.weight", "blk.1.attn_v.weight",
|
||||
"blk.1.attn_output.weight",
|
||||
"blk.1.attn_q_norm.weight", "blk.1.attn_k_norm.weight",
|
||||
"blk.1.ffn_gate.weight", "blk.1.ffn_up.weight", "blk.1.ffn_down.weight",
|
||||
}
|
||||
self.assertEqual(set(sd.keys()), expected)
|
||||
|
||||
def test_tied_embedding_fallback_when_lm_head_missing(self):
|
||||
hf = self._fake_hf_weights(n_layers=1, include_lm_head=False)
|
||||
sd = _hf_to_appsllm_state_dict(hf, 1)
|
||||
self.assertIn("output.weight", sd)
|
||||
self.assertIs(sd["output.weight"], sd["token_embd.weight"])
|
||||
|
||||
def test_unknown_keys_are_skipped(self):
|
||||
hf = self._fake_hf_weights(n_layers=1)
|
||||
hf["model.layers.0.self_attn.rotary_emb.inv_freq"] = object()
|
||||
hf["model.some_unknown.weight"] = object()
|
||||
sd = _hf_to_appsllm_state_dict(hf, 1)
|
||||
self.assertNotIn("model.some_unknown.weight", sd)
|
||||
# Renamed keys still present
|
||||
self.assertIn("blk.0.attn_q.weight", sd)
|
||||
|
||||
def test_qkv_bias_models_rejected(self):
|
||||
hf = self._fake_hf_weights(n_layers=1)
|
||||
hf["model.layers.0.self_attn.q_proj.bias"] = object()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
_hf_to_appsllm_state_dict(hf, 1)
|
||||
self.assertIn("Qwen3", str(ctx.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/tinygrad/test.sh
Executable file
11
backend/python/tinygrad/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
11
backend/python/tinygrad/tool_parsers/__init__.py
Normal file
11
backend/python/tinygrad/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Tool-call parsers for the tinygrad backend.
|
||||
|
||||
Each parser takes raw model output and extracts OpenAI-style tool calls so
|
||||
the backend can populate `ChatDelta.tool_calls[]` natively (matching vLLM's
|
||||
behavior, which the Go core prefers over regex fallback parsing).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import ToolCall, ToolParser, resolve_parser
|
||||
|
||||
__all__ = ["ToolCall", "ToolParser", "resolve_parser"]
|
||||
85
backend/python/tinygrad/tool_parsers/base.py
Normal file
85
backend/python/tinygrad/tool_parsers/base.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Common types + parser registry for tool-call extraction."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""One extracted tool call — maps 1:1 to backend_pb2.ToolCallDelta."""
|
||||
index: int
|
||||
name: str
|
||||
arguments: str # JSON string
|
||||
id: str = ""
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""Parser interface.
|
||||
|
||||
Subclasses implement `parse` (full non-streaming pass) and optionally
|
||||
`parse_stream` (incremental). The default `parse_stream` buffers until a
|
||||
full response is available and then delegates to `parse`.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._stream_buffer = ""
|
||||
self._stream_index = 0
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
"""Return (content_for_user, tool_calls)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_stream(self, delta: str, finished: bool = False) -> tuple[str, list[ToolCall]]:
|
||||
"""Accumulate a streaming delta. Emits any tool calls that have closed.
|
||||
|
||||
Default behavior: buffer until `finished=True`, then parse once.
|
||||
Subclasses can override to emit mid-stream.
|
||||
"""
|
||||
self._stream_buffer += delta
|
||||
if not finished:
|
||||
return "", []
|
||||
content, calls = self.parse(self._stream_buffer)
|
||||
# Re-index starting from whatever we've already emitted in this stream.
|
||||
reindexed: list[ToolCall] = []
|
||||
for i, c in enumerate(calls):
|
||||
reindexed.append(ToolCall(
|
||||
index=self._stream_index + i,
|
||||
name=c.name,
|
||||
arguments=c.arguments,
|
||||
id=c.id,
|
||||
))
|
||||
self._stream_index += len(reindexed)
|
||||
return content, reindexed
|
||||
|
||||
def reset(self) -> None:
|
||||
self._stream_buffer = ""
|
||||
self._stream_index = 0
|
||||
|
||||
|
||||
_REGISTRY: dict[str, type[ToolParser]] = {}
|
||||
|
||||
|
||||
def register(cls: type[ToolParser]) -> type[ToolParser]:
|
||||
_REGISTRY[cls.name] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def resolve_parser(name: Optional[str]) -> ToolParser:
|
||||
"""Return a parser instance by name, falling back to a no-op passthrough."""
|
||||
# Import for side effects — each module registers itself.
|
||||
from . import hermes, llama3_json, mistral, qwen3_xml # noqa: F401
|
||||
|
||||
if name and name in _REGISTRY:
|
||||
return _REGISTRY[name]()
|
||||
return PassthroughToolParser()
|
||||
|
||||
|
||||
class PassthroughToolParser(ToolParser):
|
||||
"""No-op parser — used when no tool_parser is configured."""
|
||||
name = "passthrough"
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
return text, []
|
||||
74
backend/python/tinygrad/tool_parsers/hermes.py
Normal file
74
backend/python/tinygrad/tool_parsers/hermes.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Hermes-format tool-call parser.
|
||||
|
||||
Hermes 2 / 2.5 / 3 (and Qwen 2.5 Instruct, which adopted the same convention)
|
||||
emit tool calls wrapped in `<tool_call>...</tool_call>` tags, where the inner
|
||||
content is a JSON object with `name` and `arguments` keys:
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_weather", "arguments": {"city": "Paris"}}
|
||||
</tool_call>
|
||||
|
||||
Multiple tool calls may appear back-to-back. Text outside the tags is plain
|
||||
assistant content that should surface to the user.
|
||||
|
||||
This parser also strips `<think>...</think>` reasoning blocks and returns them
|
||||
via the reasoning_content channel (Qwen 3, DeepSeek-R1 distills).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import ToolCall, ToolParser, register
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
|
||||
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HermesParseResult:
|
||||
content: str
|
||||
reasoning: str
|
||||
tool_calls: list[ToolCall]
|
||||
|
||||
|
||||
@register
|
||||
class HermesToolParser(ToolParser):
|
||||
name = "hermes"
|
||||
|
||||
def _parse_full(self, text: str) -> HermesParseResult:
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
def _capture_reasoning(match: re.Match[str]) -> str:
|
||||
reasoning_parts.append(match.group(1).strip())
|
||||
return ""
|
||||
|
||||
text_wo_think = _THINK_RE.sub(_capture_reasoning, text)
|
||||
|
||||
calls: list[ToolCall] = []
|
||||
for idx, match in enumerate(_TOOL_CALL_RE.finditer(text_wo_think)):
|
||||
raw = match.group(1)
|
||||
try:
|
||||
obj = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(obj, dict):
|
||||
continue
|
||||
name = obj.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
args = obj.get("arguments", {})
|
||||
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
|
||||
calls.append(ToolCall(index=idx, name=name, arguments=args_str))
|
||||
|
||||
content = _TOOL_CALL_RE.sub("", text_wo_think).strip()
|
||||
reasoning = "\n\n".join(reasoning_parts).strip()
|
||||
return HermesParseResult(content=content, reasoning=reasoning, tool_calls=calls)
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
result = self._parse_full(text)
|
||||
return result.content, result.tool_calls
|
||||
|
||||
def parse_full(self, text: str) -> HermesParseResult:
|
||||
return self._parse_full(text)
|
||||
86
backend/python/tinygrad/tool_parsers/llama3_json.py
Normal file
86
backend/python/tinygrad/tool_parsers/llama3_json.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Llama 3.1 / 3.2 / 3.3 JSON tool-call parser.
|
||||
|
||||
Meta's Llama 3.1+ instruct chat templates emit tool calls in two broadly
|
||||
compatible shapes:
|
||||
|
||||
1. With the `<|python_tag|>` lead-in:
|
||||
<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}}
|
||||
2. As a bare JSON object (or list of objects) at the end of the turn.
|
||||
|
||||
We also handle multi-call shapes where the model emits several JSON objects
|
||||
separated by `;` or newlines, and JSON arrays `[{...}, {...}]`. The key field
|
||||
for Llama 3 is historically `parameters` (older docs) but recent checkpoints
|
||||
also emit `arguments` — accept either.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import ToolCall, ToolParser, register
|
||||
|
||||
_PYTHON_TAG = "<|python_tag|>"
|
||||
_JSON_OBJECT_RE = re.compile(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", re.DOTALL)
|
||||
|
||||
|
||||
def _coerce_call(obj: object, index: int) -> ToolCall | None:
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
name = obj.get("name")
|
||||
if not isinstance(name, str):
|
||||
return None
|
||||
args = obj.get("arguments", obj.get("parameters", {}))
|
||||
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
|
||||
return ToolCall(index=index, name=name, arguments=args_str)
|
||||
|
||||
|
||||
@register
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
name = "llama3_json"
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
calls: list[ToolCall] = []
|
||||
|
||||
# Strip <|python_tag|> segments first — each segment is one tool call
|
||||
# body. The content after the final python_tag (if any) is the call.
|
||||
remaining = text
|
||||
if _PYTHON_TAG in text:
|
||||
head, *tails = text.split(_PYTHON_TAG)
|
||||
remaining = head
|
||||
for tail in tails:
|
||||
parsed = _try_parse(tail.strip(), len(calls))
|
||||
calls.extend(parsed)
|
||||
|
||||
# Any JSON objects / arrays left in `remaining` count as tool calls too
|
||||
# if they parse to a {"name": ..., "arguments": ...} shape.
|
||||
for match in _JSON_OBJECT_RE.finditer(remaining):
|
||||
parsed = _try_parse(match.group(0), len(calls))
|
||||
if parsed:
|
||||
calls.extend(parsed)
|
||||
remaining = remaining.replace(match.group(0), "", 1)
|
||||
|
||||
content = remaining.strip()
|
||||
return content, calls
|
||||
|
||||
|
||||
def _try_parse(blob: str, start_index: int) -> list[ToolCall]:
|
||||
"""Parse a fragment that may be a JSON object or a JSON array of objects."""
|
||||
blob = blob.strip().rstrip(";")
|
||||
if not blob:
|
||||
return []
|
||||
try:
|
||||
obj = json.loads(blob)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
if isinstance(obj, dict):
|
||||
call = _coerce_call(obj, start_index)
|
||||
return [call] if call else []
|
||||
if isinstance(obj, list):
|
||||
calls: list[ToolCall] = []
|
||||
for i, item in enumerate(obj):
|
||||
c = _coerce_call(item, start_index + i)
|
||||
if c:
|
||||
calls.append(c)
|
||||
return calls
|
||||
return []
|
||||
56
backend/python/tinygrad/tool_parsers/mistral.py
Normal file
56
backend/python/tinygrad/tool_parsers/mistral.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Mistral / Mixtral tool-call parser.
|
||||
|
||||
Mistral Nemo / Small / Large Instruct emit tool calls prefixed with the
|
||||
`[TOOL_CALLS]` control token, followed by a JSON array:
|
||||
|
||||
[TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}}]
|
||||
|
||||
Multiple calls live inside the same array. Any text before `[TOOL_CALLS]` is
|
||||
normal assistant content and should surface to the user.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .base import ToolCall, ToolParser, register
|
||||
|
||||
_MARKER = "[TOOL_CALLS]"
|
||||
_JSON_ARRAY_RE = re.compile(r"\[\s*(?:\{.*?\}\s*,?\s*)+\]", re.DOTALL)
|
||||
|
||||
|
||||
@register
|
||||
class MistralToolParser(ToolParser):
|
||||
name = "mistral"
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
if _MARKER not in text:
|
||||
return text.strip(), []
|
||||
|
||||
head, tail = text.split(_MARKER, 1)
|
||||
content = head.strip()
|
||||
|
||||
match = _JSON_ARRAY_RE.search(tail)
|
||||
if not match:
|
||||
return content, []
|
||||
|
||||
try:
|
||||
arr = json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
return content, []
|
||||
|
||||
if not isinstance(arr, list):
|
||||
return content, []
|
||||
|
||||
calls: list[ToolCall] = []
|
||||
for i, obj in enumerate(arr):
|
||||
if not isinstance(obj, dict):
|
||||
continue
|
||||
name = obj.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
args = obj.get("arguments", {})
|
||||
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
|
||||
calls.append(ToolCall(index=i, name=name, arguments=args_str))
|
||||
|
||||
return content, calls
|
||||
74
backend/python/tinygrad/tool_parsers/qwen3_xml.py
Normal file
74
backend/python/tinygrad/tool_parsers/qwen3_xml.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Qwen 3 XML tool-call parser.
|
||||
|
||||
Qwen 3 Instruct emits tool calls wrapped in a two-level tag structure:
|
||||
|
||||
<tool_call>
|
||||
<function=get_weather>
|
||||
<parameter=city>
|
||||
Paris
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
Parameter values are raw text — we treat them as strings unless they look
|
||||
like JSON (in which case we try to parse so numbers / booleans round-trip
|
||||
cleanly). Qwen 3 also supports `<think>...</think>` reasoning blocks before
|
||||
the tool call — these are captured via the shared Hermes convention.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .base import ToolCall, ToolParser, register
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||
_FUNCTION_RE = re.compile(r"<function=([^>]+)>(.*?)</function>", re.DOTALL)
|
||||
_PARAMETER_RE = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
|
||||
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
def _maybe_json(value: str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return value
|
||||
if value[0] in "{[\"" or value in ("true", "false", "null") or value.lstrip("-").replace(".", "", 1).isdigit():
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
@register
|
||||
class Qwen3XmlToolParser(ToolParser):
|
||||
name = "qwen3_xml"
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
# Strip reasoning blocks from the user-visible content.
|
||||
stripped = _THINK_RE.sub("", text)
|
||||
|
||||
calls: list[ToolCall] = []
|
||||
for match in _TOOL_CALL_RE.finditer(stripped):
|
||||
body = match.group(1)
|
||||
fn_match = _FUNCTION_RE.search(body)
|
||||
if not fn_match:
|
||||
continue
|
||||
name = fn_match.group(1).strip()
|
||||
params_body = fn_match.group(2)
|
||||
|
||||
params: dict[str, object] = {}
|
||||
for pm in _PARAMETER_RE.finditer(params_body):
|
||||
params[pm.group(1).strip()] = _maybe_json(pm.group(2))
|
||||
|
||||
calls.append(ToolCall(
|
||||
index=len(calls),
|
||||
name=name,
|
||||
arguments=json.dumps(params, ensure_ascii=False),
|
||||
))
|
||||
|
||||
content = _TOOL_CALL_RE.sub("", stripped).strip()
|
||||
return content, calls
|
||||
6
backend/python/tinygrad/vendor/__init__.py
vendored
Normal file
6
backend/python/tinygrad/vendor/__init__.py
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Vendored upstream tinygrad reference code (MIT-licensed).
|
||||
|
||||
Source: https://github.com/tinygrad/tinygrad
|
||||
These files are not part of the `tinygrad` pip package (the `extra/` tree is
|
||||
excluded from `pyproject.toml` `packages`), so we carry a pinned copy here.
|
||||
"""
|
||||
102
backend/python/tinygrad/vendor/appsllm_adapter.py
vendored
Normal file
102
backend/python/tinygrad/vendor/appsllm_adapter.py
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Glue code between LocalAI's HF-shaped model assets and tinygrad.apps.llm.
|
||||
|
||||
apps.llm's `Transformer` uses GGUF-native weight names and consumes a
|
||||
`TransformerConfig` dataclass. LocalAI resolves models from HuggingFace
|
||||
snapshots (HF safetensors + config.json) so we translate both sides here.
|
||||
|
||||
This module does NOT subclass anything from apps.llm. With the Qwen3+
|
||||
scope the backend targets, we can use `apps.llm.Transformer` unchanged
|
||||
(no qkv_bias, no RoPE permute). Everything below is a thin adapter.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _hf_to_appsllm_state_dict(hf_weights: dict[str, Any], n_layers: int) -> dict[str, Any]:
|
||||
"""Rename a HuggingFace-style state dict to the GGUF-native keys that
|
||||
`tinygrad.apps.llm.Transformer` expects.
|
||||
|
||||
HF and apps.llm both store RoPE weights in half-split layout, so no
|
||||
permute is required — only a direct key rename and a tied-embedding
|
||||
fallback for models like Llama 3.2 that drop `lm_head.weight`.
|
||||
"""
|
||||
keymap: dict[str, str] = {
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
for layer in range(n_layers):
|
||||
keymap[f"model.layers.{layer}.input_layernorm.weight"] = f"blk.{layer}.attn_norm.weight"
|
||||
keymap[f"model.layers.{layer}.post_attention_layernorm.weight"] = f"blk.{layer}.ffn_norm.weight"
|
||||
for hf_proj, gguf_proj in (("q", "q"), ("k", "k"), ("v", "v"), ("o", "output")):
|
||||
keymap[f"model.layers.{layer}.self_attn.{hf_proj}_proj.weight"] = f"blk.{layer}.attn_{gguf_proj}.weight"
|
||||
keymap[f"model.layers.{layer}.self_attn.q_norm.weight"] = f"blk.{layer}.attn_q_norm.weight"
|
||||
keymap[f"model.layers.{layer}.self_attn.k_norm.weight"] = f"blk.{layer}.attn_k_norm.weight"
|
||||
for hf_name, gguf_name in (("gate", "gate"), ("up", "up"), ("down", "down")):
|
||||
keymap[f"model.layers.{layer}.mlp.{hf_name}_proj.weight"] = f"blk.{layer}.ffn_{gguf_name}.weight"
|
||||
|
||||
# Fail loudly if the model carries Q/K/V projection bias (Qwen2 / 2.5).
|
||||
# apps.llm's `TransformerBlock` hardcodes `bias=False`, so these weights
|
||||
# would be silently dropped by `load_state_dict(strict=False)` and the
|
||||
# model would produce garbage. Supported families (Qwen3, Qwen3.5,
|
||||
# Llama 3.x, GLM-4, Mistral) have no qkv bias.
|
||||
bias_keys = [k for k in hf_weights
|
||||
if k.startswith("model.layers.") and
|
||||
any(k.endswith(f".self_attn.{p}_proj.bias") for p in ("q", "k", "v"))]
|
||||
if bias_keys:
|
||||
raise ValueError(
|
||||
"tinygrad backend: model has Q/K/V projection bias ("
|
||||
f"{bias_keys[0]} etc). Supported families are Qwen3, Qwen3.5, "
|
||||
"Llama 3.x, GLM-4, Mistral. For Qwen2 / 2.5 please use a "
|
||||
"newer model or the vLLM / llama.cpp backends."
|
||||
)
|
||||
|
||||
sd = {dst: hf_weights[src] for src, dst in keymap.items() if src in hf_weights}
|
||||
if "output.weight" not in sd and "token_embd.weight" in sd:
|
||||
sd["output.weight"] = sd["token_embd.weight"]
|
||||
return sd
|
||||
|
||||
|
||||
def _hf_to_transformer_kwargs(hf_config: dict, state_dict: dict[str, Any], max_context: int) -> dict:
|
||||
"""Build the kwargs dict for `tinygrad.apps.llm.Transformer(**kwargs)`.
|
||||
|
||||
Supports dense Qwen3 / Qwen3.5 / Llama 3.x / GLM-4 / Mistral-shaped
|
||||
models. The tinygrad 0.12.0 `Transformer` takes keyword-only args (no
|
||||
`TransformerConfig` dataclass) — so we return a plain dict.
|
||||
"""
|
||||
n_heads = hf_config["num_attention_heads"]
|
||||
head_dim = hf_config.get("head_dim") or (hf_config["hidden_size"] // n_heads)
|
||||
|
||||
# Detect qk_norm presence from the GGUF-shaped state dict (matches
|
||||
# apps.llm's own heuristic in `from_gguf`).
|
||||
qk_norm = 0
|
||||
qn = state_dict.get("blk.0.attn_q_norm.weight")
|
||||
if qn is not None:
|
||||
qk_norm = int(qn.shape[0])
|
||||
|
||||
max_pos = hf_config.get("max_position_embeddings", 4096)
|
||||
|
||||
return dict(
|
||||
num_blocks=hf_config["num_hidden_layers"],
|
||||
dim=hf_config["hidden_size"],
|
||||
hidden_dim=hf_config["intermediate_size"],
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=hf_config.get("num_key_value_heads", n_heads),
|
||||
norm_eps=hf_config.get("rms_norm_eps", 1e-5),
|
||||
vocab_size=hf_config["vocab_size"],
|
||||
head_dim=head_dim,
|
||||
rope_theta=float(hf_config.get("rope_theta", 10000.0)),
|
||||
max_context=min(max_pos, max_context),
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
|
||||
def _embed_hidden(model, tokens):
|
||||
"""Return mean-poolable hidden states by running the block stack
|
||||
without going through the LM head + Gumbel-max sampler baked into
|
||||
`Transformer.forward`."""
|
||||
x = model.token_embd(tokens).float()
|
||||
for blk in model.blk:
|
||||
x = blk(x, 0)
|
||||
return model.output_norm(x)
|
||||
83
backend/python/tinygrad/vendor/audio_helpers.py
vendored
Normal file
83
backend/python/tinygrad/vendor/audio_helpers.py
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
# Vendored verbatim from tinygrad examples/audio_helpers.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/audio_helpers.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
from typing import Optional
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import DTypeLike, dtypes
|
||||
import math
|
||||
|
||||
# rewritten from numpy
|
||||
def rfftfreq(n: int, d: float = 1.0, device=None) -> Tensor:
|
||||
val = 1.0 / (n * d)
|
||||
N = n // 2 + 1
|
||||
results = Tensor.arange(N, device=device)
|
||||
return results * val
|
||||
|
||||
# just like in librosa
|
||||
def fft_frequencies(sr: float, n_fft: int) -> Tensor:
|
||||
return rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
def hz_to_mel(freq: Tensor) -> Tensor:
|
||||
# linear part
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
mels = (freq - f_min) / f_sp
|
||||
|
||||
# log-scale part
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
mask = freq >= min_log_hz
|
||||
return mask.where(((min_log_hz - f_min) / f_sp) + (freq / min_log_hz).log() / (math.log(6.4) / 27.0), mels)
|
||||
|
||||
def mel_to_hz(mels: Tensor) -> Tensor:
|
||||
# linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mels
|
||||
|
||||
# nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = math.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
log_t = mels >= min_log_mel
|
||||
freqs = log_t.where(min_log_hz * ((logstep * (mels - min_log_mel)).exp()), freqs)
|
||||
return freqs
|
||||
|
||||
def mel_frequencies(n_mels: int = 128, *, fmin: float = 0.0, fmax: float = 11025.0) -> Tensor:
|
||||
# center freqs of mel bands - uniformly spaced between limits
|
||||
min_max_mel = hz_to_mel(Tensor([fmin, fmax]))
|
||||
|
||||
mels = Tensor.linspace(min_max_mel[0], min_max_mel[1], n_mels)
|
||||
hz = mel_to_hz(mels)
|
||||
return hz
|
||||
|
||||
def mel(
|
||||
*,
|
||||
sr: float,
|
||||
n_fft: int,
|
||||
n_mels: int = 128,
|
||||
fmin: float = 0.0,
|
||||
fmax: Optional[float] = None,
|
||||
dtype: DTypeLike = dtypes.default_float,
|
||||
) -> Tensor:
|
||||
if fmax is None:
|
||||
fmax = float(sr) / 2
|
||||
|
||||
n_mels = int(n_mels)
|
||||
|
||||
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) # center freqs of each FFT bin
|
||||
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax) # center freqs of mel bands
|
||||
|
||||
fdiff = mel_f[1:] - mel_f[:-1]
|
||||
ramps = mel_f[None].T.expand(-1, fftfreqs.shape[-1]) - fftfreqs
|
||||
|
||||
lower = -ramps[:n_mels] / fdiff[:n_mels][None].T
|
||||
upper = ramps[2 : n_mels + 2] / fdiff[1 : n_mels + 1][None].T
|
||||
weights = lower.minimum(upper).maximum(0)
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, None]
|
||||
|
||||
return weights
|
||||
484
backend/python/tinygrad/vendor/clip.py
vendored
Normal file
484
backend/python/tinygrad/vendor/clip.py
vendored
Normal file
@@ -0,0 +1,484 @@
|
||||
# Vendored verbatim from tinygrad extra/models/clip.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/clip.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn import Linear, LayerNorm, Embedding, Conv2d
|
||||
|
||||
from typing import List, Optional, Union, Tuple, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
import numpy as np
|
||||
import re, gzip
|
||||
|
||||
# Allow for monkeypatching for mlperf.
|
||||
gelu = Tensor.gelu
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Namespace for CLIP Text Tokenizer components.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
return set(zip(word, word[1:]))
|
||||
@staticmethod
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
@staticmethod
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
class ClipTokenizer:
|
||||
def __init__(self, version=None):
|
||||
self.byte_encoder, self.version = Tokenizer.bytes_to_unicode(), version
|
||||
merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(Tokenizer.bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
if self.version == "sd_mlperf_v5_0":
|
||||
import regex
|
||||
vocab.extend(['<start_of_text>', '<end_of_text>'])
|
||||
self.cache = {'<start_of_text>': '<start_of_text>', '<end_of_text>': '<end_of_text>'}
|
||||
self.pat = regex.compile(r"""<start_of_text>|<end_of_text>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
|
||||
else:
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = Tokenizer.get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
pairs = Tokenizer.get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
|
||||
bpe_tokens: List[int] = []
|
||||
if self.version == "sd_mlperf_v5_0":
|
||||
import regex, ftfy, html
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text)).strip()
|
||||
text = Tokenizer.whitespace_clean(text).lower()
|
||||
re_module = regex
|
||||
else:
|
||||
text = Tokenizer.whitespace_clean(text.strip()).lower()
|
||||
re_module = re
|
||||
|
||||
for token in re_module.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
# Truncation, keeping two slots for start and end tokens.
|
||||
if len(bpe_tokens) > 75:
|
||||
bpe_tokens = bpe_tokens[:75]
|
||||
return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
|
||||
|
||||
|
||||
class Embedder(ABC):
|
||||
input_key: str
|
||||
@abstractmethod
|
||||
def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
pass
|
||||
|
||||
|
||||
class Closed:
|
||||
"""
|
||||
Namespace for OpenAI CLIP model components.
|
||||
"""
|
||||
class ClipMlp:
|
||||
def __init__(self):
|
||||
self.fc1 = Linear(768, 3072)
|
||||
self.fc2 = Linear(3072, 768)
|
||||
|
||||
def __call__(self, h:Tensor) -> Tensor:
|
||||
h = self.fc1(h)
|
||||
h = h.quick_gelu()
|
||||
h = self.fc2(h)
|
||||
return h
|
||||
|
||||
class ClipAttention:
|
||||
def __init__(self):
|
||||
self.embed_dim = 768
|
||||
self.num_heads = 12
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
||||
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
||||
|
||||
class ClipEncoderLayer:
|
||||
def __init__(self):
|
||||
self.self_attn = Closed.ClipAttention()
|
||||
self.layer_norm1 = LayerNorm(768)
|
||||
self.mlp = Closed.ClipMlp()
|
||||
self.layer_norm2 = LayerNorm(768)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
class ClipTextEmbeddings:
|
||||
def __init__(self):
|
||||
self.token_embedding = Embedding(49408, 768)
|
||||
self.position_embedding = Embedding(77, 768)
|
||||
|
||||
def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
|
||||
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
||||
|
||||
class ClipEncoder:
|
||||
def __init__(self, layer_count:int=12):
|
||||
self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
|
||||
|
||||
def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
|
||||
# the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
|
||||
layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
|
||||
for l in layers:
|
||||
x = l(x, causal_attention_mask)
|
||||
return x
|
||||
|
||||
class ClipTextTransformer:
|
||||
def __init__(self, ret_layer_idx:Optional[int]=None):
|
||||
self.embeddings = Closed.ClipTextEmbeddings()
|
||||
self.encoder = Closed.ClipEncoder()
|
||||
self.final_layer_norm = LayerNorm(768)
|
||||
self.ret_layer_idx = ret_layer_idx
|
||||
|
||||
def __call__(self, input_ids:Tensor) -> Tensor:
|
||||
x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
|
||||
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
|
||||
return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
|
||||
|
||||
class ClipTextModel:
|
||||
def __init__(self, ret_layer_idx:Optional[int]):
|
||||
self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
|
||||
class FrozenClosedClipEmbedder(Embedder):
|
||||
def __init__(self, ret_layer_idx:Optional[int]=None):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer()
|
||||
self.transformer = Closed.ClipTextModel(ret_layer_idx)
|
||||
self.input_key = "txt"
|
||||
|
||||
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
||||
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
|
||||
return self.transformer.text_model(tokens.reshape(len(texts),-1))
|
||||
|
||||
|
||||
class Open:
|
||||
"""
|
||||
Namespace for OpenCLIP model components.
|
||||
"""
|
||||
class MultiheadAttention:
|
||||
def __init__(self, dims:int, n_heads:int):
|
||||
self.dims = dims
|
||||
self.n_heads = n_heads
|
||||
self.d_head = self.dims // self.n_heads
|
||||
|
||||
self.in_proj_bias = Tensor.empty(3*dims)
|
||||
self.in_proj_weight = Tensor.empty(3*dims, dims)
|
||||
self.out_proj = Linear(dims, dims)
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
T,B,C = x.shape
|
||||
|
||||
proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
|
||||
proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0, -2)
|
||||
|
||||
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)]
|
||||
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).reshape(T, B, C)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
class Mlp:
|
||||
def __init__(self, dims, hidden_dims):
|
||||
self.c_fc = Linear(dims, hidden_dims)
|
||||
self.c_proj = Linear(hidden_dims, dims)
|
||||
self.gelu = gelu
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential([self.c_fc, self.gelu, self.c_proj])
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
|
||||
class ResidualAttentionBlock:
|
||||
def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
|
||||
self.ln_1 = LayerNorm(dims)
|
||||
self.attn = Open.MultiheadAttention(dims, n_heads)
|
||||
|
||||
self.ln_2 = LayerNorm(dims)
|
||||
self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None, transpose:bool=False) -> Tensor:
|
||||
q_x = self.ln_1(x)
|
||||
attn_out = self.attn(q_x.transpose(0, 1) if transpose else q_x, attn_mask=attn_mask)
|
||||
attn_out = attn_out.transpose(0, 1) if transpose else attn_out
|
||||
x = x + attn_out
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
|
||||
class ClipTransformer:
|
||||
def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
|
||||
self.resblocks = [
|
||||
Open.ResidualAttentionBlock(dims, n_heads, mlp_ratio) for _ in range(layers)
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
for r in self.resblocks:
|
||||
x = r(x, attn_mask=attn_mask, transpose=True)
|
||||
return x
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
|
||||
class ClipTextTransformer:
|
||||
def __init__(self, width:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
|
||||
self.token_embedding = Embedding(vocab_size, width)
|
||||
self.positional_embedding = Tensor.empty(ctx_length, width)
|
||||
self.transformer = Open.ClipTransformer(width, layers, n_heads)
|
||||
self.ln_final = LayerNorm(width)
|
||||
self.text_projection = Tensor.empty(width, width)
|
||||
self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
|
||||
|
||||
def __call__(self, text:Tensor) -> Tensor:
|
||||
seq_len = text.shape[1]
|
||||
|
||||
x = self.token_embedding(text)
|
||||
x = x + self.positional_embedding[:seq_len]
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x)
|
||||
|
||||
pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
|
||||
return pooled
|
||||
|
||||
class ClipVisionTransformer:
|
||||
def __init__(self, width:int, layers:int, d_head:int, image_size:int, patch_size:int):
|
||||
grid_size = image_size // patch_size
|
||||
n_heads = width // d_head
|
||||
assert n_heads * d_head == width
|
||||
|
||||
self.conv1 = Conv2d(3, width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
self.class_embedding = Tensor.empty(width)
|
||||
self.positional_embedding = Tensor.empty(grid_size * grid_size + 1, width)
|
||||
self.transformer = Open.ClipTransformer(width, layers, n_heads)
|
||||
self.ln_pre = LayerNorm(width)
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = Tensor.empty(width, 1024)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self.conv1(x)
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
|
||||
x = self.class_embedding.reshape(1, 1, -1).expand(x.shape[0], 1, -1).cat(x, dim=1)
|
||||
x = x + self.positional_embedding
|
||||
|
||||
x = self.ln_pre(x)
|
||||
x = self.transformer(x)
|
||||
x = self.ln_post(x)
|
||||
|
||||
pooled = x[:, 0] @ self.proj
|
||||
return pooled
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
|
||||
class FrozenOpenClipEmbedder(Embedder):
|
||||
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False, clip_tokenizer_version=None):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer(version=clip_tokenizer_version)
|
||||
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
|
||||
self.return_pooled = return_pooled
|
||||
self.input_key = "txt"
|
||||
self.ln_penultimate = ln_penultimate
|
||||
|
||||
def tokenize(self, text:str, device:Optional[str]=None) -> Tensor:
|
||||
return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int32, device=device).reshape(1,-1)
|
||||
|
||||
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
|
||||
for r in self.model.transformer.resblocks:
|
||||
x, penultimate = r(x, attn_mask=attn_mask), x
|
||||
return x.permute(1, 0, 2), penultimate.permute(1, 0, 2)
|
||||
|
||||
def embed_tokens(self, tokens:Tensor) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
|
||||
x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
|
||||
if self.ln_penultimate:
|
||||
penultimate = self.model.ln_final(penultimate)
|
||||
|
||||
if self.return_pooled:
|
||||
x = self.model.ln_final(x)
|
||||
index = tokens.argmax(axis=-1).reshape(-1,1,1).expand(x.shape[0],1,x.shape[-1])
|
||||
pooled = x.gather(1, index).squeeze(1) @ self.model.text_projection
|
||||
return penultimate, pooled
|
||||
else:
|
||||
return penultimate
|
||||
|
||||
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
||||
tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0)
|
||||
return self.embed_tokens(tokens)
|
||||
|
||||
|
||||
clip_configs: Dict = {
|
||||
"ViT-H-14": {
|
||||
"dims": 1024,
|
||||
"vision_cfg": {
|
||||
"width": 1280,
|
||||
"layers": 32,
|
||||
"d_head": 80,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
},
|
||||
"text_cfg": {
|
||||
"width": 1024,
|
||||
"n_heads": 16,
|
||||
"layers": 24,
|
||||
"ctx_length": 77,
|
||||
"vocab_size": 49408,
|
||||
},
|
||||
"return_pooled": False,
|
||||
"ln_penultimate": True,
|
||||
}
|
||||
}
|
||||
|
||||
class OpenClipEncoder:
|
||||
def __init__(self, dims:int, text_cfg:Dict, vision_cfg:Dict, **_):
|
||||
self.visual = Open.ClipVisionTransformer(**vision_cfg)
|
||||
|
||||
text = Open.ClipTextTransformer(**text_cfg)
|
||||
self.transformer = text.transformer
|
||||
self.token_embedding = text.token_embedding
|
||||
self.positional_embedding = text.positional_embedding
|
||||
self.ln_final = text.ln_final
|
||||
self.text_projection = text.text_projection
|
||||
|
||||
self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
|
||||
self.mean = Tensor([0.48145466, 0.45782750, 0.40821073]).reshape(-1, 1, 1)
|
||||
self.std = Tensor([0.26862954, 0.26130258, 0.27577711]).reshape(-1, 1, 1)
|
||||
|
||||
# TODO:
|
||||
# Should be doable in pure tinygrad, would just require some work and verification.
|
||||
# This is very desirable since it would allow for full generation->evaluation in a single JIT call.
|
||||
def prepare_image(self, image) -> Tensor:
|
||||
from PIL import Image
|
||||
SIZE = 224
|
||||
w, h = image.size
|
||||
scale = min(SIZE / h, SIZE / w)
|
||||
image = image.resize((max(int(w*scale),SIZE),max(int(h*scale),SIZE)), Image.Resampling.BICUBIC)
|
||||
w, h = image.size
|
||||
if w > SIZE:
|
||||
left = (w - SIZE) // 2
|
||||
image = image.crop((left, left+SIZE, 0, SIZE))
|
||||
elif h > SIZE:
|
||||
top = (h - SIZE) // 2
|
||||
image = image.crop((0, SIZE, top, top+SIZE))
|
||||
|
||||
x = Tensor(np.array(image.convert('RGB')), device=self.std.device)
|
||||
x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0
|
||||
return (x - self.mean) / self.std
|
||||
|
||||
def encode_tokens(self, tokens:Tensor) -> Tensor:
|
||||
x = self.token_embedding(tokens)
|
||||
x = x + self.positional_embedding
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x)
|
||||
x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)]
|
||||
x = x @ self.text_projection
|
||||
return x
|
||||
|
||||
def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor:
|
||||
image_features: Tensor = self.visual(image)
|
||||
image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
|
||||
|
||||
text_features = self.encode_tokens(tokens)
|
||||
text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
|
||||
|
||||
return (image_features * text_features).sum(axis=-1)
|
||||
232
backend/python/tinygrad/vendor/stable_diffusion.py
vendored
Normal file
232
backend/python/tinygrad/vendor/stable_diffusion.py
vendored
Normal file
@@ -0,0 +1,232 @@
|
||||
# Adapted from tinygrad examples/stable_diffusion.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/stable_diffusion.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# Local modifications: removed the MLPerf training branch (pulls
|
||||
# examples/mlperf/initializers which we don't vendor) and the __main__
|
||||
# argparse / fetch / profile blocks. Kept the core classes so the LocalAI
|
||||
# tinygrad backend can instantiate and drive Stable Diffusion v1.x from a
|
||||
# single checkpoint path.
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.nn import Conv2d, GroupNorm
|
||||
|
||||
from . import clip as clip_mod
|
||||
from . import unet as unet_mod
|
||||
from .clip import Closed, Tokenizer
|
||||
from .unet import UNetModel
|
||||
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
self.norm = GroupNorm(32, in_channels)
|
||||
self.q = Conv2d(in_channels, in_channels, 1)
|
||||
self.k = Conv2d(in_channels, in_channels, 1)
|
||||
self.v = Conv2d(in_channels, in_channels, 1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def __call__(self, x):
|
||||
h_ = self.norm(x)
|
||||
q, k, v = self.q(h_), self.k(h_), self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = [t.reshape(b, c, h * w).transpose(1, 2) for t in (q, k, v)]
|
||||
h_ = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(b, c, h, w)
|
||||
return x + self.proj_out(h_)
|
||||
|
||||
|
||||
class ResnetBlock:
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
self.norm1 = GroupNorm(32, in_channels)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
||||
self.norm2 = GroupNorm(32, out_channels)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else (lambda x: x)
|
||||
|
||||
def __call__(self, x):
|
||||
h = self.conv1(self.norm1(x).swish())
|
||||
h = self.conv2(self.norm2(h).swish())
|
||||
return self.nin_shortcut(x) + h
|
||||
|
||||
|
||||
class Mid:
|
||||
def __init__(self, block_in):
|
||||
self.block_1 = ResnetBlock(block_in, block_in)
|
||||
self.attn_1 = AttnBlock(block_in)
|
||||
self.block_2 = ResnetBlock(block_in, block_in)
|
||||
|
||||
def __call__(self, x):
|
||||
return x.sequential([self.block_1, self.attn_1, self.block_2])
|
||||
|
||||
|
||||
class Decoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
||||
self.conv_in = Conv2d(4, 512, 3, padding=1)
|
||||
self.mid = Mid(512)
|
||||
|
||||
arr = []
|
||||
for i, s in enumerate(sz):
|
||||
arr.append({"block": [ResnetBlock(s[1], s[0]), ResnetBlock(s[0], s[0]), ResnetBlock(s[0], s[0])]})
|
||||
if i != 0:
|
||||
arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
||||
self.up = arr
|
||||
|
||||
self.norm_out = GroupNorm(32, 128)
|
||||
self.conv_out = Conv2d(128, 3, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.mid(x)
|
||||
for l in self.up[::-1]:
|
||||
for b in l['block']:
|
||||
x = b(x)
|
||||
if 'upsample' in l:
|
||||
bs, c, py, px = x.shape
|
||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py * 2, px * 2)
|
||||
x = l['upsample']['conv'](x)
|
||||
x.realize()
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
||||
self.conv_in = Conv2d(3, 128, 3, padding=1)
|
||||
|
||||
arr = []
|
||||
for i, s in enumerate(sz):
|
||||
arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]})
|
||||
if i != 3:
|
||||
arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))}
|
||||
self.down = arr
|
||||
|
||||
self.mid = Mid(512)
|
||||
self.norm_out = GroupNorm(32, 512)
|
||||
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
for l in self.down:
|
||||
for b in l['block']:
|
||||
x = b(x)
|
||||
if 'downsample' in l:
|
||||
x = l['downsample']['conv'](x)
|
||||
x = self.mid(x)
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
|
||||
class AutoencoderKL:
|
||||
def __init__(self):
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.quant_conv = Conv2d(8, 8, 1)
|
||||
self.post_quant_conv = Conv2d(4, 4, 1)
|
||||
|
||||
def __call__(self, x):
|
||||
latent = self.encoder(x)
|
||||
latent = self.quant_conv(latent)
|
||||
latent = latent[:, 0:4]
|
||||
latent = self.post_quant_conv(latent)
|
||||
return self.decoder(latent)
|
||||
|
||||
|
||||
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
|
||||
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
return Tensor(alphas_cumprod)
|
||||
|
||||
|
||||
# SD1.x UNet hyperparameters (same as upstream `unet_params`).
|
||||
UNET_PARAMS_SD1: Dict[str, Any] = {
|
||||
"adm_in_ch": None,
|
||||
"in_ch": 4,
|
||||
"out_ch": 4,
|
||||
"model_ch": 320,
|
||||
"attention_resolutions": [4, 2, 1],
|
||||
"num_res_blocks": 2,
|
||||
"channel_mult": [1, 2, 4, 4],
|
||||
"n_heads": 8,
|
||||
"transformer_depth": [1, 1, 1, 1],
|
||||
"ctx_dim": 768,
|
||||
"use_linear": False,
|
||||
}
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
"""Stable Diffusion 1.x pipeline, adapted from tinygrad's reference example.
|
||||
|
||||
Drives the native CompVis `sd-v1-*.ckpt` checkpoint format (the only one
|
||||
the vendored weight layout handles). For HuggingFace safetensors pipelines
|
||||
the caller is expected to download / merge the `.ckpt` equivalent before
|
||||
calling LoadModel.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.alphas_cumprod = get_alphas_cumprod()
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
|
||||
transformer=namedtuple("Transformer", ["text_model"])(text_model=Closed.ClipTextTransformer())
|
||||
)
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(
|
||||
diffusion_model=UNetModel(**UNET_PARAMS_SD1)
|
||||
)
|
||||
|
||||
# DDIM update step.
|
||||
def _update(self, x, e_t, a_t, a_prev):
|
||||
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
dir_xt = (1.0 - a_prev).sqrt() * e_t
|
||||
return a_prev.sqrt() * pred_x0 + dir_xt
|
||||
|
||||
def _model_output(self, uncond, cond, latent, timestep, guidance):
|
||||
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, uncond.cat(cond, dim=0))
|
||||
uncond_latent, cond_latent = latents[0:1], latents[1:2]
|
||||
return uncond_latent + guidance * (cond_latent - uncond_latent)
|
||||
|
||||
def step(self, uncond, cond, latent, timestep, a_t, a_prev, guidance):
|
||||
e_t = self._model_output(uncond, cond, latent, timestep, guidance)
|
||||
return self._update(latent, e_t, a_t, a_prev).realize()
|
||||
|
||||
def decode(self, x):
|
||||
x = self.first_stage_model.post_quant_conv(1 / 0.18215 * x)
|
||||
x = self.first_stage_model.decoder(x)
|
||||
x = (x + 1.0) / 2.0
|
||||
x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
|
||||
return x.cast(dtypes.uint8)
|
||||
|
||||
def encode_prompt(self, tokenizer, prompt: str):
|
||||
ids = Tensor([tokenizer.encode(prompt)])
|
||||
return self.cond_stage_model.transformer.text_model(ids).realize()
|
||||
|
||||
|
||||
def run_sd15(model: StableDiffusion, prompt: str, negative_prompt: str, steps: int, guidance: float, seed: int):
|
||||
"""Generate a single 512x512 image. Returns a (512,512,3) uint8 tensor."""
|
||||
tokenizer = Tokenizer.ClipTokenizer()
|
||||
|
||||
context = model.encode_prompt(tokenizer, prompt)
|
||||
uncond = model.encode_prompt(tokenizer, negative_prompt)
|
||||
|
||||
timesteps = list(range(1, 1000, 1000 // steps))
|
||||
alphas = model.alphas_cumprod[Tensor(timesteps)]
|
||||
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
||||
|
||||
if seed is not None:
|
||||
Tensor.manual_seed(seed)
|
||||
latent = Tensor.randn(1, 4, 64, 64)
|
||||
|
||||
for index in range(len(timesteps) - 1, -1, -1):
|
||||
timestep = timesteps[index]
|
||||
tid = Tensor([index])
|
||||
latent = model.step(
|
||||
uncond, context, latent,
|
||||
Tensor([timestep]),
|
||||
alphas[tid], alphas_prev[tid],
|
||||
Tensor([guidance]),
|
||||
)
|
||||
|
||||
return model.decode(latent).realize()
|
||||
267
backend/python/tinygrad/vendor/unet.py
vendored
Normal file
267
backend/python/tinygrad/vendor/unet.py
vendored
Normal file
@@ -0,0 +1,267 @@
|
||||
# Vendored verbatim from tinygrad extra/models/unet.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/unet.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
from tinygrad import Tensor, dtypes, nn
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from typing import Optional, Union, List, Any, Tuple, Callable
|
||||
import math
|
||||
|
||||
# allow for monkeypatching
|
||||
Linear, Conv2d, GroupNorm, LayerNorm = nn.Linear, nn.Conv2d, nn.GroupNorm, nn.LayerNorm
|
||||
attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Tensor.gelu, dtypes.float16
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
||||
return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
|
||||
|
||||
class ResBlock:
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
|
||||
self.in_layers = [
|
||||
GroupNorm(num_groups, channels),
|
||||
Tensor.silu,
|
||||
Conv2d(channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.emb_layers = [
|
||||
Tensor.silu,
|
||||
Linear(emb_channels, out_channels),
|
||||
]
|
||||
self.out_layers = [
|
||||
GroupNorm(num_groups, out_channels),
|
||||
Tensor.silu,
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Conv2d(out_channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else (lambda x: x)
|
||||
|
||||
def __call__(self, x:Tensor, emb:Tensor) -> Tensor:
|
||||
h = x.sequential(self.in_layers)
|
||||
emb_out = emb.sequential(self.emb_layers)
|
||||
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
|
||||
h = h.sequential(self.out_layers)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
class CrossAttention:
|
||||
def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.attn = attention
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
ctx = x if ctx is None else ctx
|
||||
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = self.attn(q, k, v).transpose(1,2)
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in:int, dim_out:int):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
self.gelu = gelu
|
||||
self.dim_out = dim_out
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * self.gelu(gate)
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim:int, mult:int=4):
|
||||
self.net: tuple[GEGLU, Callable, nn.Linear] = (
|
||||
GEGLU(dim, dim*mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential(list(self.net))
|
||||
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
||||
self.ff = FeedForward(dim)
|
||||
self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.norm3 = LayerNorm(dim)
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
x = x + self.attn1(self.norm1(x))
|
||||
x = x + self.attn2(self.norm2(x), ctx=ctx)
|
||||
x = x + self.ff(self.norm3(x))
|
||||
return x
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1,
|
||||
norm_eps:float=1e-5):
|
||||
if isinstance(ctx_dim, int):
|
||||
ctx_dim = [ctx_dim]*depth
|
||||
else:
|
||||
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
|
||||
self.norm = GroupNorm(32, channels, eps=norm_eps)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
||||
self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
|
||||
self.proj_out = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
||||
self.use_linear = use_linear
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
ops = [ (lambda z: z.reshape(b, c, h*w).permute(0,2,1)), (lambda z: self.proj_in(z)) ]
|
||||
x = x.sequential(ops if self.use_linear else ops[::-1])
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, ctx=ctx)
|
||||
ops = [ (lambda z: self.proj_out(z)), (lambda z: z.permute(0,2,1).reshape(b, c, h, w)) ]
|
||||
x = x.sequential(ops if self.use_linear else ops[::-1])
|
||||
return x + x_in
|
||||
|
||||
class Downsample:
|
||||
def __init__(self, channels:int):
|
||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.op(x)
|
||||
|
||||
class Upsample:
|
||||
def __init__(self, channels:int):
|
||||
self.conv = Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
bs,c,py,px = x.shape
|
||||
z = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
||||
return self.conv(z)
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
|
||||
class UNetModel:
|
||||
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int,
|
||||
channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None,
|
||||
n_heads:Optional[int]=None, num_groups:int=32, st_norm_eps:float=1e-5):
|
||||
self.model_ch = model_ch
|
||||
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.d_head = d_head
|
||||
self.n_heads = n_heads
|
||||
def get_d_and_n_heads(dims:int) -> Tuple[int,int]:
|
||||
if self.d_head is None:
|
||||
assert self.n_heads is not None, f"d_head and n_heads cannot both be None"
|
||||
return dims // self.n_heads, self.n_heads
|
||||
else:
|
||||
assert self.n_heads is None, f"d_head and n_heads cannot both be non-None"
|
||||
return self.d_head, dims // self.d_head
|
||||
|
||||
time_embed_dim = model_ch * 4
|
||||
self.time_embed = [
|
||||
Linear(model_ch, time_embed_dim),
|
||||
Tensor.silu,
|
||||
Linear(time_embed_dim, time_embed_dim),
|
||||
]
|
||||
|
||||
if adm_in_ch is not None:
|
||||
self.label_emb = [
|
||||
[
|
||||
Linear(adm_in_ch, time_embed_dim),
|
||||
Tensor.silu,
|
||||
Linear(time_embed_dim, time_embed_dim),
|
||||
]
|
||||
]
|
||||
|
||||
self.input_blocks: List[Any] = [
|
||||
[Conv2d(in_ch, model_ch, 3, padding=1)]
|
||||
]
|
||||
input_block_channels = [model_ch]
|
||||
ch = model_ch
|
||||
ds = 1
|
||||
for idx, mult in enumerate(channel_mult):
|
||||
for _ in range(self.num_res_blocks[idx]):
|
||||
layers: List[Any] = [
|
||||
ResBlock(ch, time_embed_dim, model_ch*mult, num_groups),
|
||||
]
|
||||
ch = mult * model_ch
|
||||
if ds in attention_resolutions:
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
|
||||
|
||||
self.input_blocks.append(layers)
|
||||
input_block_channels.append(ch)
|
||||
|
||||
if idx != len(channel_mult) - 1:
|
||||
self.input_blocks.append([
|
||||
Downsample(ch),
|
||||
])
|
||||
input_block_channels.append(ch)
|
||||
ds *= 2
|
||||
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
self.middle_block: List = [
|
||||
ResBlock(ch, time_embed_dim, ch, num_groups),
|
||||
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1], norm_eps=st_norm_eps),
|
||||
ResBlock(ch, time_embed_dim, ch, num_groups),
|
||||
]
|
||||
|
||||
self.output_blocks = []
|
||||
for idx, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[idx] + 1):
|
||||
ich = input_block_channels.pop()
|
||||
layers = [
|
||||
ResBlock(ch + ich, time_embed_dim, model_ch*mult, num_groups),
|
||||
]
|
||||
ch = model_ch * mult
|
||||
|
||||
if ds in attention_resolutions:
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
|
||||
|
||||
if idx > 0 and i == self.num_res_blocks[idx]:
|
||||
layers.append(Upsample(ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(layers)
|
||||
|
||||
self.out = [
|
||||
GroupNorm(num_groups, ch),
|
||||
Tensor.silu,
|
||||
Conv2d(model_ch, out_ch, 3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
|
||||
t_emb = timestep_embedding(tms, self.model_ch)
|
||||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
if y is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + y.sequential(self.label_emb[0])
|
||||
|
||||
if is_dtype_supported(mixed_precision_dtype):
|
||||
emb = emb.cast(mixed_precision_dtype)
|
||||
ctx = ctx.cast(mixed_precision_dtype)
|
||||
x = x .cast(mixed_precision_dtype)
|
||||
|
||||
def run(x:Tensor, bb) -> Tensor:
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
elif isinstance(bb, SpatialTransformer): x = bb(x, ctx)
|
||||
else: x = bb(x)
|
||||
return x
|
||||
|
||||
saved_inputs = []
|
||||
for b in self.input_blocks:
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
saved_inputs.append(x)
|
||||
for bb in self.middle_block:
|
||||
x = run(x, bb)
|
||||
for b in self.output_blocks:
|
||||
x = x.cat(saved_inputs.pop(), dim=1)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
return x.sequential(self.out)
|
||||
274
backend/python/tinygrad/vendor/whisper.py
vendored
Normal file
274
backend/python/tinygrad/vendor/whisper.py
vendored
Normal file
@@ -0,0 +1,274 @@
|
||||
# Adapted from tinygrad examples/whisper.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/whisper.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# Local modifications: removed the pyaudio listener / __main__ block; the rest
|
||||
# is the core Whisper model + preprocessing + single-file transcription path.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import collections
|
||||
import itertools
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes, nn
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn.state import load_state_dict, torch_load
|
||||
|
||||
from .audio_helpers import mel
|
||||
|
||||
|
||||
class MultiHeadAttention:
|
||||
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self', None] = None, max_self_attn_cache_len=None):
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
self.value = nn.Linear(n_state, n_state)
|
||||
self.out = nn.Linear(n_state, n_state)
|
||||
self.kv_caching = kv_caching
|
||||
self.max_self_attn_cache_len = max_self_attn_cache_len
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, len=None):
|
||||
if self.kv_caching == 'cross':
|
||||
if xa is not None:
|
||||
k, v = self.key(xa), self.value(xa)
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k, self.cache_v = k, v
|
||||
else:
|
||||
self.cache_k.assign(k).realize()
|
||||
self.cache_v.assign(v).realize()
|
||||
else:
|
||||
k, v = self.cache_k, self.cache_v
|
||||
else:
|
||||
k, v = self.key(x), self.value(x)
|
||||
if self.kv_caching == 'self':
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
|
||||
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
|
||||
padding = self.max_self_attn_cache_len - len - x.shape[1]
|
||||
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
|
||||
q = self.query(x)
|
||||
n_ctx = q.shape[1]
|
||||
head_dim = q.shape[-1] // self.n_head
|
||||
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None)
|
||||
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
return self.out(wv)
|
||||
|
||||
|
||||
class ResidualAttentionBlock:
|
||||
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
|
||||
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
|
||||
self.mlp = [nn.Linear(n_state, n_state * 4), Tensor.gelu, nn.Linear(n_state * 4, n_state)]
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, len=None):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
||||
x = x + self.mlp_ln(x).sequential(self.mlp)
|
||||
return x.realize()
|
||||
|
||||
|
||||
class AudioEncoder:
|
||||
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
|
||||
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
|
||||
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
|
||||
self.ln_post = nn.LayerNorm(n_audio_state)
|
||||
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
|
||||
self.encode = TinyJit(self.__call__)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x).gelu()
|
||||
x = self.conv2(x).gelu()
|
||||
x = x.permute(0, 2, 1)
|
||||
x = x + self.positional_embedding[:x.shape[1]]
|
||||
x = x.sequential(self.blocks)
|
||||
x = self.ln_post(x)
|
||||
return x.realize()
|
||||
|
||||
|
||||
class TextDecoder:
|
||||
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
|
||||
self.max_tokens_to_sample = n_text_ctx // 2
|
||||
self.max_self_attn_cache_len = n_text_ctx
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
|
||||
self.ln = nn.LayerNorm(n_text_state)
|
||||
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
|
||||
self.getjitted = collections.defaultdict(lambda: TinyJit(self.forward))
|
||||
|
||||
def __call__(self, x, pos, encoded_audio):
|
||||
pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len - 1).bind(pos) if pos else 0
|
||||
return self.getjitted[x.shape](x, pos, encoded_audio)
|
||||
|
||||
def forward(self, x, pos, encoded_audio):
|
||||
seqlen = x.shape[-1]
|
||||
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos + seqlen), None))
|
||||
for block in self.blocks:
|
||||
x = block(x, xa=encoded_audio, mask=self.mask, len=pos)
|
||||
return self.output_tok(x)
|
||||
|
||||
def output_tok(self, x):
|
||||
return (self.ln(x) @ self.token_embedding.weight.T).realize()
|
||||
|
||||
|
||||
class Whisper:
|
||||
def __init__(self, dims, batch_size=1):
|
||||
self.encoder = AudioEncoder(**dims)
|
||||
self.decoder = TextDecoder(**dims)
|
||||
self.is_multilingual = dims["n_vocab"] == 51865
|
||||
self.batch_size = batch_size
|
||||
|
||||
|
||||
RATE = 16000
|
||||
SEGMENT_SECONDS = 30
|
||||
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH
|
||||
|
||||
|
||||
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate: bool = False) -> np.ndarray:
|
||||
import librosa
|
||||
|
||||
def pad_or_trim(arr, target_len):
|
||||
if len(arr) == target_len:
|
||||
return arr
|
||||
if len(arr) < target_len:
|
||||
return np.pad(arr, (0, target_len - len(arr)), 'constant')
|
||||
return arr[:target_len]
|
||||
|
||||
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(w) for w in waveforms)
|
||||
if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
|
||||
max_len += SAMPLES_PER_SEGMENT - r
|
||||
|
||||
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
|
||||
if waveforms.shape[0] < batch_size:
|
||||
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
|
||||
|
||||
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||
mel_spec = mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS).numpy() @ magnitudes
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max((1, 2), keepdims=True) - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean",
|
||||
"fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "it": "italian",
|
||||
}
|
||||
|
||||
|
||||
def get_encoding(encoding_name: str):
|
||||
import tiktoken
|
||||
|
||||
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
|
||||
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
|
||||
n_vocab = len(ranks)
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
|
||||
return tiktoken.Encoding(
|
||||
name=encoding_name,
|
||||
explicit_n_vocab=n_vocab + len(specials),
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
MODEL_URLS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
}
|
||||
|
||||
|
||||
def init_whisper(model_name: str = "base", batch_size: int = 1):
|
||||
filename = fetch(MODEL_URLS[model_name])
|
||||
state = torch_load(filename)
|
||||
model = Whisper(state['dims'], batch_size)
|
||||
load_state_dict(model, state['model_state_dict'], strict=False)
|
||||
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
|
||||
return model, enc
|
||||
|
||||
|
||||
def load_file_waveform(filename: str):
|
||||
import librosa
|
||||
waveform, _ = librosa.load(filename, sr=RATE)
|
||||
return waveform
|
||||
|
||||
|
||||
def transcribe_waveform(model: Whisper, enc, waveforms, language: Optional[str] = None, truncate: bool = False) -> str:
|
||||
log_spec = prep_audio(waveforms, model.batch_size, truncate)
|
||||
nsample = model.decoder.max_tokens_to_sample
|
||||
nctx = model.decoder.max_self_attn_cache_len
|
||||
|
||||
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
if model.is_multilingual:
|
||||
lang = language if (language and language in LANGUAGES) else "en"
|
||||
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index(lang)
|
||||
start_tokens.append(language_token)
|
||||
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
||||
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
||||
|
||||
eot = enc._special_tokens["<|endoftext|>"]
|
||||
|
||||
def inferloop(ctx, encoded_audio):
|
||||
pos, next_tokens = 0, ctx
|
||||
for _ in range(nsample):
|
||||
next_tokens = model.decoder(Tensor(next_tokens, dtype=dtypes.int32), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
|
||||
next_tokens[ctx[:, -1] == eot] = eot
|
||||
ctx = np.concatenate((ctx, next_tokens), axis=1)
|
||||
pos = ctx.shape[-1] - 1
|
||||
if (next_tokens == eot).all() or pos == nctx:
|
||||
break
|
||||
return ctx
|
||||
|
||||
ctx = np.tile(start_tokens, (model.batch_size, 1))
|
||||
transcriptions: list[list[int]] = [[] for _ in waveforms]
|
||||
|
||||
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
|
||||
ctx_arr = inferloop(np.array(ctx), encoded_audio)
|
||||
for i, arr in enumerate(ctx_arr):
|
||||
if i >= len(waveforms):
|
||||
break
|
||||
end_idxs = np.where(arr == eot)[0]
|
||||
start_idx = np.where(arr == start_tokens[-1])[0][0] + 1
|
||||
end_idx = end_idxs[0] if len(end_idxs) else None
|
||||
transcriptions[i].extend(arr[start_idx:end_idx])
|
||||
ctx = ctx_arr
|
||||
|
||||
texts = [enc.decode([int(t) for t in toks]).strip() for toks in transcriptions]
|
||||
return texts[0] if len(texts) == 1 else "\n".join(texts)
|
||||
@@ -4,7 +4,7 @@ numba==0.60.0
|
||||
accelerate
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
sentence-transformers==5.4.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -4,7 +4,7 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
sentence-transformers==5.4.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -4,7 +4,7 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
sentence-transformers==5.4.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -5,7 +5,7 @@ transformers>=5.0.0
|
||||
llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
sentence-transformers==5.4.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -5,7 +5,7 @@ llvmlite==0.43.0
|
||||
numba==0.60.0
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
sentence-transformers==5.4.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
@@ -4,7 +4,7 @@ numba==0.60.0
|
||||
accelerate
|
||||
transformers>=5.0.0
|
||||
bitsandbytes
|
||||
sentence-transformers==5.2.3
|
||||
sentence-transformers==5.4.0
|
||||
diffusers
|
||||
soundfile
|
||||
protobuf==6.33.5
|
||||
|
||||
@@ -17,6 +17,8 @@ import time
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import gc
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
@@ -30,6 +32,7 @@ import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
from vllm_utils import parse_options, messages_to_dicts, setup_parsers
|
||||
|
||||
|
||||
from vllm_omni.entrypoints.omni import Omni
|
||||
@@ -148,23 +151,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
# CPU detection: if no CUDA, default vLLM target device to CPU.
|
||||
try:
|
||||
if not torch.cuda.is_available():
|
||||
os.environ.setdefault("VLLM_TARGET_DEVICE", "cpu")
|
||||
os.environ.setdefault("VLLM_CPU_KVCACHE_SPACE", "4")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"Loading model {request.Model}...", file=sys.stderr)
|
||||
print(f"Request {request}", file=sys.stderr)
|
||||
|
||||
# Parse options from request.Options (key:value pairs)
|
||||
self.options = {}
|
||||
for opt in request.Options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
# Convert value to appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
# Parse options from request.Options using shared helper
|
||||
self.options = parse_options(request.Options)
|
||||
opts = self.options
|
||||
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
@@ -244,6 +244,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
omni_kwargs["max_model_len"] = request.MaxModelLen
|
||||
|
||||
self.omni = Omni(**omni_kwargs)
|
||||
|
||||
# Load tokenizer for LLM/TTS so chat templates work
|
||||
if self.model_type in ("llm", "tts"):
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
self.tokenizer = get_tokenizer(
|
||||
request.Model,
|
||||
trust_remote_code=opts.get("trust_remote_code", False),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to load tokenizer: {e}", file=sys.stderr)
|
||||
self.tokenizer = None
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
# Setup optional tool / reasoning parsers
|
||||
self.tool_parser_cls, self.reasoning_parser_cls = setup_parsers(opts)
|
||||
|
||||
print("Model loaded successfully", file=sys.stderr)
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
@@ -466,14 +484,32 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# Extract prompt
|
||||
if request.Prompt:
|
||||
prompt = request.Prompt
|
||||
elif request.Messages and request.UseTokenizerTemplate:
|
||||
# Build prompt from messages (simplified - would need tokenizer for full template)
|
||||
prompt = ""
|
||||
for msg in request.Messages:
|
||||
role = msg.role
|
||||
content = msg.content
|
||||
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
elif request.Messages:
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
messages_dicts = messages_to_dicts(request.Messages)
|
||||
template_kwargs = {"tokenize": False, "add_generation_prompt": True}
|
||||
if request.Tools:
|
||||
try:
|
||||
template_kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
try:
|
||||
if request.Metadata.get("enable_thinking", "").lower() == "true":
|
||||
template_kwargs["enable_thinking"] = True
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(messages_dicts, **template_kwargs)
|
||||
except TypeError:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages_dicts, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
# Fallback: basic template
|
||||
prompt = ""
|
||||
for msg in request.Messages:
|
||||
prompt += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n"
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
else:
|
||||
yield backend_pb2.Reply(message=bytes("", 'utf-8'))
|
||||
return
|
||||
@@ -539,20 +575,79 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# Call omni.generate() (returns generator for LLM mode)
|
||||
omni_generator = self.omni.generate([inputs], sampling_params_list)
|
||||
|
||||
# Extract text from outputs
|
||||
# Extract text from outputs and track token usage
|
||||
generated_text = ""
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
for stage_outputs in omni_generator:
|
||||
if stage_outputs.final_output_type == "text":
|
||||
for output in stage_outputs.request_output:
|
||||
text_output = output.outputs[0].text
|
||||
completion = output.outputs[0]
|
||||
text_output = completion.text
|
||||
# Track tokens when available
|
||||
try:
|
||||
if getattr(output, "prompt_token_ids", None) is not None:
|
||||
prompt_tokens = len(output.prompt_token_ids)
|
||||
if getattr(completion, "token_ids", None) is not None:
|
||||
completion_tokens = len(completion.token_ids)
|
||||
except Exception:
|
||||
pass
|
||||
if streaming:
|
||||
# Remove already sent text (vllm concatenates)
|
||||
delta_text = text_output.removeprefix(generated_text)
|
||||
yield backend_pb2.Reply(message=bytes(delta_text, encoding='utf-8'))
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(delta_text, encoding='utf-8'),
|
||||
tokens=completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
generated_text = text_output
|
||||
|
||||
if not streaming:
|
||||
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
# Build optional ChatDelta with parsed reasoning / tool calls
|
||||
chat_deltas = []
|
||||
content_text = generated_text
|
||||
reasoning_text = ""
|
||||
tool_call_deltas = []
|
||||
|
||||
if self.reasoning_parser_cls is not None:
|
||||
try:
|
||||
parser = self.reasoning_parser_cls(self.tokenizer) if self.tokenizer else self.reasoning_parser_cls()
|
||||
reasoning_text, content_text = parser.extract_reasoning_content(content_text, request=None)
|
||||
reasoning_text = reasoning_text or ""
|
||||
content_text = content_text or ""
|
||||
except Exception as e:
|
||||
print(f"reasoning_parser failed: {e}", file=sys.stderr)
|
||||
|
||||
if self.tool_parser_cls is not None:
|
||||
try:
|
||||
parser = self.tool_parser_cls(self.tokenizer) if self.tokenizer else self.tool_parser_cls()
|
||||
tool_info = parser.extract_tool_calls(content_text, request=None)
|
||||
if getattr(tool_info, "tools_called", False):
|
||||
content_text = tool_info.content or ""
|
||||
for tc in tool_info.tool_calls or []:
|
||||
fn = getattr(tc, "function", None)
|
||||
tool_call_deltas.append(backend_pb2.ToolCallDelta(
|
||||
index=getattr(tc, "index", 0) or 0,
|
||||
id=getattr(tc, "id", "") or "",
|
||||
name=getattr(fn, "name", "") if fn else "",
|
||||
arguments=getattr(fn, "arguments", "") if fn else "",
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"tool_parser failed: {e}", file=sys.stderr)
|
||||
|
||||
if self.tool_parser_cls is not None or self.reasoning_parser_cls is not None:
|
||||
chat_deltas.append(backend_pb2.ChatDelta(
|
||||
content=content_text,
|
||||
reasoning_content=reasoning_text,
|
||||
tool_calls=tool_call_deltas,
|
||||
))
|
||||
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(generated_text, encoding='utf-8'),
|
||||
tokens=completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
chat_deltas=chat_deltas,
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in Predict: {err}", file=sys.stderr)
|
||||
@@ -647,6 +742,37 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Error generating TTS: {err}")
|
||||
|
||||
def TokenizeString(self, request, context):
|
||||
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Model/tokenizer not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokens = self.tokenizer.encode(request.Prompt)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
def Free(self, request, context):
|
||||
try:
|
||||
if hasattr(self, 'omni'):
|
||||
del self.omni
|
||||
if hasattr(self, 'tokenizer'):
|
||||
del self.tokenizer
|
||||
self.tool_parser_cls = None
|
||||
self.reasoning_parser_cls = None
|
||||
gc.collect()
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="Model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
|
||||
@@ -5,6 +5,9 @@ import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import gc
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
@@ -26,6 +29,25 @@ from vllm.assets.video import VideoAsset
|
||||
import base64
|
||||
import io
|
||||
|
||||
# Version-compat imports — wrap in try/except for older vLLM versions
|
||||
try:
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
HAS_TOOL_PARSERS = True
|
||||
except ImportError:
|
||||
HAS_TOOL_PARSERS = False
|
||||
|
||||
try:
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
HAS_REASONING_PARSERS = True
|
||||
except ImportError:
|
||||
HAS_REASONING_PARSERS = False
|
||||
|
||||
try:
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
HAS_GUIDED_DECODING = True
|
||||
except ImportError:
|
||||
HAS_GUIDED_DECODING = False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
@@ -69,6 +91,35 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
break
|
||||
return decoded_text
|
||||
|
||||
def _parse_options(self, options_list):
|
||||
"""Parse Options[] key:value string list into a dict."""
|
||||
opts = {}
|
||||
for opt in options_list:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
opts[key.strip()] = value.strip()
|
||||
return opts
|
||||
|
||||
def _messages_to_dicts(self, messages):
|
||||
"""Convert proto Messages to list of dicts suitable for apply_chat_template()."""
|
||||
result = []
|
||||
for msg in messages:
|
||||
d = {"role": msg.role, "content": msg.content or ""}
|
||||
if msg.name:
|
||||
d["name"] = msg.name
|
||||
if msg.tool_call_id:
|
||||
d["tool_call_id"] = msg.tool_call_id
|
||||
if msg.reasoning_content:
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
Returns a health check message.
|
||||
@@ -132,15 +183,49 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
try:
|
||||
engine_model_config = await self.llm.get_model_config()
|
||||
self.tokenizer = get_tokenizer(
|
||||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code,
|
||||
truncation_side="left",
|
||||
)
|
||||
# vLLM >= 0.14 removed get_model_config() on AsyncLLM; the tokenizer
|
||||
# is either already loaded on the engine or can be built from the
|
||||
# Model name directly.
|
||||
tokenizer = None
|
||||
if hasattr(self.llm, "get_tokenizer"):
|
||||
try:
|
||||
tokenizer = await self.llm.get_tokenizer()
|
||||
except TypeError:
|
||||
tokenizer = self.llm.get_tokenizer()
|
||||
except Exception:
|
||||
tokenizer = None
|
||||
if tokenizer is None and hasattr(self.llm, "tokenizer"):
|
||||
tokenizer = self.llm.tokenizer
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(
|
||||
request.Model,
|
||||
trust_remote_code=bool(request.TrustRemoteCode),
|
||||
truncation_side="left",
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
# Parse options for parser selection
|
||||
opts = self._parse_options(request.Options)
|
||||
|
||||
# Instantiate tool/reasoning parser classes (they'll be instantiated per-request with tokenizer)
|
||||
self.tool_parser_cls = None
|
||||
self.reasoning_parser_cls = None
|
||||
if HAS_TOOL_PARSERS and opts.get("tool_parser"):
|
||||
try:
|
||||
self.tool_parser_cls = ToolParserManager.get_tool_parser(opts["tool_parser"])
|
||||
print(f"Loaded tool_parser: {opts['tool_parser']}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Failed to load tool_parser {opts.get('tool_parser')}: {e}", file=sys.stderr)
|
||||
|
||||
if HAS_REASONING_PARSERS and opts.get("reasoning_parser"):
|
||||
try:
|
||||
self.reasoning_parser_cls = ReasoningParserManager.get_reasoning_parser(opts["reasoning_parser"])
|
||||
print(f"Loaded reasoning_parser: {opts['reasoning_parser']}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Failed to load reasoning_parser {opts.get('reasoning_parser')}: {e}", file=sys.stderr)
|
||||
|
||||
print("Model loaded successfully", file=sys.stderr)
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
@@ -197,6 +282,38 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
finally:
|
||||
await iterations.aclose()
|
||||
|
||||
async def TokenizeString(self, request, context):
|
||||
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("Model/tokenizer not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokens = self.tokenizer.encode(request.Prompt)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
async def Free(self, request, context):
|
||||
try:
|
||||
if hasattr(self, 'llm'):
|
||||
del self.llm
|
||||
if hasattr(self, 'tokenizer'):
|
||||
del self.tokenizer
|
||||
self.tool_parser_cls = None
|
||||
self.reasoning_parser_cls = None
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="Model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
# Build the sampling parameters
|
||||
# NOTE: this must stay in sync with the vllm backend
|
||||
@@ -222,7 +339,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"SkipSpecialTokens": "skip_special_tokens",
|
||||
"SpacesBetweenSpecialTokens": "spaces_between_special_tokens",
|
||||
"TruncatePromptTokens": "truncate_prompt_tokens",
|
||||
"GuidedDecoding": "guided_decoding",
|
||||
}
|
||||
|
||||
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
||||
@@ -233,6 +349,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if value not in (None, 0, [], False):
|
||||
setattr(sampling_params, param_field, value)
|
||||
|
||||
# Guided decoding: use Grammar field to pass JSON schema or BNF
|
||||
if HAS_GUIDED_DECODING and request.Grammar:
|
||||
try:
|
||||
json.loads(request.Grammar) # valid JSON = JSON schema
|
||||
sampling_params.guided_decoding = GuidedDecodingParams(json=request.Grammar)
|
||||
except json.JSONDecodeError:
|
||||
sampling_params.guided_decoding = GuidedDecodingParams(grammar=request.Grammar)
|
||||
|
||||
# Extract image paths and process images
|
||||
prompt = request.Prompt
|
||||
|
||||
@@ -244,7 +368,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
||||
messages_dicts = self._messages_to_dicts(request.Messages)
|
||||
template_kwargs = {"tokenize": False, "add_generation_prompt": True}
|
||||
|
||||
# Pass tools for tool calling
|
||||
if request.Tools:
|
||||
try:
|
||||
template_kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Enable thinking mode if requested
|
||||
if request.Metadata.get("enable_thinking", "").lower() == "true":
|
||||
template_kwargs["enable_thinking"] = True
|
||||
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(messages_dicts, **template_kwargs)
|
||||
except TypeError:
|
||||
# Some tokenizers don't support tools/enable_thinking kwargs — retry without them
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages_dicts, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Generate text using the LLM engine
|
||||
request_id = random_uuid()
|
||||
@@ -265,25 +409,26 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
# Stream the results
|
||||
generated_text = ""
|
||||
last_output = None
|
||||
try:
|
||||
async for request_output in outputs:
|
||||
iteration_text = request_output.outputs[0].text
|
||||
last_output = request_output
|
||||
|
||||
if streaming:
|
||||
# Remove text already sent as vllm concatenates the text from previous yields
|
||||
delta_iteration_text = iteration_text.removeprefix(generated_text)
|
||||
# Send the partial result
|
||||
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(delta_iteration_text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=delta_iteration_text)],
|
||||
)
|
||||
|
||||
# Keep track of text generated
|
||||
generated_text = iteration_text
|
||||
finally:
|
||||
await outputs.aclose()
|
||||
|
||||
# If streaming, we already sent everything
|
||||
if streaming:
|
||||
return
|
||||
|
||||
# Remove the image files from /tmp folder
|
||||
for img_path in image_paths:
|
||||
try:
|
||||
@@ -291,8 +436,99 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as e:
|
||||
print(f"Error removing image file: {img_path}, {e}", file=sys.stderr)
|
||||
|
||||
# Sending the final generated text
|
||||
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
# Parse reasoning and tool calls from final text using vLLM's native parsers
|
||||
content = generated_text
|
||||
reasoning_content = ""
|
||||
tool_calls_proto = []
|
||||
|
||||
if self.reasoning_parser_cls:
|
||||
try:
|
||||
rp = self.reasoning_parser_cls(self.tokenizer)
|
||||
r, c = rp.extract_reasoning(generated_text, request=None)
|
||||
reasoning_content = r or ""
|
||||
content = c if c is not None else generated_text
|
||||
except Exception as e:
|
||||
print(f"Reasoning parser error: {e}", file=sys.stderr)
|
||||
|
||||
if self.tool_parser_cls and request.Tools:
|
||||
try:
|
||||
tools = json.loads(request.Tools)
|
||||
# Some concrete parsers only accept the tokenizer; only the
|
||||
# abstract base declares the tools kwarg. Try with tools first,
|
||||
# fall back to tokenizer-only.
|
||||
try:
|
||||
tp = self.tool_parser_cls(self.tokenizer, tools=tools)
|
||||
except TypeError:
|
||||
tp = self.tool_parser_cls(self.tokenizer)
|
||||
info = tp.extract_tool_calls(content, request=None)
|
||||
if info.tools_called:
|
||||
content = info.content or ""
|
||||
for i, tc in enumerate(info.tool_calls):
|
||||
tool_calls_proto.append(backend_pb2.ToolCallDelta(
|
||||
index=i,
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=tc.function.arguments,
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"Tool parser error: {e}", file=sys.stderr)
|
||||
|
||||
# Extract token counts
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
if last_output is not None:
|
||||
try:
|
||||
prompt_tokens = len(last_output.prompt_token_ids or [])
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
completion_tokens = len(last_output.outputs[0].token_ids or [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprobs_bytes = b""
|
||||
if last_output is not None and request.Logprobs > 0:
|
||||
try:
|
||||
lp = last_output.outputs[0].logprobs
|
||||
if lp:
|
||||
logprobs_data = {"content": []}
|
||||
for token_lp_dict in lp:
|
||||
if token_lp_dict:
|
||||
first_tok_id, first_lp = next(iter(token_lp_dict.items()))
|
||||
logprobs_data["content"].append({
|
||||
"token": getattr(first_lp, "decoded_token", str(first_tok_id)),
|
||||
"logprob": first_lp.logprob,
|
||||
})
|
||||
logprobs_bytes = json.dumps(logprobs_data).encode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Logprobs extraction error: {e}", file=sys.stderr)
|
||||
|
||||
chat_delta = backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
|
||||
if streaming:
|
||||
# Final chunk with structured data
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
chat_deltas=[chat_delta],
|
||||
logprobs=logprobs_bytes,
|
||||
)
|
||||
return
|
||||
|
||||
# Non-streaming: single Reply with everything
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
chat_deltas=[chat_delta],
|
||||
logprobs=logprobs_bytes,
|
||||
)
|
||||
|
||||
def load_image(self, image_path: str):
|
||||
"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user