mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-17 05:18:53 -04:00
* feat(backend): add tinygrad multimodal backend
Wire tinygrad as a new Python backend covering LLM text generation with
native tool-call extraction, embeddings, Stable Diffusion 1.x image
generation, and Whisper speech-to-text from a single self-contained
container.
Backend (`backend/python/tinygrad/`):
- `backend.py` gRPC servicer with LLM Predict/PredictStream (auto-detects
Llama / Qwen2 / Mistral architecture from `config.json`, supports
safetensors and GGUF), Embedding via mean-pooled last hidden state,
GenerateImage via the vendored SD1.x pipeline, AudioTranscription +
AudioTranscriptionStream via the vendored Whisper inference loop, plus
Tokenize / ModelMetadata / Status / Free.
- Vendored upstream model code under `vendor/` (MIT, headers preserved):
llama.py with an added `qkv_bias` flag for Qwen2-family bias support
and an `embed()` method that returns the last hidden state, plus
clip.py, unet.py, stable_diffusion.py (trimmed to drop the MLPerf
training branch that pulls `mlperf.initializers`), audio_helpers.py
and whisper.py (trimmed to drop the pyaudio listener).
- Pluggable tool-call parsers under `tool_parsers/`: hermes (Qwen2.5 /
Hermes), llama3_json (Llama 3.1+), qwen3_xml (Qwen 3), mistral
(Mistral / Mixtral). Auto-selected from model architecture or `Options`.
- `install.sh` pins Python 3.11.14 (tinygrad >=0.12 needs >=3.11; the
default portable python is 3.10).
- `package.sh` bundles libLLVM.so.1 + libedit/libtinfo/libgomp/libsndfile
into the scratch image. `run.sh` sets `CPU_LLVM=1` and `LLVM_PATH` so
tinygrad's CPU device uses the in-process libLLVM JIT instead of
shelling out to the missing `clang` binary.
- Local unit tests for Health and the four parsers in `test.py`.
Build wiring:
- Root `Makefile`: `.NOTPARALLEL`, `prepare-test-extra`, `test-extra`,
`BACKEND_TINYGRAD = tinygrad|python|.|false|true`,
docker-build-target eval, and `docker-build-backends` aggregator.
- `.github/workflows/backend.yml`: cpu / cuda12 / cuda13 build matrix
entries (mirrors the transformers backend placement).
- `backend/index.yaml`: `&tinygrad` meta + cpu/cuda12/cuda13 image
entries (latest + development).
E2E test wiring:
- `tests/e2e-backends/backend_test.go` gains an `image` capability that
exercises GenerateImage and asserts a non-empty PNG is written to
`dst`. New `BACKEND_TEST_IMAGE_PROMPT` / `BACKEND_TEST_IMAGE_STEPS`
knobs.
- Five new make targets next to `test-extra-backend-vllm`:
- `test-extra-backend-tinygrad` — Qwen2.5-0.5B-Instruct + hermes,
mirrors the vllm target 1:1 (5/9 specs in ~57s).
- `test-extra-backend-tinygrad-embeddings` — same model, embeddings
via LLM hidden state (3/9 in ~10s).
- `test-extra-backend-tinygrad-sd` — stable-diffusion-v1-5 mirror,
health/load/image (3/9 in ~10min, 4 diffusion steps on CPU).
- `test-extra-backend-tinygrad-whisper` — openai/whisper-tiny.en
against jfk.wav from whisper.cpp samples (4/9 in ~49s).
- `test-extra-backend-tinygrad-all` aggregate.
All four targets land green on the first MVP pass: 15 specs total, 0
failures across LLM+tools, embeddings, image generation, and speech
transcription.
* refactor(tinygrad): collapse to a single backend image
tinygrad generates its own GPU kernels (PTX renderer for CUDA, the
autogen ctypes wrappers for HIP / Metal / WebGPU) and never links
against cuDNN, cuBLAS, or any toolkit-version-tied library. The only
runtime dependency that varies across hosts is the driver's libcuda.so.1
/ libamdhip64.so, which are injected into the container at run time by
the nvidia-container / rocm runtimes. So unlike torch- or vLLM-based
backends, there is no reason to ship per-CUDA-version images.
- Drop the cuda12-tinygrad and cuda13-tinygrad build-matrix entries
from .github/workflows/backend.yml. The sole remaining entry is
renamed to -tinygrad (from -cpu-tinygrad) since it is no longer
CPU-only.
- Collapse backend/index.yaml to a single meta + development pair.
The meta anchor carries the latest uri directly; the development
entry points at the master tag.
- run.sh picks the tinygrad device at launch time by probing
/usr/lib/... for libcuda.so.1 / libamdhip64.so. When libcuda is
visible we set CUDA=1 + CUDA_PTX=1 so tinygrad uses its own PTX
renderer (avoids any nvrtc/toolkit dependency); otherwise we fall
back to HIP or CLANG. CPU_LLVM=1 + LLVM_PATH keep the in-process
libLLVM JIT for the CLANG path.
- backend.py's _select_tinygrad_device() is trimmed to a CLANG-only
fallback since production device selection happens in run.sh.
Re-ran test-extra-backend-tinygrad after the change:
Ran 5 of 9 Specs in 56.541 seconds — 5 Passed, 0 Failed
751 lines
30 KiB
Python
751 lines
30 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
LocalAI gRPC backend for tinygrad.
|
|
|
|
Multimodal scope (landing incrementally):
|
|
- LLM text generation (Llama 3 / Qwen 2.5 / Mistral via HF safetensors + GGUF)
|
|
- Native tool-call extraction via pluggable parsers (hermes, llama3_json, ...)
|
|
- Embeddings, Stable Diffusion, Whisper — planned; currently return UNIMPLEMENTED.
|
|
|
|
The heavy imports (tinygrad, tokenizers, vendor.llama) are deferred until
|
|
`LoadModel`, because tinygrad binds its compute device at import time from
|
|
env vars. `_select_tinygrad_device()` maps LocalAI's BUILD_TYPE onto the
|
|
corresponding tinygrad env flag before any import happens.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
from concurrent import futures
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
import grpc
|
|
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
|
from grpc_auth import get_auth_interceptors # noqa: E402
|
|
|
|
from tool_parsers import resolve_parser # noqa: E402
|
|
from tool_parsers.base import ToolCall # noqa: E402
|
|
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Device selection — must run BEFORE `import tinygrad` anywhere.
|
|
#
|
|
# In production this is set by run.sh based on which driver libraries the
|
|
# host has injected into the container (libcuda.so.1 → CUDA, libamdhip64
|
|
# → HIP, otherwise CLANG). This helper is only a fallback for direct
|
|
# invocations like the unit tests.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _select_tinygrad_device() -> None:
|
|
if any(os.environ.get(k) == "1" for k in ("CUDA", "HIP", "METAL", "CLANG", "AMD", "NV")):
|
|
return
|
|
os.environ["CLANG"] = "1"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model asset discovery
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _resolve_model_assets(model_ref: str) -> Path:
|
|
"""
|
|
Accept either a local path or a HuggingFace repo id (e.g.
|
|
"Qwen/Qwen2.5-0.5B-Instruct") and return the local directory / file.
|
|
HF ids are materialized via `huggingface_hub.snapshot_download`.
|
|
"""
|
|
p = Path(model_ref)
|
|
if p.exists():
|
|
return p
|
|
if "/" in model_ref and not model_ref.startswith(("/", ".")):
|
|
from huggingface_hub import snapshot_download
|
|
local = snapshot_download(
|
|
repo_id=model_ref,
|
|
allow_patterns=[
|
|
"config.json",
|
|
"tokenizer.json",
|
|
"tokenizer_config.json",
|
|
"special_tokens_map.json",
|
|
"generation_config.json",
|
|
"*.safetensors",
|
|
"*.safetensors.index.json",
|
|
],
|
|
)
|
|
return Path(local)
|
|
raise FileNotFoundError(f"Model not found: {model_ref}")
|
|
|
|
|
|
def _load_llm_weights(model_dir: Path):
|
|
"""Load HF safetensors weights from a directory or a single file."""
|
|
from tinygrad import Device, Tensor, dtypes
|
|
from tinygrad.nn.state import safe_load, gguf_load
|
|
|
|
if model_dir.is_file():
|
|
if str(model_dir).endswith(".gguf"):
|
|
gguf_tensor = Tensor.empty(os.stat(model_dir).st_size, dtype=dtypes.uint8,
|
|
device=f"disk:{model_dir}").to(Device.DEFAULT)
|
|
return gguf_load(gguf_tensor)[1], "gguf"
|
|
return safe_load(str(model_dir)), "safetensors"
|
|
|
|
index = model_dir / "model.safetensors.index.json"
|
|
if index.exists():
|
|
with open(index) as fp:
|
|
weight_map = json.load(fp)["weight_map"]
|
|
shards: dict[str, Any] = {}
|
|
for shard_name in set(weight_map.values()):
|
|
shards[shard_name] = safe_load(str(model_dir / shard_name))
|
|
merged = {k: shards[n][k] for k, n in weight_map.items()}
|
|
return merged, "safetensors"
|
|
|
|
single = model_dir / "model.safetensors"
|
|
if single.exists():
|
|
return safe_load(str(single)), "safetensors"
|
|
|
|
ggufs = sorted(model_dir.glob("*.gguf"))
|
|
if ggufs:
|
|
gguf_tensor = Tensor.empty(os.stat(ggufs[0]).st_size, dtype=dtypes.uint8,
|
|
device=f"disk:{ggufs[0]}").to(Device.DEFAULT)
|
|
return gguf_load(gguf_tensor)[1], "gguf"
|
|
|
|
raise FileNotFoundError(f"No safetensors or gguf weights found under {model_dir}")
|
|
|
|
|
|
def _infer_qkv_bias(config: dict) -> bool:
|
|
"""Qwen2-family uses Q/K/V projection bias; Llama/Mistral do not."""
|
|
architectures = [a.lower() for a in config.get("architectures", [])]
|
|
if any("qwen" in a for a in architectures):
|
|
return True
|
|
return bool(config.get("attention_bias", False)) or bool(config.get("qkv_bias", False))
|
|
|
|
|
|
def _auto_tool_parser(model_ref: Optional[str], config: dict) -> Optional[str]:
|
|
"""Pick a tool parser automatically from model family heuristics.
|
|
|
|
Order of precedence: architecture name from config.json, then model ref
|
|
string. Returns None to fall through to the passthrough parser.
|
|
"""
|
|
arches = " ".join(a.lower() for a in config.get("architectures", []))
|
|
ref = (model_ref or "").lower()
|
|
blob = f"{arches} {ref}"
|
|
|
|
if "qwen3" in blob:
|
|
return "qwen3_xml"
|
|
if "hermes" in blob or "qwen2" in blob or "qwen" in blob:
|
|
return "hermes"
|
|
if "llama-3" in blob or "llama_3" in blob or "llama3" in blob:
|
|
return "llama3_json"
|
|
if "mistral" in blob or "mixtral" in blob:
|
|
return "mistral"
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Servicer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
"""gRPC servicer for the tinygrad backend."""
|
|
|
|
def __init__(self) -> None:
|
|
self._reset_state()
|
|
|
|
def _reset_state(self) -> None:
|
|
self.model_ref: Optional[str] = None
|
|
self.model_type: str = "llm"
|
|
self.options: dict[str, str] = {}
|
|
# LLM state
|
|
self.llm_model = None
|
|
self.llm_config: dict = {}
|
|
self.llm_tokenizer = None
|
|
self.llm_eos_ids: list[int] = []
|
|
self.chat_template: Optional[str] = None
|
|
self.tool_parser = resolve_parser(None)
|
|
self.max_context = 4096
|
|
# Stable Diffusion state
|
|
self.sd_model = None
|
|
# Whisper state
|
|
self.whisper_model = None
|
|
self.whisper_tokenizer = None
|
|
|
|
# --------------------- helpers --------------------------------------
|
|
|
|
@staticmethod
|
|
def _parse_options(options_list) -> dict[str, str]:
|
|
opts: dict[str, str] = {}
|
|
for opt in options_list:
|
|
if ":" not in opt:
|
|
continue
|
|
key, value = opt.split(":", 1)
|
|
opts[key.strip()] = value.strip()
|
|
return opts
|
|
|
|
@staticmethod
|
|
def _detect_model_type(model_ref: str, explicit: Optional[str]) -> str:
|
|
if explicit:
|
|
return explicit
|
|
name = (model_ref or "").lower()
|
|
if "whisper" in name:
|
|
return "whisper"
|
|
if "sdxl" in name:
|
|
return "sdxl"
|
|
if "sd-v1" in name or "v1-5" in name or "stable-diffusion" in name:
|
|
return "sd15"
|
|
if any(tag in name for tag in ("bge", "e5", "minilm", "bert")):
|
|
return "bert"
|
|
return "llm"
|
|
|
|
def _messages_to_dicts(self, messages) -> list[dict]:
|
|
result = []
|
|
for msg in messages:
|
|
d: dict = {"role": msg.role, "content": msg.content or ""}
|
|
if msg.name:
|
|
d["name"] = msg.name
|
|
if msg.tool_call_id:
|
|
d["tool_call_id"] = msg.tool_call_id
|
|
if msg.reasoning_content:
|
|
d["reasoning_content"] = msg.reasoning_content
|
|
if msg.tool_calls:
|
|
try:
|
|
d["tool_calls"] = json.loads(msg.tool_calls)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
result.append(d)
|
|
return result
|
|
|
|
def _render_prompt(self, request) -> str:
|
|
"""Render messages + tools into the model's chat template, or fall
|
|
back to the raw Prompt field for models without a template."""
|
|
if not request.Messages and request.Prompt:
|
|
return request.Prompt
|
|
|
|
if not self.chat_template:
|
|
# No template known — concatenate role/content lines.
|
|
lines = []
|
|
for msg in request.Messages:
|
|
lines.append(f"{msg.role}: {msg.content or ''}")
|
|
return "\n".join(lines) + "\nassistant:"
|
|
|
|
from jinja2 import Environment
|
|
|
|
env = Environment(trim_blocks=True, lstrip_blocks=True)
|
|
template = env.from_string(self.chat_template)
|
|
|
|
tools = None
|
|
if request.Tools:
|
|
try:
|
|
tools = json.loads(request.Tools)
|
|
except json.JSONDecodeError:
|
|
tools = None
|
|
|
|
return template.render(
|
|
messages=self._messages_to_dicts(request.Messages),
|
|
tools=tools,
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
# --------------------- LLM path -------------------------------------
|
|
|
|
def _load_llm(self, model_dir: Path) -> None:
|
|
config_path = model_dir / "config.json" if model_dir.is_dir() else None
|
|
if config_path is None or not config_path.exists():
|
|
raise FileNotFoundError(f"config.json not found under {model_dir}")
|
|
with open(config_path) as fp:
|
|
config = json.load(fp)
|
|
self.llm_config = config
|
|
|
|
from tinygrad import nn
|
|
|
|
from vendor.llama import Transformer, convert_from_huggingface, convert_from_gguf, fix_bf16
|
|
|
|
n_layers = config["num_hidden_layers"]
|
|
n_heads = config["num_attention_heads"]
|
|
n_kv_heads = config.get("num_key_value_heads", n_heads)
|
|
dim = config["hidden_size"]
|
|
hidden_dim = config["intermediate_size"]
|
|
vocab_size = config["vocab_size"]
|
|
norm_eps = config.get("rms_norm_eps", 1e-5)
|
|
rope_theta = config.get("rope_theta", 10000)
|
|
self.max_context = min(config.get("max_position_embeddings", 4096), 8192)
|
|
qkv_bias = _infer_qkv_bias(config)
|
|
|
|
weights, weight_format = _load_llm_weights(model_dir)
|
|
|
|
model = Transformer(
|
|
dim=dim,
|
|
hidden_dim=hidden_dim,
|
|
n_heads=n_heads,
|
|
n_layers=n_layers,
|
|
norm_eps=norm_eps,
|
|
vocab_size=vocab_size,
|
|
linear=nn.Linear,
|
|
embedding=nn.Embedding,
|
|
n_kv_heads=n_kv_heads,
|
|
rope_theta=rope_theta,
|
|
max_context=self.max_context,
|
|
jit=True,
|
|
qkv_bias=qkv_bias,
|
|
)
|
|
|
|
if weight_format == "safetensors":
|
|
weights = convert_from_huggingface(weights, n_layers, n_heads, n_kv_heads)
|
|
else:
|
|
weights = convert_from_gguf(weights, n_layers)
|
|
weights = fix_bf16(weights)
|
|
|
|
from tinygrad.nn.state import load_state_dict
|
|
load_state_dict(model, weights, strict=False, consume=True)
|
|
self.llm_model = model
|
|
|
|
# Tokenizer
|
|
tokenizer_json = model_dir / "tokenizer.json"
|
|
if not tokenizer_json.exists():
|
|
raise FileNotFoundError(f"tokenizer.json not found under {model_dir}")
|
|
from tokenizers import Tokenizer as HFTokenizer
|
|
self.llm_tokenizer = HFTokenizer.from_file(str(tokenizer_json))
|
|
|
|
# Chat template + EOS ids (from tokenizer_config.json + generation_config.json)
|
|
tok_cfg_path = model_dir / "tokenizer_config.json"
|
|
if tok_cfg_path.exists():
|
|
with open(tok_cfg_path) as fp:
|
|
tok_cfg = json.load(fp)
|
|
self.chat_template = tok_cfg.get("chat_template")
|
|
|
|
self.llm_eos_ids = []
|
|
gen_cfg_path = model_dir / "generation_config.json"
|
|
if gen_cfg_path.exists():
|
|
with open(gen_cfg_path) as fp:
|
|
gen_cfg = json.load(fp)
|
|
eos = gen_cfg.get("eos_token_id")
|
|
if isinstance(eos, list):
|
|
self.llm_eos_ids.extend(int(x) for x in eos)
|
|
elif isinstance(eos, int):
|
|
self.llm_eos_ids.append(eos)
|
|
if not self.llm_eos_ids:
|
|
eos = config.get("eos_token_id")
|
|
if isinstance(eos, list):
|
|
self.llm_eos_ids.extend(int(x) for x in eos)
|
|
elif isinstance(eos, int):
|
|
self.llm_eos_ids.append(eos)
|
|
|
|
# Auto-pick tool parser from options or model family.
|
|
parser_name = self.options.get("tool_parser") or _auto_tool_parser(self.model_ref, config)
|
|
self.tool_parser = resolve_parser(parser_name)
|
|
|
|
# --------------------- Stable Diffusion path ------------------------
|
|
|
|
def _load_sd(self, model_ref: str) -> None:
|
|
"""Load a Stable Diffusion 1.x checkpoint (CompVis `.ckpt` format)."""
|
|
from huggingface_hub import hf_hub_download
|
|
from tinygrad.nn.state import load_state_dict, torch_load
|
|
|
|
from vendor.stable_diffusion import StableDiffusion
|
|
|
|
ckpt_path = Path(model_ref)
|
|
if not ckpt_path.exists():
|
|
# Accept an HF repo id — fetch the canonical v1-5-pruned-emaonly.ckpt
|
|
# from the requested repo. Common case is runwayml/stable-diffusion-v1-5.
|
|
repo_id = model_ref if "/" in model_ref else "runwayml/stable-diffusion-v1-5"
|
|
ckpt_file = self.options.get("sd_ckpt_filename", "v1-5-pruned-emaonly.ckpt")
|
|
ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_file))
|
|
|
|
model = StableDiffusion()
|
|
state_dict = torch_load(str(ckpt_path))
|
|
if isinstance(state_dict, dict) and "state_dict" in state_dict:
|
|
state_dict = state_dict["state_dict"]
|
|
load_state_dict(model, state_dict, strict=False, verbose=False, realize=False)
|
|
self.sd_model = model
|
|
|
|
# --------------------- Whisper path ---------------------------------
|
|
|
|
def _load_whisper(self, model_ref: str) -> None:
|
|
"""Load a Whisper checkpoint (OpenAI `.pt` format).
|
|
|
|
Accepts a model-size alias (tiny / tiny.en / base / base.en / small /
|
|
small.en) OR an explicit `.pt` file path OR the HF repo id naming
|
|
convention `openai/whisper-*` (mapped to the matching OpenAI alias).
|
|
"""
|
|
from vendor.whisper import init_whisper, MODEL_URLS
|
|
|
|
alias = model_ref
|
|
if "/" in alias and alias.startswith("openai/whisper-"):
|
|
alias = alias.removeprefix("openai/whisper-")
|
|
if alias not in MODEL_URLS:
|
|
# Explicit path to a .pt checkpoint — fall back to size heuristic
|
|
# via filename.
|
|
basename = Path(alias).name.lower()
|
|
for name in MODEL_URLS:
|
|
if name in basename:
|
|
alias = name
|
|
break
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown Whisper model_ref={model_ref!r}; expected one of {list(MODEL_URLS)} "
|
|
f"or an openai/whisper-* HF id"
|
|
)
|
|
|
|
model, enc = init_whisper(alias, batch_size=1)
|
|
self.whisper_model = model
|
|
self.whisper_tokenizer = enc
|
|
|
|
# --------------------- LLM generation -------------------------------
|
|
|
|
def _generate_tokens(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int):
|
|
"""Synchronous generator yielding (token_id, token_text) pairs."""
|
|
from tinygrad import Tensor
|
|
|
|
encoded = self.llm_tokenizer.encode(prompt)
|
|
ids = encoded.ids
|
|
if not ids:
|
|
return
|
|
|
|
# Prefill: run one forward pass over the full prompt, sampling from the last token.
|
|
prompt_tensor = Tensor([ids])
|
|
next_tok = int(self.llm_model(prompt_tensor, 0, temperature=temperature, top_k=top_k, top_p=top_p).item())
|
|
pos = len(ids)
|
|
|
|
for _ in range(max_new_tokens):
|
|
if next_tok in self.llm_eos_ids:
|
|
break
|
|
text = self.llm_tokenizer.decode([next_tok])
|
|
yield next_tok, text
|
|
next_tok = int(self.llm_model(Tensor([[next_tok]]), pos, temperature=temperature,
|
|
top_k=top_k, top_p=top_p).item())
|
|
pos += 1
|
|
|
|
# --------------------- gRPC methods ---------------------------------
|
|
|
|
def Health(self, request, context):
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
|
|
async def LoadModel(self, request, context):
|
|
try:
|
|
_select_tinygrad_device()
|
|
self._reset_state()
|
|
self.options = self._parse_options(list(request.Options))
|
|
self.model_ref = request.ModelFile or request.Model
|
|
self.model_type = self._detect_model_type(self.model_ref, self.options.get("model_type"))
|
|
|
|
if self.model_type in ("sd15", "sd", "stable-diffusion"):
|
|
self._load_sd(self.model_ref)
|
|
return backend_pb2.Result(
|
|
success=True, message="tinygrad Stable Diffusion 1.x loaded",
|
|
)
|
|
|
|
if self.model_type == "whisper":
|
|
self._load_whisper(self.model_ref)
|
|
return backend_pb2.Result(
|
|
success=True, message="tinygrad Whisper loaded",
|
|
)
|
|
|
|
if self.model_type != "llm":
|
|
return backend_pb2.Result(
|
|
success=False,
|
|
message=f"tinygrad: model_type={self.model_type} not yet implemented",
|
|
)
|
|
|
|
model_path = _resolve_model_assets(self.model_ref)
|
|
if model_path.is_file():
|
|
return backend_pb2.Result(
|
|
success=False,
|
|
message=("tinygrad currently requires an HF model directory with config.json + "
|
|
"tokenizer.json; single-file GGUF support lands with gallery metadata wiring."),
|
|
)
|
|
self._load_llm(model_path)
|
|
|
|
return backend_pb2.Result(
|
|
success=True,
|
|
message=f"tinygrad LLM loaded (arch={self.llm_config.get('architectures', ['?'])[0]}, "
|
|
f"parser={self.tool_parser.name})",
|
|
)
|
|
except Exception as exc:
|
|
import traceback
|
|
traceback.print_exc()
|
|
return backend_pb2.Result(success=False, message=f"LoadModel failed: {exc}")
|
|
|
|
async def Predict(self, request, context):
|
|
if self.llm_model is None:
|
|
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
|
context.set_details("LLM not loaded")
|
|
return backend_pb2.Reply()
|
|
|
|
try:
|
|
prompt = self._render_prompt(request)
|
|
max_new = request.Tokens if request.Tokens > 0 else 256
|
|
temperature = request.Temperature if request.Temperature > 0 else 0.7
|
|
top_p = request.TopP if request.TopP > 0 else 0.95
|
|
top_k = request.TopK if request.TopK > 0 else 0
|
|
|
|
t0 = time.monotonic()
|
|
pieces: list[str] = []
|
|
ntok = 0
|
|
for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
|
|
pieces.append(text)
|
|
ntok += 1
|
|
elapsed = time.monotonic() - t0
|
|
|
|
full = "".join(pieces)
|
|
from tool_parsers.hermes import HermesToolParser
|
|
if isinstance(self.tool_parser, HermesToolParser):
|
|
result = self.tool_parser.parse_full(full)
|
|
content, calls, reasoning = result.content, result.tool_calls, result.reasoning
|
|
else:
|
|
content, calls = self.tool_parser.parse(full)
|
|
reasoning = ""
|
|
|
|
delta = backend_pb2.ChatDelta(
|
|
content=content,
|
|
reasoning_content=reasoning,
|
|
tool_calls=[
|
|
backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
|
|
for c in calls
|
|
],
|
|
)
|
|
return backend_pb2.Reply(
|
|
message=content.encode("utf-8"),
|
|
tokens=ntok,
|
|
timing_token_generation=elapsed,
|
|
chat_deltas=[delta],
|
|
)
|
|
except Exception as exc:
|
|
import traceback
|
|
traceback.print_exc()
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"Predict failed: {exc}")
|
|
return backend_pb2.Reply()
|
|
|
|
async def PredictStream(self, request, context):
|
|
if self.llm_model is None:
|
|
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
|
context.set_details("LLM not loaded")
|
|
return
|
|
|
|
try:
|
|
prompt = self._render_prompt(request)
|
|
max_new = request.Tokens if request.Tokens > 0 else 256
|
|
temperature = request.Temperature if request.Temperature > 0 else 0.7
|
|
top_p = request.TopP if request.TopP > 0 else 0.95
|
|
top_k = request.TopK if request.TopK > 0 else 0
|
|
|
|
buffer = ""
|
|
for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
|
|
buffer += text
|
|
yield backend_pb2.Reply(
|
|
message=text.encode("utf-8"),
|
|
chat_deltas=[backend_pb2.ChatDelta(content=text)],
|
|
)
|
|
|
|
# Final emission carries the extracted tool calls (vLLM semantics).
|
|
from tool_parsers.hermes import HermesToolParser
|
|
if isinstance(self.tool_parser, HermesToolParser):
|
|
result = self.tool_parser.parse_full(buffer)
|
|
calls = result.tool_calls
|
|
reasoning = result.reasoning
|
|
else:
|
|
_, calls = self.tool_parser.parse(buffer)
|
|
reasoning = ""
|
|
|
|
if calls or reasoning:
|
|
yield backend_pb2.Reply(
|
|
chat_deltas=[backend_pb2.ChatDelta(
|
|
reasoning_content=reasoning,
|
|
tool_calls=[
|
|
backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
|
|
for c in calls
|
|
],
|
|
)],
|
|
)
|
|
except Exception as exc:
|
|
import traceback
|
|
traceback.print_exc()
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"PredictStream failed: {exc}")
|
|
|
|
async def Embedding(self, request, context):
|
|
if self.llm_model is None or self.llm_tokenizer is None:
|
|
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
|
context.set_details("No model loaded")
|
|
return backend_pb2.EmbeddingResult()
|
|
|
|
try:
|
|
text = request.Embeddings
|
|
if not text:
|
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
|
context.set_details("Embeddings field is empty")
|
|
return backend_pb2.EmbeddingResult()
|
|
|
|
from tinygrad import Tensor, dtypes
|
|
|
|
ids = self.llm_tokenizer.encode(text).ids
|
|
if not ids:
|
|
return backend_pb2.EmbeddingResult(embeddings=[])
|
|
|
|
# Clamp to context window — truncate long inputs rather than blow up.
|
|
ids = ids[: self.max_context]
|
|
tokens = Tensor([ids])
|
|
|
|
hidden = self.llm_model.embed(tokens) # (1, seqlen, dim)
|
|
# Mean pool over sequence dim
|
|
pooled = hidden.mean(axis=1).squeeze(0) # (dim,)
|
|
# L2 normalize
|
|
norm = pooled.square().sum().sqrt()
|
|
normalized = (pooled / (norm + 1e-12))
|
|
vec = normalized.cast(dtypes.float32).tolist()
|
|
|
|
return backend_pb2.EmbeddingResult(embeddings=[float(x) for x in vec])
|
|
except Exception as exc:
|
|
import traceback
|
|
traceback.print_exc()
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"Embedding failed: {exc}")
|
|
return backend_pb2.EmbeddingResult()
|
|
|
|
async def GenerateImage(self, request, context):
|
|
if self.sd_model is None:
|
|
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
|
context.set_details("No Stable Diffusion model loaded")
|
|
return backend_pb2.Result(success=False, message="not loaded")
|
|
|
|
try:
|
|
from PIL import Image
|
|
from vendor.stable_diffusion import run_sd15
|
|
|
|
steps = request.step if request.step > 0 else 20
|
|
guidance = 7.5
|
|
seed = request.seed if request.seed != 0 else None
|
|
img_tensor = run_sd15(
|
|
model=self.sd_model,
|
|
prompt=request.positive_prompt or "",
|
|
negative_prompt=request.negative_prompt or "",
|
|
steps=steps,
|
|
guidance=guidance,
|
|
seed=seed,
|
|
)
|
|
arr = img_tensor.numpy()
|
|
image = Image.fromarray(arr)
|
|
dst = request.dst or "/tmp/tinygrad_image.png"
|
|
image.save(dst)
|
|
return backend_pb2.Result(success=True, message=dst)
|
|
except Exception as exc:
|
|
import traceback
|
|
traceback.print_exc()
|
|
return backend_pb2.Result(success=False, message=f"GenerateImage failed: {exc}")
|
|
|
|
def _transcribe(self, audio_path: str, language: Optional[str]) -> tuple[str, float]:
|
|
from vendor.whisper import load_file_waveform, transcribe_waveform
|
|
|
|
waveform = load_file_waveform(audio_path)
|
|
text = transcribe_waveform(
|
|
self.whisper_model,
|
|
self.whisper_tokenizer,
|
|
[waveform],
|
|
language=language or None,
|
|
)
|
|
duration = float(len(waveform)) / 16000.0
|
|
return text, duration
|
|
|
|
async def AudioTranscription(self, request, context):
|
|
if self.whisper_model is None or self.whisper_tokenizer is None:
|
|
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
|
context.set_details("No Whisper model loaded")
|
|
return backend_pb2.TranscriptResult()
|
|
|
|
try:
|
|
if not request.dst:
|
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
|
context.set_details("TranscriptRequest.dst (audio file path) is required")
|
|
return backend_pb2.TranscriptResult()
|
|
|
|
text, duration = self._transcribe(request.dst, request.language)
|
|
segments = [backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)]
|
|
return backend_pb2.TranscriptResult(
|
|
text=text,
|
|
language=request.language or "en",
|
|
duration=duration,
|
|
segments=segments,
|
|
)
|
|
except Exception as exc:
|
|
import traceback
|
|
traceback.print_exc()
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"AudioTranscription failed: {exc}")
|
|
return backend_pb2.TranscriptResult()
|
|
|
|
async def AudioTranscriptionStream(self, request, context):
|
|
if self.whisper_model is None or self.whisper_tokenizer is None:
|
|
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
|
context.set_details("No Whisper model loaded")
|
|
return
|
|
|
|
try:
|
|
if not request.dst:
|
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
|
context.set_details("TranscriptRequest.dst (audio file path) is required")
|
|
return
|
|
|
|
# The vendored tinygrad whisper loop is chunked at the file level
|
|
# (one inference pass per 30s segment), not token-level. To still
|
|
# produce a streaming response we run the full transcription and
|
|
# emit it as a single delta + a final-result envelope so the client
|
|
# gets both code paths exercised.
|
|
text, duration = self._transcribe(request.dst, request.language)
|
|
yield backend_pb2.TranscriptStreamResponse(delta=text)
|
|
final = backend_pb2.TranscriptResult(
|
|
text=text,
|
|
language=request.language or "en",
|
|
duration=duration,
|
|
segments=[backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)],
|
|
)
|
|
yield backend_pb2.TranscriptStreamResponse(final_result=final)
|
|
except Exception as exc:
|
|
import traceback
|
|
traceback.print_exc()
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"AudioTranscriptionStream failed: {exc}")
|
|
|
|
async def Status(self, request, context):
|
|
return backend_pb2.StatusResponse(state=backend_pb2.StatusResponse.READY)
|
|
|
|
async def Free(self, request, context):
|
|
self._reset_state()
|
|
return backend_pb2.Result(success=True, message="freed")
|
|
|
|
|
|
async def serve(address):
|
|
server = grpc.aio.server(
|
|
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
options=[
|
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
|
],
|
|
interceptors=get_auth_interceptors(aio=True),
|
|
)
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
server.add_insecure_port(address)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
|
|
|
|
await server.start()
|
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
await server.wait_for_termination()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the tinygrad gRPC backend.")
|
|
parser.add_argument("--addr", default="localhost:50051", help="Bind address")
|
|
args = parser.parse_args()
|
|
asyncio.run(serve(args.addr))
|