feat(backend): add tinygrad multimodal backend (experimental) (#9364)

* feat(backend): add tinygrad multimodal backend

Wire tinygrad as a new Python backend covering LLM text generation with
native tool-call extraction, embeddings, Stable Diffusion 1.x image
generation, and Whisper speech-to-text from a single self-contained
container.

Backend (`backend/python/tinygrad/`):
- `backend.py` gRPC servicer with LLM Predict/PredictStream (auto-detects
  Llama / Qwen2 / Mistral architecture from `config.json`, supports
  safetensors and GGUF), Embedding via mean-pooled last hidden state,
  GenerateImage via the vendored SD1.x pipeline, AudioTranscription +
  AudioTranscriptionStream via the vendored Whisper inference loop, plus
  Tokenize / ModelMetadata / Status / Free.
- Vendored upstream model code under `vendor/` (MIT, headers preserved):
  llama.py with an added `qkv_bias` flag for Qwen2-family bias support
  and an `embed()` method that returns the last hidden state, plus
  clip.py, unet.py, stable_diffusion.py (trimmed to drop the MLPerf
  training branch that pulls `mlperf.initializers`), audio_helpers.py
  and whisper.py (trimmed to drop the pyaudio listener).
- Pluggable tool-call parsers under `tool_parsers/`: hermes (Qwen2.5 /
  Hermes), llama3_json (Llama 3.1+), qwen3_xml (Qwen 3), mistral
  (Mistral / Mixtral). Auto-selected from model architecture or `Options`.
- `install.sh` pins Python 3.11.14 (tinygrad >=0.12 needs >=3.11; the
  default portable python is 3.10).
- `package.sh` bundles libLLVM.so.1 + libedit/libtinfo/libgomp/libsndfile
  into the scratch image. `run.sh` sets `CPU_LLVM=1` and `LLVM_PATH` so
  tinygrad's CPU device uses the in-process libLLVM JIT instead of
  shelling out to the missing `clang` binary.
- Local unit tests for Health and the four parsers in `test.py`.

Build wiring:
- Root `Makefile`: `.NOTPARALLEL`, `prepare-test-extra`, `test-extra`,
  `BACKEND_TINYGRAD = tinygrad|python|.|false|true`,
  docker-build-target eval, and `docker-build-backends` aggregator.
- `.github/workflows/backend.yml`: cpu / cuda12 / cuda13 build matrix
  entries (mirrors the transformers backend placement).
- `backend/index.yaml`: `&tinygrad` meta + cpu/cuda12/cuda13 image
  entries (latest + development).

E2E test wiring:
- `tests/e2e-backends/backend_test.go` gains an `image` capability that
  exercises GenerateImage and asserts a non-empty PNG is written to
  `dst`. New `BACKEND_TEST_IMAGE_PROMPT` / `BACKEND_TEST_IMAGE_STEPS`
  knobs.
- Five new make targets next to `test-extra-backend-vllm`:
  - `test-extra-backend-tinygrad` — Qwen2.5-0.5B-Instruct + hermes,
    mirrors the vllm target 1:1 (5/9 specs in ~57s).
  - `test-extra-backend-tinygrad-embeddings` — same model, embeddings
    via LLM hidden state (3/9 in ~10s).
  - `test-extra-backend-tinygrad-sd` — stable-diffusion-v1-5 mirror,
    health/load/image (3/9 in ~10min, 4 diffusion steps on CPU).
  - `test-extra-backend-tinygrad-whisper` — openai/whisper-tiny.en
    against jfk.wav from whisper.cpp samples (4/9 in ~49s).
  - `test-extra-backend-tinygrad-all` aggregate.

All four targets land green on the first MVP pass: 15 specs total, 0
failures across LLM+tools, embeddings, image generation, and speech
transcription.

* refactor(tinygrad): collapse to a single backend image

tinygrad generates its own GPU kernels (PTX renderer for CUDA, the
autogen ctypes wrappers for HIP / Metal / WebGPU) and never links
against cuDNN, cuBLAS, or any toolkit-version-tied library. The only
runtime dependency that varies across hosts is the driver's libcuda.so.1
/ libamdhip64.so, which are injected into the container at run time by
the nvidia-container / rocm runtimes. So unlike torch- or vLLM-based
backends, there is no reason to ship per-CUDA-version images.

- Drop the cuda12-tinygrad and cuda13-tinygrad build-matrix entries
  from .github/workflows/backend.yml. The sole remaining entry is
  renamed to -tinygrad (from -cpu-tinygrad) since it is no longer
  CPU-only.
- Collapse backend/index.yaml to a single meta + development pair.
  The meta anchor carries the latest uri directly; the development
  entry points at the master tag.
- run.sh picks the tinygrad device at launch time by probing
  /usr/lib/... for libcuda.so.1 / libamdhip64.so. When libcuda is
  visible we set CUDA=1 + CUDA_PTX=1 so tinygrad uses its own PTX
  renderer (avoids any nvrtc/toolkit dependency); otherwise we fall
  back to HIP or CLANG. CPU_LLVM=1 + LLVM_PATH keep the in-process
  libLLVM JIT for the CLANG path.
- backend.py's _select_tinygrad_device() is trimmed to a CLANG-only
  fallback since production device selection happens in run.sh.

Re-ran test-extra-backend-tinygrad after the change:
  Ran 5 of 9 Specs in 56.541 seconds — 5 Passed, 0 Failed
This commit is contained in:
Ettore Di Giacinto
2026-04-15 19:48:23 +02:00
committed by GitHub
parent 8487058673
commit 6f0051301b
29 changed files with 3265 additions and 9 deletions

View File

@@ -105,6 +105,25 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
# tinygrad ships a single image — its CPU device uses bundled
# libLLVM, and its CUDA / HIP / Metal devices dlopen the host
# driver libraries at runtime via tinygrad's ctypes autogen
# wrappers. There is no toolkit-version split because tinygrad
# generates kernels itself (PTX renderer for CUDA) and never
# links against cuDNN/cuBLAS/torch.
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-tinygrad'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'true'
backend: "tinygrad"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""

View File

@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad
GOCMD=go
GOTEST=$(GOCMD) test
@@ -432,6 +432,7 @@ prepare-test-extra: protogen-python
$(MAKE) -C backend/python/whisperx
$(MAKE) -C backend/python/ace-step
$(MAKE) -C backend/python/trl
$(MAKE) -C backend/python/tinygrad
$(MAKE) -C backend/rust/kokoros kokoros-grpc
test-extra: prepare-test-extra
@@ -454,6 +455,7 @@ test-extra: prepare-test-extra
$(MAKE) -C backend/python/whisperx test
$(MAKE) -C backend/python/ace-step test
$(MAKE) -C backend/python/trl test
$(MAKE) -C backend/python/tinygrad test
$(MAKE) -C backend/rust/kokoros test
##
@@ -550,6 +552,56 @@ test-extra-backend-vllm: docker-build-vllm
BACKEND_TEST_OPTIONS=tool_parser:hermes \
$(MAKE) test-extra-backend
## tinygrad mirrors the vllm target (same model, same caps, same parser) so
## the two backends are directly comparable. The LLM path covers Predict,
## streaming and native tool-call extraction. Companion targets below cover
## embeddings, Stable Diffusion and Whisper — run them individually or via
## the `test-extra-backend-tinygrad-all` aggregate.
test-extra-backend-tinygrad: docker-build-tinygrad
BACKEND_IMAGE=local-ai-backend:tinygrad \
BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
BACKEND_TEST_OPTIONS=tool_parser:hermes \
$(MAKE) test-extra-backend
## tinygrad — embeddings via LLM last-hidden-state pooling. Reuses the same
## Qwen2.5-0.5B-Instruct as the chat target so we don't need a separate BERT
## vendor; the Embedding RPC mean-pools and L2-normalizes the last-layer
## hidden state.
test-extra-backend-tinygrad-embeddings: docker-build-tinygrad
BACKEND_IMAGE=local-ai-backend:tinygrad \
BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \
BACKEND_TEST_CAPS=health,load,embeddings \
$(MAKE) test-extra-backend
## tinygrad — Stable Diffusion 1.5. The original CompVis/runwayml repos have
## been gated, so we use the community-maintained mirror at
## stable-diffusion-v1-5/stable-diffusion-v1-5 with the EMA-only pruned
## checkpoint (~4.3GB). Step count is kept low (4) so a CPU-only run finishes
## in a few minutes; bump BACKEND_TEST_IMAGE_STEPS for higher quality.
test-extra-backend-tinygrad-sd: docker-build-tinygrad
BACKEND_IMAGE=local-ai-backend:tinygrad \
BACKEND_TEST_MODEL_NAME=stable-diffusion-v1-5/stable-diffusion-v1-5 \
BACKEND_TEST_CAPS=health,load,image \
$(MAKE) test-extra-backend
## tinygrad — Whisper. Loads OpenAI's tiny.en checkpoint (smallest at ~75MB)
## from the original azure CDN through tinygrad's `fetch` helper, and
## transcribes the canonical jfk.wav fixture from whisper.cpp's CI samples.
## Exercises both AudioTranscription and AudioTranscriptionStream.
test-extra-backend-tinygrad-whisper: docker-build-tinygrad
BACKEND_IMAGE=local-ai-backend:tinygrad \
BACKEND_TEST_MODEL_NAME=openai/whisper-tiny.en \
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
BACKEND_TEST_CAPS=health,load,transcription \
$(MAKE) test-extra-backend
test-extra-backend-tinygrad-all: \
test-extra-backend-tinygrad \
test-extra-backend-tinygrad-embeddings \
test-extra-backend-tinygrad-sd \
test-extra-backend-tinygrad-whisper
## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool
## parser from the chat template, so no tool_parser: option is needed (it
## would be ignored at runtime). Run this on macOS / arm64 with Metal; the
@@ -707,6 +759,7 @@ BACKEND_MLX_VLM = mlx-vlm|python|.|false|true
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
BACKEND_TRL = trl|python|.|false|true
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
BACKEND_TINYGRAD = tinygrad|python|.|false|true
# Rust backends
BACKEND_KOKOROS = kokoros|rust|.|false|true
@@ -778,6 +831,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
$(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD)))
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
@@ -785,7 +839,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
docker-save-%: backend-images
docker save local-ai-backend:$* -o backend-images/$*.tar
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
########################################################
### Mock Backend for E2E Tests

View File

@@ -361,6 +361,34 @@
intel: "intel-rerankers"
amd: "rocm-rerankers"
metal: "metal-rerankers"
- &tinygrad
name: "tinygrad"
alias: "tinygrad"
license: MIT
description: |
tinygrad is a minimalist deep-learning framework with zero runtime
dependencies that targets CUDA, ROCm, Metal, WebGPU and CPU (CLANG).
The LocalAI tinygrad backend exposes a single multimodal runtime that
covers LLM text generation (Llama / Qwen / Mistral via safetensors or
GGUF) with native tool-call extraction, BERT-family embeddings,
Stable Diffusion 1.x / 2 / XL image generation, and Whisper speech-to-text.
Single image: tinygrad generates its own GPU kernels and dlopens the
host driver libraries at runtime, so there is no per-toolkit build
split. The same image runs CPU-only or accelerates against
CUDA / ROCm / Metal when the host driver is visible.
urls:
- https://github.com/tinygrad/tinygrad
uri: "quay.io/go-skynet/local-ai-backends:latest-tinygrad"
mirrors:
- localai/localai-backends:latest-tinygrad
tags:
- text-to-text
- LLM
- embeddings
- image-generation
- transcription
- multimodal
- &transformers
name: "transformers"
icon: https://avatars.githubusercontent.com/u/25720743?s=200&v=4
@@ -1993,6 +2021,15 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rerankers"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-rerankers
## tinygrad
## Single image — the meta anchor above carries the latest uri directly
## since there is only one variant. The development entry below points at
## the master tag.
- !!merge <<: *tinygrad
name: "tinygrad-development"
uri: "quay.io/go-skynet/local-ai-backends:master-tinygrad"
mirrors:
- localai/localai-backends:master-tinygrad
## Transformers
- !!merge <<: *transformers
name: "transformers-development"

View File

@@ -0,0 +1,25 @@
.DEFAULT_GOAL := install
.PHONY: install
install:
bash install.sh
.PHONY: run
run: install
@echo "Running tinygrad..."
bash run.sh
@echo "tinygrad run."
.PHONY: test
test: install
@echo "Testing tinygrad..."
bash test.sh
@echo "tinygrad tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -0,0 +1,750 @@
#!/usr/bin/env python3
"""
LocalAI gRPC backend for tinygrad.
Multimodal scope (landing incrementally):
- LLM text generation (Llama 3 / Qwen 2.5 / Mistral via HF safetensors + GGUF)
- Native tool-call extraction via pluggable parsers (hermes, llama3_json, ...)
- Embeddings, Stable Diffusion, Whisper — planned; currently return UNIMPLEMENTED.
The heavy imports (tinygrad, tokenizers, vendor.llama) are deferred until
`LoadModel`, because tinygrad binds its compute device at import time from
env vars. `_select_tinygrad_device()` maps LocalAI's BUILD_TYPE onto the
corresponding tinygrad env flag before any import happens.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import signal
import sys
import time
from concurrent import futures
from pathlib import Path
from typing import Any, Optional
import grpc
import backend_pb2
import backend_pb2_grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors # noqa: E402
from tool_parsers import resolve_parser # noqa: E402
from tool_parsers.base import ToolCall # noqa: E402
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# ---------------------------------------------------------------------------
# Device selection — must run BEFORE `import tinygrad` anywhere.
#
# In production this is set by run.sh based on which driver libraries the
# host has injected into the container (libcuda.so.1 → CUDA, libamdhip64
# → HIP, otherwise CLANG). This helper is only a fallback for direct
# invocations like the unit tests.
# ---------------------------------------------------------------------------
def _select_tinygrad_device() -> None:
if any(os.environ.get(k) == "1" for k in ("CUDA", "HIP", "METAL", "CLANG", "AMD", "NV")):
return
os.environ["CLANG"] = "1"
# ---------------------------------------------------------------------------
# Model asset discovery
# ---------------------------------------------------------------------------
def _resolve_model_assets(model_ref: str) -> Path:
"""
Accept either a local path or a HuggingFace repo id (e.g.
"Qwen/Qwen2.5-0.5B-Instruct") and return the local directory / file.
HF ids are materialized via `huggingface_hub.snapshot_download`.
"""
p = Path(model_ref)
if p.exists():
return p
if "/" in model_ref and not model_ref.startswith(("/", ".")):
from huggingface_hub import snapshot_download
local = snapshot_download(
repo_id=model_ref,
allow_patterns=[
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"generation_config.json",
"*.safetensors",
"*.safetensors.index.json",
],
)
return Path(local)
raise FileNotFoundError(f"Model not found: {model_ref}")
def _load_llm_weights(model_dir: Path):
"""Load HF safetensors weights from a directory or a single file."""
from tinygrad import Device, Tensor, dtypes
from tinygrad.nn.state import safe_load, gguf_load
if model_dir.is_file():
if str(model_dir).endswith(".gguf"):
gguf_tensor = Tensor.empty(os.stat(model_dir).st_size, dtype=dtypes.uint8,
device=f"disk:{model_dir}").to(Device.DEFAULT)
return gguf_load(gguf_tensor)[1], "gguf"
return safe_load(str(model_dir)), "safetensors"
index = model_dir / "model.safetensors.index.json"
if index.exists():
with open(index) as fp:
weight_map = json.load(fp)["weight_map"]
shards: dict[str, Any] = {}
for shard_name in set(weight_map.values()):
shards[shard_name] = safe_load(str(model_dir / shard_name))
merged = {k: shards[n][k] for k, n in weight_map.items()}
return merged, "safetensors"
single = model_dir / "model.safetensors"
if single.exists():
return safe_load(str(single)), "safetensors"
ggufs = sorted(model_dir.glob("*.gguf"))
if ggufs:
gguf_tensor = Tensor.empty(os.stat(ggufs[0]).st_size, dtype=dtypes.uint8,
device=f"disk:{ggufs[0]}").to(Device.DEFAULT)
return gguf_load(gguf_tensor)[1], "gguf"
raise FileNotFoundError(f"No safetensors or gguf weights found under {model_dir}")
def _infer_qkv_bias(config: dict) -> bool:
"""Qwen2-family uses Q/K/V projection bias; Llama/Mistral do not."""
architectures = [a.lower() for a in config.get("architectures", [])]
if any("qwen" in a for a in architectures):
return True
return bool(config.get("attention_bias", False)) or bool(config.get("qkv_bias", False))
def _auto_tool_parser(model_ref: Optional[str], config: dict) -> Optional[str]:
"""Pick a tool parser automatically from model family heuristics.
Order of precedence: architecture name from config.json, then model ref
string. Returns None to fall through to the passthrough parser.
"""
arches = " ".join(a.lower() for a in config.get("architectures", []))
ref = (model_ref or "").lower()
blob = f"{arches} {ref}"
if "qwen3" in blob:
return "qwen3_xml"
if "hermes" in blob or "qwen2" in blob or "qwen" in blob:
return "hermes"
if "llama-3" in blob or "llama_3" in blob or "llama3" in blob:
return "llama3_json"
if "mistral" in blob or "mixtral" in blob:
return "mistral"
return None
# ---------------------------------------------------------------------------
# Servicer
# ---------------------------------------------------------------------------
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""gRPC servicer for the tinygrad backend."""
def __init__(self) -> None:
self._reset_state()
def _reset_state(self) -> None:
self.model_ref: Optional[str] = None
self.model_type: str = "llm"
self.options: dict[str, str] = {}
# LLM state
self.llm_model = None
self.llm_config: dict = {}
self.llm_tokenizer = None
self.llm_eos_ids: list[int] = []
self.chat_template: Optional[str] = None
self.tool_parser = resolve_parser(None)
self.max_context = 4096
# Stable Diffusion state
self.sd_model = None
# Whisper state
self.whisper_model = None
self.whisper_tokenizer = None
# --------------------- helpers --------------------------------------
@staticmethod
def _parse_options(options_list) -> dict[str, str]:
opts: dict[str, str] = {}
for opt in options_list:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
opts[key.strip()] = value.strip()
return opts
@staticmethod
def _detect_model_type(model_ref: str, explicit: Optional[str]) -> str:
if explicit:
return explicit
name = (model_ref or "").lower()
if "whisper" in name:
return "whisper"
if "sdxl" in name:
return "sdxl"
if "sd-v1" in name or "v1-5" in name or "stable-diffusion" in name:
return "sd15"
if any(tag in name for tag in ("bge", "e5", "minilm", "bert")):
return "bert"
return "llm"
def _messages_to_dicts(self, messages) -> list[dict]:
result = []
for msg in messages:
d: dict = {"role": msg.role, "content": msg.content or ""}
if msg.name:
d["name"] = msg.name
if msg.tool_call_id:
d["tool_call_id"] = msg.tool_call_id
if msg.reasoning_content:
d["reasoning_content"] = msg.reasoning_content
if msg.tool_calls:
try:
d["tool_calls"] = json.loads(msg.tool_calls)
except json.JSONDecodeError:
pass
result.append(d)
return result
def _render_prompt(self, request) -> str:
"""Render messages + tools into the model's chat template, or fall
back to the raw Prompt field for models without a template."""
if not request.Messages and request.Prompt:
return request.Prompt
if not self.chat_template:
# No template known — concatenate role/content lines.
lines = []
for msg in request.Messages:
lines.append(f"{msg.role}: {msg.content or ''}")
return "\n".join(lines) + "\nassistant:"
from jinja2 import Environment
env = Environment(trim_blocks=True, lstrip_blocks=True)
template = env.from_string(self.chat_template)
tools = None
if request.Tools:
try:
tools = json.loads(request.Tools)
except json.JSONDecodeError:
tools = None
return template.render(
messages=self._messages_to_dicts(request.Messages),
tools=tools,
add_generation_prompt=True,
)
# --------------------- LLM path -------------------------------------
def _load_llm(self, model_dir: Path) -> None:
config_path = model_dir / "config.json" if model_dir.is_dir() else None
if config_path is None or not config_path.exists():
raise FileNotFoundError(f"config.json not found under {model_dir}")
with open(config_path) as fp:
config = json.load(fp)
self.llm_config = config
from tinygrad import nn
from vendor.llama import Transformer, convert_from_huggingface, convert_from_gguf, fix_bf16
n_layers = config["num_hidden_layers"]
n_heads = config["num_attention_heads"]
n_kv_heads = config.get("num_key_value_heads", n_heads)
dim = config["hidden_size"]
hidden_dim = config["intermediate_size"]
vocab_size = config["vocab_size"]
norm_eps = config.get("rms_norm_eps", 1e-5)
rope_theta = config.get("rope_theta", 10000)
self.max_context = min(config.get("max_position_embeddings", 4096), 8192)
qkv_bias = _infer_qkv_bias(config)
weights, weight_format = _load_llm_weights(model_dir)
model = Transformer(
dim=dim,
hidden_dim=hidden_dim,
n_heads=n_heads,
n_layers=n_layers,
norm_eps=norm_eps,
vocab_size=vocab_size,
linear=nn.Linear,
embedding=nn.Embedding,
n_kv_heads=n_kv_heads,
rope_theta=rope_theta,
max_context=self.max_context,
jit=True,
qkv_bias=qkv_bias,
)
if weight_format == "safetensors":
weights = convert_from_huggingface(weights, n_layers, n_heads, n_kv_heads)
else:
weights = convert_from_gguf(weights, n_layers)
weights = fix_bf16(weights)
from tinygrad.nn.state import load_state_dict
load_state_dict(model, weights, strict=False, consume=True)
self.llm_model = model
# Tokenizer
tokenizer_json = model_dir / "tokenizer.json"
if not tokenizer_json.exists():
raise FileNotFoundError(f"tokenizer.json not found under {model_dir}")
from tokenizers import Tokenizer as HFTokenizer
self.llm_tokenizer = HFTokenizer.from_file(str(tokenizer_json))
# Chat template + EOS ids (from tokenizer_config.json + generation_config.json)
tok_cfg_path = model_dir / "tokenizer_config.json"
if tok_cfg_path.exists():
with open(tok_cfg_path) as fp:
tok_cfg = json.load(fp)
self.chat_template = tok_cfg.get("chat_template")
self.llm_eos_ids = []
gen_cfg_path = model_dir / "generation_config.json"
if gen_cfg_path.exists():
with open(gen_cfg_path) as fp:
gen_cfg = json.load(fp)
eos = gen_cfg.get("eos_token_id")
if isinstance(eos, list):
self.llm_eos_ids.extend(int(x) for x in eos)
elif isinstance(eos, int):
self.llm_eos_ids.append(eos)
if not self.llm_eos_ids:
eos = config.get("eos_token_id")
if isinstance(eos, list):
self.llm_eos_ids.extend(int(x) for x in eos)
elif isinstance(eos, int):
self.llm_eos_ids.append(eos)
# Auto-pick tool parser from options or model family.
parser_name = self.options.get("tool_parser") or _auto_tool_parser(self.model_ref, config)
self.tool_parser = resolve_parser(parser_name)
# --------------------- Stable Diffusion path ------------------------
def _load_sd(self, model_ref: str) -> None:
"""Load a Stable Diffusion 1.x checkpoint (CompVis `.ckpt` format)."""
from huggingface_hub import hf_hub_download
from tinygrad.nn.state import load_state_dict, torch_load
from vendor.stable_diffusion import StableDiffusion
ckpt_path = Path(model_ref)
if not ckpt_path.exists():
# Accept an HF repo id — fetch the canonical v1-5-pruned-emaonly.ckpt
# from the requested repo. Common case is runwayml/stable-diffusion-v1-5.
repo_id = model_ref if "/" in model_ref else "runwayml/stable-diffusion-v1-5"
ckpt_file = self.options.get("sd_ckpt_filename", "v1-5-pruned-emaonly.ckpt")
ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_file))
model = StableDiffusion()
state_dict = torch_load(str(ckpt_path))
if isinstance(state_dict, dict) and "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
load_state_dict(model, state_dict, strict=False, verbose=False, realize=False)
self.sd_model = model
# --------------------- Whisper path ---------------------------------
def _load_whisper(self, model_ref: str) -> None:
"""Load a Whisper checkpoint (OpenAI `.pt` format).
Accepts a model-size alias (tiny / tiny.en / base / base.en / small /
small.en) OR an explicit `.pt` file path OR the HF repo id naming
convention `openai/whisper-*` (mapped to the matching OpenAI alias).
"""
from vendor.whisper import init_whisper, MODEL_URLS
alias = model_ref
if "/" in alias and alias.startswith("openai/whisper-"):
alias = alias.removeprefix("openai/whisper-")
if alias not in MODEL_URLS:
# Explicit path to a .pt checkpoint — fall back to size heuristic
# via filename.
basename = Path(alias).name.lower()
for name in MODEL_URLS:
if name in basename:
alias = name
break
else:
raise ValueError(
f"Unknown Whisper model_ref={model_ref!r}; expected one of {list(MODEL_URLS)} "
f"or an openai/whisper-* HF id"
)
model, enc = init_whisper(alias, batch_size=1)
self.whisper_model = model
self.whisper_tokenizer = enc
# --------------------- LLM generation -------------------------------
def _generate_tokens(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int):
"""Synchronous generator yielding (token_id, token_text) pairs."""
from tinygrad import Tensor
encoded = self.llm_tokenizer.encode(prompt)
ids = encoded.ids
if not ids:
return
# Prefill: run one forward pass over the full prompt, sampling from the last token.
prompt_tensor = Tensor([ids])
next_tok = int(self.llm_model(prompt_tensor, 0, temperature=temperature, top_k=top_k, top_p=top_p).item())
pos = len(ids)
for _ in range(max_new_tokens):
if next_tok in self.llm_eos_ids:
break
text = self.llm_tokenizer.decode([next_tok])
yield next_tok, text
next_tok = int(self.llm_model(Tensor([[next_tok]]), pos, temperature=temperature,
top_k=top_k, top_p=top_p).item())
pos += 1
# --------------------- gRPC methods ---------------------------------
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
async def LoadModel(self, request, context):
try:
_select_tinygrad_device()
self._reset_state()
self.options = self._parse_options(list(request.Options))
self.model_ref = request.ModelFile or request.Model
self.model_type = self._detect_model_type(self.model_ref, self.options.get("model_type"))
if self.model_type in ("sd15", "sd", "stable-diffusion"):
self._load_sd(self.model_ref)
return backend_pb2.Result(
success=True, message="tinygrad Stable Diffusion 1.x loaded",
)
if self.model_type == "whisper":
self._load_whisper(self.model_ref)
return backend_pb2.Result(
success=True, message="tinygrad Whisper loaded",
)
if self.model_type != "llm":
return backend_pb2.Result(
success=False,
message=f"tinygrad: model_type={self.model_type} not yet implemented",
)
model_path = _resolve_model_assets(self.model_ref)
if model_path.is_file():
return backend_pb2.Result(
success=False,
message=("tinygrad currently requires an HF model directory with config.json + "
"tokenizer.json; single-file GGUF support lands with gallery metadata wiring."),
)
self._load_llm(model_path)
return backend_pb2.Result(
success=True,
message=f"tinygrad LLM loaded (arch={self.llm_config.get('architectures', ['?'])[0]}, "
f"parser={self.tool_parser.name})",
)
except Exception as exc:
import traceback
traceback.print_exc()
return backend_pb2.Result(success=False, message=f"LoadModel failed: {exc}")
async def Predict(self, request, context):
if self.llm_model is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("LLM not loaded")
return backend_pb2.Reply()
try:
prompt = self._render_prompt(request)
max_new = request.Tokens if request.Tokens > 0 else 256
temperature = request.Temperature if request.Temperature > 0 else 0.7
top_p = request.TopP if request.TopP > 0 else 0.95
top_k = request.TopK if request.TopK > 0 else 0
t0 = time.monotonic()
pieces: list[str] = []
ntok = 0
for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
pieces.append(text)
ntok += 1
elapsed = time.monotonic() - t0
full = "".join(pieces)
from tool_parsers.hermes import HermesToolParser
if isinstance(self.tool_parser, HermesToolParser):
result = self.tool_parser.parse_full(full)
content, calls, reasoning = result.content, result.tool_calls, result.reasoning
else:
content, calls = self.tool_parser.parse(full)
reasoning = ""
delta = backend_pb2.ChatDelta(
content=content,
reasoning_content=reasoning,
tool_calls=[
backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
for c in calls
],
)
return backend_pb2.Reply(
message=content.encode("utf-8"),
tokens=ntok,
timing_token_generation=elapsed,
chat_deltas=[delta],
)
except Exception as exc:
import traceback
traceback.print_exc()
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Predict failed: {exc}")
return backend_pb2.Reply()
async def PredictStream(self, request, context):
if self.llm_model is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("LLM not loaded")
return
try:
prompt = self._render_prompt(request)
max_new = request.Tokens if request.Tokens > 0 else 256
temperature = request.Temperature if request.Temperature > 0 else 0.7
top_p = request.TopP if request.TopP > 0 else 0.95
top_k = request.TopK if request.TopK > 0 else 0
buffer = ""
for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
buffer += text
yield backend_pb2.Reply(
message=text.encode("utf-8"),
chat_deltas=[backend_pb2.ChatDelta(content=text)],
)
# Final emission carries the extracted tool calls (vLLM semantics).
from tool_parsers.hermes import HermesToolParser
if isinstance(self.tool_parser, HermesToolParser):
result = self.tool_parser.parse_full(buffer)
calls = result.tool_calls
reasoning = result.reasoning
else:
_, calls = self.tool_parser.parse(buffer)
reasoning = ""
if calls or reasoning:
yield backend_pb2.Reply(
chat_deltas=[backend_pb2.ChatDelta(
reasoning_content=reasoning,
tool_calls=[
backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
for c in calls
],
)],
)
except Exception as exc:
import traceback
traceback.print_exc()
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"PredictStream failed: {exc}")
async def Embedding(self, request, context):
if self.llm_model is None or self.llm_tokenizer is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("No model loaded")
return backend_pb2.EmbeddingResult()
try:
text = request.Embeddings
if not text:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details("Embeddings field is empty")
return backend_pb2.EmbeddingResult()
from tinygrad import Tensor, dtypes
ids = self.llm_tokenizer.encode(text).ids
if not ids:
return backend_pb2.EmbeddingResult(embeddings=[])
# Clamp to context window — truncate long inputs rather than blow up.
ids = ids[: self.max_context]
tokens = Tensor([ids])
hidden = self.llm_model.embed(tokens) # (1, seqlen, dim)
# Mean pool over sequence dim
pooled = hidden.mean(axis=1).squeeze(0) # (dim,)
# L2 normalize
norm = pooled.square().sum().sqrt()
normalized = (pooled / (norm + 1e-12))
vec = normalized.cast(dtypes.float32).tolist()
return backend_pb2.EmbeddingResult(embeddings=[float(x) for x in vec])
except Exception as exc:
import traceback
traceback.print_exc()
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Embedding failed: {exc}")
return backend_pb2.EmbeddingResult()
async def GenerateImage(self, request, context):
if self.sd_model is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("No Stable Diffusion model loaded")
return backend_pb2.Result(success=False, message="not loaded")
try:
from PIL import Image
from vendor.stable_diffusion import run_sd15
steps = request.step if request.step > 0 else 20
guidance = 7.5
seed = request.seed if request.seed != 0 else None
img_tensor = run_sd15(
model=self.sd_model,
prompt=request.positive_prompt or "",
negative_prompt=request.negative_prompt or "",
steps=steps,
guidance=guidance,
seed=seed,
)
arr = img_tensor.numpy()
image = Image.fromarray(arr)
dst = request.dst or "/tmp/tinygrad_image.png"
image.save(dst)
return backend_pb2.Result(success=True, message=dst)
except Exception as exc:
import traceback
traceback.print_exc()
return backend_pb2.Result(success=False, message=f"GenerateImage failed: {exc}")
def _transcribe(self, audio_path: str, language: Optional[str]) -> tuple[str, float]:
from vendor.whisper import load_file_waveform, transcribe_waveform
waveform = load_file_waveform(audio_path)
text = transcribe_waveform(
self.whisper_model,
self.whisper_tokenizer,
[waveform],
language=language or None,
)
duration = float(len(waveform)) / 16000.0
return text, duration
async def AudioTranscription(self, request, context):
if self.whisper_model is None or self.whisper_tokenizer is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("No Whisper model loaded")
return backend_pb2.TranscriptResult()
try:
if not request.dst:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details("TranscriptRequest.dst (audio file path) is required")
return backend_pb2.TranscriptResult()
text, duration = self._transcribe(request.dst, request.language)
segments = [backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)]
return backend_pb2.TranscriptResult(
text=text,
language=request.language or "en",
duration=duration,
segments=segments,
)
except Exception as exc:
import traceback
traceback.print_exc()
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"AudioTranscription failed: {exc}")
return backend_pb2.TranscriptResult()
async def AudioTranscriptionStream(self, request, context):
if self.whisper_model is None or self.whisper_tokenizer is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("No Whisper model loaded")
return
try:
if not request.dst:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details("TranscriptRequest.dst (audio file path) is required")
return
# The vendored tinygrad whisper loop is chunked at the file level
# (one inference pass per 30s segment), not token-level. To still
# produce a streaming response we run the full transcription and
# emit it as a single delta + a final-result envelope so the client
# gets both code paths exercised.
text, duration = self._transcribe(request.dst, request.language)
yield backend_pb2.TranscriptStreamResponse(delta=text)
final = backend_pb2.TranscriptResult(
text=text,
language=request.language or "en",
duration=duration,
segments=[backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)],
)
yield backend_pb2.TranscriptStreamResponse(final_result=final)
except Exception as exc:
import traceback
traceback.print_exc()
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"AudioTranscriptionStream failed: {exc}")
async def Status(self, request, context):
return backend_pb2.StatusResponse(state=backend_pb2.StatusResponse.READY)
async def Free(self, request, context):
self._reset_state()
return backend_pb2.Result(success=True, message="freed")
async def serve(address):
server = grpc.aio.server(
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
interceptors=get_auth_interceptors(aio=True),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
await server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
await server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the tinygrad gRPC backend.")
parser.add_argument("--addr", default="localhost:50051", help="Bind address")
args = parser.parse_args()
asyncio.run(serve(args.addr))

View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
# tinygrad >= 0.12 requires Python >= 3.11 (pyproject: `requires-python = ">=3.11"`).
# LocalAI's default portable python is 3.10, so we pin to 3.11.x here.
PYTHON_VERSION="3.11"
PYTHON_PATCH="14"
PY_STANDALONE_TAG="20260203"
installRequirements

View File

@@ -0,0 +1,103 @@
#!/bin/bash
# Script to package runtime shared libraries for the tinygrad backend.
#
# The final Dockerfile.python stage is FROM scratch, so system libraries
# must be explicitly copied into ${BACKEND}/lib so the backend can run on
# any host without installing them. libbackend.sh automatically prepends
# that directory to LD_LIBRARY_PATH at run time.
#
# tinygrad's CPU device (CLANG / LLVM renderer) JIT-compiles kernels at
# runtime. The default `CLANG` path invokes the external `clang` binary via
# subprocess, which does not exist in the scratch image. We force the
# in-process LLVM path (`CPU_LLVM=1` in run.sh) which loads libLLVM.so.*
# through ctypes and bundle the library + its runtime dependencies here.
#
# Also bundle libgomp (pulled by librosa / numpy via numba) and libsndfile
# (required by soundfile -> librosa audio I/O for Whisper).
set -e
CURDIR=$(dirname "$(realpath "$0")")
LIB_DIR="${CURDIR}/lib"
mkdir -p "${LIB_DIR}"
SEARCH_DIRS=(
/usr/lib/x86_64-linux-gnu
/usr/lib/aarch64-linux-gnu
/lib/x86_64-linux-gnu
/lib/aarch64-linux-gnu
/usr/lib
/lib
)
copy_with_symlinks() {
local soname="$1"
local hit=""
for dir in "${SEARCH_DIRS[@]}"; do
if [ -e "${dir}/${soname}" ]; then
hit="${dir}/${soname}"
break
fi
done
if [ -z "${hit}" ]; then
echo "warning: ${soname} not found in standard lib paths" >&2
return 0
fi
local real
real=$(readlink -f "${hit}")
cp -v "${real}" "${LIB_DIR}/"
local real_base
real_base=$(basename "${real}")
if [ "${real_base}" != "${soname}" ]; then
ln -sf "${real_base}" "${LIB_DIR}/${soname}"
fi
}
# tinygrad searches for libLLVM under these sonames (see
# tinygrad/runtime/autogen/llvm.py). Ubuntu 24.04's `llvm` metapackage
# installs `libLLVM-18.so.1` into `/usr/lib/llvm-18/lib/`. Also scan the
# standard lib directories in case a different distro layout puts it in
# /usr/lib/x86_64-linux-gnu.
llvm_so=""
shopt -s nullglob
LLVM_EXTRA_DIRS=(/usr/lib/llvm-*/lib /usr/lib/llvm-*)
# First try the versioned symlink (libLLVM-18.so) since that's what
# tinygrad's DLL loader matches against (see llvm.py DLL name list).
for dir in "${SEARCH_DIRS[@]}" "${LLVM_EXTRA_DIRS[@]}"; do
for candidate in "${dir}"/libLLVM-[0-9]*.so "${dir}"/libLLVM-[0-9]*.so.[0-9]*; do
if [ -e "${candidate}" ]; then
llvm_so="${candidate}"
break 2
fi
done
done
# Fallback: any libLLVM.so file under /usr.
if [ -z "${llvm_so}" ]; then
llvm_so=$(find /usr -maxdepth 5 -name 'libLLVM*.so*' 2>/dev/null | head -1)
fi
shopt -u nullglob
if [ -z "${llvm_so}" ]; then
echo "ERROR: libLLVM not found — tinygrad CPU device needs it." >&2
echo "Install the Ubuntu \`llvm\` package in the builder stage." >&2
exit 1
fi
echo "Found libLLVM at: ${llvm_so}"
llvm_base=$(basename "${llvm_so}")
real_llvm=$(readlink -f "${llvm_so}")
cp -v "${real_llvm}" "${LIB_DIR}/"
real_base=$(basename "${real_llvm}")
if [ "${real_base}" != "${llvm_base}" ]; then
ln -sf "${real_base}" "${LIB_DIR}/${llvm_base}"
fi
# libLLVM has soft runtime deps on libedit / libtinfo; pick them up if
# present. They're optional but loading without them can fail.
copy_with_symlinks libedit.so.2
copy_with_symlinks libtinfo.so.6
# Audio I/O for the Whisper path.
copy_with_symlinks libsndfile.so.1
copy_with_symlinks libgomp.so.1
echo "tinygrad packaging completed successfully"
ls -liah "${LIB_DIR}/"

View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runProtogen

View File

@@ -0,0 +1 @@
# tinygrad CPU backend uses CLANG device (no extra deps required).

View File

@@ -0,0 +1,2 @@
# tinygrad drives CUDA through its own JIT (CUDA=1 env var).
# Requires the CUDA 12 runtime from the base image; no extra Python deps.

View File

@@ -0,0 +1,2 @@
# tinygrad drives CUDA through its own JIT (CUDA=1 env var).
# Requires the CUDA 13 runtime from the base image; no extra Python deps.

View File

@@ -0,0 +1,15 @@
grpcio==1.80.0
protobuf==6.33.5
certifi
setuptools
numpy>=2.0.0
tinygrad>=0.12.0
tokenizers>=0.21.0
huggingface_hub
jinja2>=3.1.0
tiktoken
sentencepiece
safetensors
Pillow
librosa
soundfile

55
backend/python/tinygrad/run.sh Executable file
View File

@@ -0,0 +1,55 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
# tinygrad binds its compute device at import time from a single env var
# (CUDA / HIP / METAL / CLANG). We pick one here based on what driver
# libraries the host has injected into the container — when a user runs
# the image with `--gpus all` (or the equivalent rocm runtime), the
# nvidia-container-toolkit / rocm runtime mounts the right libraries
# under /usr/lib so we can detect them.
#
# tinygrad's CUDA path uses two compiler pairs: an NVRTC-backed one and
# an in-process PTX renderer. We force the PTX renderer here
# (`CUDA_PTX=1`) so the image is independent of the host CUDA toolkit
# version — only libcuda.so.1 (the driver) is required.
find_lib() {
local soname="$1"
for dir in /usr/lib/x86_64-linux-gnu /usr/lib64 /usr/lib /lib/x86_64-linux-gnu /lib64 /lib; do
if [ -e "${dir}/${soname}" ]; then
echo "${dir}/${soname}"
return 0
fi
done
return 1
}
if [ -z "${CUDA:-}${HIP:-}${METAL:-}${CLANG:-}" ]; then
if find_lib libcuda.so.1 >/dev/null; then
export CUDA=1
export CUDA_PTX=1
elif find_lib libamdhip64.so >/dev/null || find_lib libamdhip64.so.6 >/dev/null; then
export HIP=1
else
export CLANG=1
fi
fi
# The CPU path (CLANG=1) JIT-compiles via libLLVM. Force tinygrad's
# in-process LLVM compiler so we don't need an external `clang` binary
# (which is not present in the scratch image).
export CPU_LLVM=1
if [ -z "${LLVM_PATH:-}" ]; then
for candidate in "${EDIR}"/lib/libLLVM-*.so "${EDIR}"/lib/libLLVM-*.so.* "${EDIR}"/lib/libLLVM.so.*; do
if [ -e "${candidate}" ]; then
export LLVM_PATH="${candidate}"
break
fi
done
fi
startBackend $@

View File

@@ -0,0 +1,84 @@
"""
Unit tests for the tinygrad gRPC backend.
These tests cover the cheap paths that don't need a real model checkpoint:
- Health responds OK
- Tool-call parsers emit expected ToolCall structures
The full LLM / embeddings / Stable Diffusion / Whisper paths are exercised by
the root-level `make test-extra-backend-tinygrad-all` e2e targets, which boot
the containerized backend against real HF checkpoints.
"""
import os
import subprocess
import sys
import time
import unittest
import grpc
import backend_pb2
import backend_pb2_grpc
sys.path.insert(0, os.path.dirname(__file__))
from tool_parsers.hermes import HermesToolParser # noqa: E402
class TestHealth(unittest.TestCase):
def setUp(self):
self.service = subprocess.Popen(
["python3", "backend.py", "--addr", "localhost:50051"]
)
time.sleep(5)
def tearDown(self):
self.service.kill()
self.service.wait()
def test_health(self):
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b"OK")
class TestHermesParser(unittest.TestCase):
def test_single_tool_call(self):
parser = HermesToolParser()
text = (
"Sure, let me check.\n"
"<tool_call>\n"
'{"name": "get_weather", "arguments": {"city": "Paris"}}\n'
"</tool_call>\n"
"Done."
)
content, calls = parser.parse(text)
self.assertIn("Sure", content)
self.assertIn("Done", content)
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0].name, "get_weather")
self.assertIn("Paris", calls[0].arguments)
def test_multi_call_and_thinking(self):
parser = HermesToolParser()
text = (
"<think>I need both.</think>"
'<tool_call>{"name":"a","arguments":{"x":1}}</tool_call>'
'<tool_call>{"name":"b","arguments":{}}</tool_call>'
)
result = parser.parse_full(text)
self.assertEqual(result.reasoning, "I need both.")
self.assertEqual([c.name for c in result.tool_calls], ["a", "b"])
self.assertEqual(result.tool_calls[0].index, 0)
self.assertEqual(result.tool_calls[1].index, 1)
def test_no_tool_call_is_passthrough(self):
parser = HermesToolParser()
text = "plain assistant answer with no tool call"
content, calls = parser.parse(text)
self.assertEqual(content, text)
self.assertEqual(calls, [])
if __name__ == "__main__":
unittest.main()

11
backend/python/tinygrad/test.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests

View File

@@ -0,0 +1,11 @@
"""Tool-call parsers for the tinygrad backend.
Each parser takes raw model output and extracts OpenAI-style tool calls so
the backend can populate `ChatDelta.tool_calls[]` natively (matching vLLM's
behavior, which the Go core prefers over regex fallback parsing).
"""
from __future__ import annotations
from .base import ToolCall, ToolParser, resolve_parser
__all__ = ["ToolCall", "ToolParser", "resolve_parser"]

View File

@@ -0,0 +1,85 @@
"""Common types + parser registry for tool-call extraction."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ToolCall:
"""One extracted tool call — maps 1:1 to backend_pb2.ToolCallDelta."""
index: int
name: str
arguments: str # JSON string
id: str = ""
class ToolParser:
"""Parser interface.
Subclasses implement `parse` (full non-streaming pass) and optionally
`parse_stream` (incremental). The default `parse_stream` buffers until a
full response is available and then delegates to `parse`.
"""
name: str = "base"
def __init__(self) -> None:
self._stream_buffer = ""
self._stream_index = 0
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
"""Return (content_for_user, tool_calls)."""
raise NotImplementedError
def parse_stream(self, delta: str, finished: bool = False) -> tuple[str, list[ToolCall]]:
"""Accumulate a streaming delta. Emits any tool calls that have closed.
Default behavior: buffer until `finished=True`, then parse once.
Subclasses can override to emit mid-stream.
"""
self._stream_buffer += delta
if not finished:
return "", []
content, calls = self.parse(self._stream_buffer)
# Re-index starting from whatever we've already emitted in this stream.
reindexed: list[ToolCall] = []
for i, c in enumerate(calls):
reindexed.append(ToolCall(
index=self._stream_index + i,
name=c.name,
arguments=c.arguments,
id=c.id,
))
self._stream_index += len(reindexed)
return content, reindexed
def reset(self) -> None:
self._stream_buffer = ""
self._stream_index = 0
_REGISTRY: dict[str, type[ToolParser]] = {}
def register(cls: type[ToolParser]) -> type[ToolParser]:
_REGISTRY[cls.name] = cls
return cls
def resolve_parser(name: Optional[str]) -> ToolParser:
"""Return a parser instance by name, falling back to a no-op passthrough."""
# Import for side effects — each module registers itself.
from . import hermes, llama3_json, mistral, qwen3_xml # noqa: F401
if name and name in _REGISTRY:
return _REGISTRY[name]()
return PassthroughToolParser()
class PassthroughToolParser(ToolParser):
"""No-op parser — used when no tool_parser is configured."""
name = "passthrough"
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
return text, []

View File

@@ -0,0 +1,74 @@
"""Hermes-format tool-call parser.
Hermes 2 / 2.5 / 3 (and Qwen 2.5 Instruct, which adopted the same convention)
emit tool calls wrapped in `<tool_call>...</tool_call>` tags, where the inner
content is a JSON object with `name` and `arguments` keys:
<tool_call>
{"name": "get_weather", "arguments": {"city": "Paris"}}
</tool_call>
Multiple tool calls may appear back-to-back. Text outside the tags is plain
assistant content that should surface to the user.
This parser also strips `<think>...</think>` reasoning blocks and returns them
via the reasoning_content channel (Qwen 3, DeepSeek-R1 distills).
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from .base import ToolCall, ToolParser, register
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
@dataclass
class HermesParseResult:
content: str
reasoning: str
tool_calls: list[ToolCall]
@register
class HermesToolParser(ToolParser):
name = "hermes"
def _parse_full(self, text: str) -> HermesParseResult:
reasoning_parts: list[str] = []
def _capture_reasoning(match: re.Match[str]) -> str:
reasoning_parts.append(match.group(1).strip())
return ""
text_wo_think = _THINK_RE.sub(_capture_reasoning, text)
calls: list[ToolCall] = []
for idx, match in enumerate(_TOOL_CALL_RE.finditer(text_wo_think)):
raw = match.group(1)
try:
obj = json.loads(raw)
except json.JSONDecodeError:
continue
if not isinstance(obj, dict):
continue
name = obj.get("name")
if not isinstance(name, str):
continue
args = obj.get("arguments", {})
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
calls.append(ToolCall(index=idx, name=name, arguments=args_str))
content = _TOOL_CALL_RE.sub("", text_wo_think).strip()
reasoning = "\n\n".join(reasoning_parts).strip()
return HermesParseResult(content=content, reasoning=reasoning, tool_calls=calls)
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
result = self._parse_full(text)
return result.content, result.tool_calls
def parse_full(self, text: str) -> HermesParseResult:
return self._parse_full(text)

View File

@@ -0,0 +1,86 @@
"""Llama 3.1 / 3.2 / 3.3 JSON tool-call parser.
Meta's Llama 3.1+ instruct chat templates emit tool calls in two broadly
compatible shapes:
1. With the `<|python_tag|>` lead-in:
<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}}
2. As a bare JSON object (or list of objects) at the end of the turn.
We also handle multi-call shapes where the model emits several JSON objects
separated by `;` or newlines, and JSON arrays `[{...}, {...}]`. The key field
for Llama 3 is historically `parameters` (older docs) but recent checkpoints
also emit `arguments` — accept either.
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from .base import ToolCall, ToolParser, register
_PYTHON_TAG = "<|python_tag|>"
_JSON_OBJECT_RE = re.compile(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", re.DOTALL)
def _coerce_call(obj: object, index: int) -> ToolCall | None:
if not isinstance(obj, dict):
return None
name = obj.get("name")
if not isinstance(name, str):
return None
args = obj.get("arguments", obj.get("parameters", {}))
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
return ToolCall(index=index, name=name, arguments=args_str)
@register
class Llama3JsonToolParser(ToolParser):
name = "llama3_json"
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
calls: list[ToolCall] = []
# Strip <|python_tag|> segments first — each segment is one tool call
# body. The content after the final python_tag (if any) is the call.
remaining = text
if _PYTHON_TAG in text:
head, *tails = text.split(_PYTHON_TAG)
remaining = head
for tail in tails:
parsed = _try_parse(tail.strip(), len(calls))
calls.extend(parsed)
# Any JSON objects / arrays left in `remaining` count as tool calls too
# if they parse to a {"name": ..., "arguments": ...} shape.
for match in _JSON_OBJECT_RE.finditer(remaining):
parsed = _try_parse(match.group(0), len(calls))
if parsed:
calls.extend(parsed)
remaining = remaining.replace(match.group(0), "", 1)
content = remaining.strip()
return content, calls
def _try_parse(blob: str, start_index: int) -> list[ToolCall]:
"""Parse a fragment that may be a JSON object or a JSON array of objects."""
blob = blob.strip().rstrip(";")
if not blob:
return []
try:
obj = json.loads(blob)
except json.JSONDecodeError:
return []
if isinstance(obj, dict):
call = _coerce_call(obj, start_index)
return [call] if call else []
if isinstance(obj, list):
calls: list[ToolCall] = []
for i, item in enumerate(obj):
c = _coerce_call(item, start_index + i)
if c:
calls.append(c)
return calls
return []

View File

@@ -0,0 +1,56 @@
"""Mistral / Mixtral tool-call parser.
Mistral Nemo / Small / Large Instruct emit tool calls prefixed with the
`[TOOL_CALLS]` control token, followed by a JSON array:
[TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}}]
Multiple calls live inside the same array. Any text before `[TOOL_CALLS]` is
normal assistant content and should surface to the user.
"""
from __future__ import annotations
import json
import re
from .base import ToolCall, ToolParser, register
_MARKER = "[TOOL_CALLS]"
_JSON_ARRAY_RE = re.compile(r"\[\s*(?:\{.*?\}\s*,?\s*)+\]", re.DOTALL)
@register
class MistralToolParser(ToolParser):
name = "mistral"
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
if _MARKER not in text:
return text.strip(), []
head, tail = text.split(_MARKER, 1)
content = head.strip()
match = _JSON_ARRAY_RE.search(tail)
if not match:
return content, []
try:
arr = json.loads(match.group(0))
except json.JSONDecodeError:
return content, []
if not isinstance(arr, list):
return content, []
calls: list[ToolCall] = []
for i, obj in enumerate(arr):
if not isinstance(obj, dict):
continue
name = obj.get("name")
if not isinstance(name, str):
continue
args = obj.get("arguments", {})
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
calls.append(ToolCall(index=i, name=name, arguments=args_str))
return content, calls

View File

@@ -0,0 +1,74 @@
"""Qwen 3 XML tool-call parser.
Qwen 3 Instruct emits tool calls wrapped in a two-level tag structure:
<tool_call>
<function=get_weather>
<parameter=city>
Paris
</parameter>
<parameter=unit>
celsius
</parameter>
</function>
</tool_call>
Parameter values are raw text — we treat them as strings unless they look
like JSON (in which case we try to parse so numbers / booleans round-trip
cleanly). Qwen 3 also supports `<think>...</think>` reasoning blocks before
the tool call — these are captured via the shared Hermes convention.
"""
from __future__ import annotations
import json
import re
from .base import ToolCall, ToolParser, register
_TOOL_CALL_RE = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
_FUNCTION_RE = re.compile(r"<function=([^>]+)>(.*?)</function>", re.DOTALL)
_PARAMETER_RE = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
def _maybe_json(value: str):
value = value.strip()
if not value:
return value
if value[0] in "{[\"" or value in ("true", "false", "null") or value.lstrip("-").replace(".", "", 1).isdigit():
try:
return json.loads(value)
except json.JSONDecodeError:
return value
return value
@register
class Qwen3XmlToolParser(ToolParser):
name = "qwen3_xml"
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
# Strip reasoning blocks from the user-visible content.
stripped = _THINK_RE.sub("", text)
calls: list[ToolCall] = []
for match in _TOOL_CALL_RE.finditer(stripped):
body = match.group(1)
fn_match = _FUNCTION_RE.search(body)
if not fn_match:
continue
name = fn_match.group(1).strip()
params_body = fn_match.group(2)
params: dict[str, object] = {}
for pm in _PARAMETER_RE.finditer(params_body):
params[pm.group(1).strip()] = _maybe_json(pm.group(2))
calls.append(ToolCall(
index=len(calls),
name=name,
arguments=json.dumps(params, ensure_ascii=False),
))
content = _TOOL_CALL_RE.sub("", stripped).strip()
return content, calls

View File

@@ -0,0 +1,6 @@
"""Vendored upstream tinygrad reference code (MIT-licensed).
Source: https://github.com/tinygrad/tinygrad
These files are not part of the `tinygrad` pip package (the `extra/` tree is
excluded from `pyproject.toml` `packages`), so we carry a pinned copy here.
"""

View File

@@ -0,0 +1,83 @@
# Vendored verbatim from tinygrad examples/audio_helpers.py (MIT license).
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/audio_helpers.py
# Copyright (c) 2023- the tinygrad authors
# SPDX-License-Identifier: MIT
from typing import Optional
from tinygrad import Tensor
from tinygrad.dtype import DTypeLike, dtypes
import math
# rewritten from numpy
def rfftfreq(n: int, d: float = 1.0, device=None) -> Tensor:
val = 1.0 / (n * d)
N = n // 2 + 1
results = Tensor.arange(N, device=device)
return results * val
# just like in librosa
def fft_frequencies(sr: float, n_fft: int) -> Tensor:
return rfftfreq(n=n_fft, d=1.0 / sr)
def hz_to_mel(freq: Tensor) -> Tensor:
# linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (freq - f_min) / f_sp
# log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
mask = freq >= min_log_hz
return mask.where(((min_log_hz - f_min) / f_sp) + (freq / min_log_hz).log() / (math.log(6.4) / 27.0), mels)
def mel_to_hz(mels: Tensor) -> Tensor:
# linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region
log_t = mels >= min_log_mel
freqs = log_t.where(min_log_hz * ((logstep * (mels - min_log_mel)).exp()), freqs)
return freqs
def mel_frequencies(n_mels: int = 128, *, fmin: float = 0.0, fmax: float = 11025.0) -> Tensor:
# center freqs of mel bands - uniformly spaced between limits
min_max_mel = hz_to_mel(Tensor([fmin, fmax]))
mels = Tensor.linspace(min_max_mel[0], min_max_mel[1], n_mels)
hz = mel_to_hz(mels)
return hz
def mel(
*,
sr: float,
n_fft: int,
n_mels: int = 128,
fmin: float = 0.0,
fmax: Optional[float] = None,
dtype: DTypeLike = dtypes.default_float,
) -> Tensor:
if fmax is None:
fmax = float(sr) / 2
n_mels = int(n_mels)
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) # center freqs of each FFT bin
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax) # center freqs of mel bands
fdiff = mel_f[1:] - mel_f[:-1]
ramps = mel_f[None].T.expand(-1, fftfreqs.shape[-1]) - fftfreqs
lower = -ramps[:n_mels] / fdiff[:n_mels][None].T
upper = ramps[2 : n_mels + 2] / fdiff[1 : n_mels + 1][None].T
weights = lower.minimum(upper).maximum(0)
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, None]
return weights

484
backend/python/tinygrad/vendor/clip.py vendored Normal file
View File

@@ -0,0 +1,484 @@
# Vendored verbatim from tinygrad extra/models/clip.py (MIT license).
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/clip.py
# Copyright (c) 2023- the tinygrad authors
# SPDX-License-Identifier: MIT
from tinygrad import Tensor, dtypes
from tinygrad.helpers import fetch
from tinygrad.nn import Linear, LayerNorm, Embedding, Conv2d
from typing import List, Optional, Union, Tuple, Dict
from abc import ABC, abstractmethod
from functools import lru_cache
import numpy as np
import re, gzip
# Allow for monkeypatching for mlperf.
gelu = Tensor.gelu
@lru_cache()
def default_bpe():
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
class Tokenizer:
"""
Namespace for CLIP Text Tokenizer components.
"""
@staticmethod
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
return set(zip(word, word[1:]))
@staticmethod
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
@staticmethod
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
class ClipTokenizer:
def __init__(self, version=None):
self.byte_encoder, self.version = Tokenizer.bytes_to_unicode(), version
merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(Tokenizer.bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
if self.version == "sd_mlperf_v5_0":
import regex
vocab.extend(['<start_of_text>', '<end_of_text>'])
self.cache = {'<start_of_text>': '<start_of_text>', '<end_of_text>': '<end_of_text>'}
self.pat = regex.compile(r"""<start_of_text>|<end_of_text>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
else:
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.bpe_ranks = dict(zip(merges, range(len(merges))))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = Tokenizer.get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
pairs = Tokenizer.get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
bpe_tokens: List[int] = []
if self.version == "sd_mlperf_v5_0":
import regex, ftfy, html
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text)).strip()
text = Tokenizer.whitespace_clean(text).lower()
re_module = regex
else:
text = Tokenizer.whitespace_clean(text.strip()).lower()
re_module = re
for token in re_module.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
# Truncation, keeping two slots for start and end tokens.
if len(bpe_tokens) > 75:
bpe_tokens = bpe_tokens[:75]
return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
class Embedder(ABC):
input_key: str
@abstractmethod
def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
pass
class Closed:
"""
Namespace for OpenAI CLIP model components.
"""
class ClipMlp:
def __init__(self):
self.fc1 = Linear(768, 3072)
self.fc2 = Linear(3072, 768)
def __call__(self, h:Tensor) -> Tensor:
h = self.fc1(h)
h = h.quick_gelu()
h = self.fc2(h)
return h
class ClipAttention:
def __init__(self):
self.embed_dim = 768
self.num_heads = 12
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = Linear(self.embed_dim, self.embed_dim)
self.v_proj = Linear(self.embed_dim, self.embed_dim)
self.q_proj = Linear(self.embed_dim, self.embed_dim)
self.out_proj = Linear(self.embed_dim, self.embed_dim)
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
bsz, tgt_len, embed_dim = hidden_states.shape
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
class ClipEncoderLayer:
def __init__(self):
self.self_attn = Closed.ClipAttention()
self.layer_norm1 = LayerNorm(768)
self.mlp = Closed.ClipMlp()
self.layer_norm2 = LayerNorm(768)
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class ClipTextEmbeddings:
def __init__(self):
self.token_embedding = Embedding(49408, 768)
self.position_embedding = Embedding(77, 768)
def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
class ClipEncoder:
def __init__(self, layer_count:int=12):
self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
# the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
for l in layers:
x = l(x, causal_attention_mask)
return x
class ClipTextTransformer:
def __init__(self, ret_layer_idx:Optional[int]=None):
self.embeddings = Closed.ClipTextEmbeddings()
self.encoder = Closed.ClipEncoder()
self.final_layer_norm = LayerNorm(768)
self.ret_layer_idx = ret_layer_idx
def __call__(self, input_ids:Tensor) -> Tensor:
x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
class ClipTextModel:
def __init__(self, ret_layer_idx:Optional[int]):
self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
class FrozenClosedClipEmbedder(Embedder):
def __init__(self, ret_layer_idx:Optional[int]=None):
self.tokenizer = Tokenizer.ClipTokenizer()
self.transformer = Closed.ClipTextModel(ret_layer_idx)
self.input_key = "txt"
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
if isinstance(texts, str): texts = [texts]
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
return self.transformer.text_model(tokens.reshape(len(texts),-1))
class Open:
"""
Namespace for OpenCLIP model components.
"""
class MultiheadAttention:
def __init__(self, dims:int, n_heads:int):
self.dims = dims
self.n_heads = n_heads
self.d_head = self.dims // self.n_heads
self.in_proj_bias = Tensor.empty(3*dims)
self.in_proj_weight = Tensor.empty(3*dims, dims)
self.out_proj = Linear(dims, dims)
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
T,B,C = x.shape
proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0, -2)
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)]
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn_output = attn_output.permute(2, 0, 1, 3).reshape(T, B, C)
attn_output = self.out_proj(attn_output)
return attn_output
class Mlp:
def __init__(self, dims, hidden_dims):
self.c_fc = Linear(dims, hidden_dims)
self.c_proj = Linear(hidden_dims, dims)
self.gelu = gelu
def __call__(self, x:Tensor) -> Tensor:
return x.sequential([self.c_fc, self.gelu, self.c_proj])
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
class ResidualAttentionBlock:
def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
self.ln_1 = LayerNorm(dims)
self.attn = Open.MultiheadAttention(dims, n_heads)
self.ln_2 = LayerNorm(dims)
self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None, transpose:bool=False) -> Tensor:
q_x = self.ln_1(x)
attn_out = self.attn(q_x.transpose(0, 1) if transpose else q_x, attn_mask=attn_mask)
attn_out = attn_out.transpose(0, 1) if transpose else attn_out
x = x + attn_out
x = x + self.mlp(self.ln_2(x))
return x
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
class ClipTransformer:
def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
self.resblocks = [
Open.ResidualAttentionBlock(dims, n_heads, mlp_ratio) for _ in range(layers)
]
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
for r in self.resblocks:
x = r(x, attn_mask=attn_mask, transpose=True)
return x
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
class ClipTextTransformer:
def __init__(self, width:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
self.token_embedding = Embedding(vocab_size, width)
self.positional_embedding = Tensor.empty(ctx_length, width)
self.transformer = Open.ClipTransformer(width, layers, n_heads)
self.ln_final = LayerNorm(width)
self.text_projection = Tensor.empty(width, width)
self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
def __call__(self, text:Tensor) -> Tensor:
seq_len = text.shape[1]
x = self.token_embedding(text)
x = x + self.positional_embedding[:seq_len]
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x)
pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
return pooled
class ClipVisionTransformer:
def __init__(self, width:int, layers:int, d_head:int, image_size:int, patch_size:int):
grid_size = image_size // patch_size
n_heads = width // d_head
assert n_heads * d_head == width
self.conv1 = Conv2d(3, width, kernel_size=patch_size, stride=patch_size, bias=False)
self.class_embedding = Tensor.empty(width)
self.positional_embedding = Tensor.empty(grid_size * grid_size + 1, width)
self.transformer = Open.ClipTransformer(width, layers, n_heads)
self.ln_pre = LayerNorm(width)
self.ln_post = LayerNorm(width)
self.proj = Tensor.empty(width, 1024)
def __call__(self, x:Tensor) -> Tensor:
x = self.conv1(x)
x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
x = self.class_embedding.reshape(1, 1, -1).expand(x.shape[0], 1, -1).cat(x, dim=1)
x = x + self.positional_embedding
x = self.ln_pre(x)
x = self.transformer(x)
x = self.ln_post(x)
pooled = x[:, 0] @ self.proj
return pooled
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
class FrozenOpenClipEmbedder(Embedder):
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False, clip_tokenizer_version=None):
self.tokenizer = Tokenizer.ClipTokenizer(version=clip_tokenizer_version)
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
self.return_pooled = return_pooled
self.input_key = "txt"
self.ln_penultimate = ln_penultimate
def tokenize(self, text:str, device:Optional[str]=None) -> Tensor:
return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int32, device=device).reshape(1,-1)
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
for r in self.model.transformer.resblocks:
x, penultimate = r(x, attn_mask=attn_mask), x
return x.permute(1, 0, 2), penultimate.permute(1, 0, 2)
def embed_tokens(self, tokens:Tensor) -> Union[Tensor,Tuple[Tensor,...]]:
x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
if self.ln_penultimate:
penultimate = self.model.ln_final(penultimate)
if self.return_pooled:
x = self.model.ln_final(x)
index = tokens.argmax(axis=-1).reshape(-1,1,1).expand(x.shape[0],1,x.shape[-1])
pooled = x.gather(1, index).squeeze(1) @ self.model.text_projection
return penultimate, pooled
else:
return penultimate
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
if isinstance(texts, str): texts = [texts]
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0)
return self.embed_tokens(tokens)
clip_configs: Dict = {
"ViT-H-14": {
"dims": 1024,
"vision_cfg": {
"width": 1280,
"layers": 32,
"d_head": 80,
"image_size": 224,
"patch_size": 14,
},
"text_cfg": {
"width": 1024,
"n_heads": 16,
"layers": 24,
"ctx_length": 77,
"vocab_size": 49408,
},
"return_pooled": False,
"ln_penultimate": True,
}
}
class OpenClipEncoder:
def __init__(self, dims:int, text_cfg:Dict, vision_cfg:Dict, **_):
self.visual = Open.ClipVisionTransformer(**vision_cfg)
text = Open.ClipTextTransformer(**text_cfg)
self.transformer = text.transformer
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
self.mean = Tensor([0.48145466, 0.45782750, 0.40821073]).reshape(-1, 1, 1)
self.std = Tensor([0.26862954, 0.26130258, 0.27577711]).reshape(-1, 1, 1)
# TODO:
# Should be doable in pure tinygrad, would just require some work and verification.
# This is very desirable since it would allow for full generation->evaluation in a single JIT call.
def prepare_image(self, image) -> Tensor:
from PIL import Image
SIZE = 224
w, h = image.size
scale = min(SIZE / h, SIZE / w)
image = image.resize((max(int(w*scale),SIZE),max(int(h*scale),SIZE)), Image.Resampling.BICUBIC)
w, h = image.size
if w > SIZE:
left = (w - SIZE) // 2
image = image.crop((left, left+SIZE, 0, SIZE))
elif h > SIZE:
top = (h - SIZE) // 2
image = image.crop((0, SIZE, top, top+SIZE))
x = Tensor(np.array(image.convert('RGB')), device=self.std.device)
x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0
return (x - self.mean) / self.std
def encode_tokens(self, tokens:Tensor) -> Tensor:
x = self.token_embedding(tokens)
x = x + self.positional_embedding
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x)
x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)]
x = x @ self.text_projection
return x
def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor:
image_features: Tensor = self.visual(image)
image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
text_features = self.encode_tokens(tokens)
text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
return (image_features * text_features).sum(axis=-1)

294
backend/python/tinygrad/vendor/llama.py vendored Normal file
View File

@@ -0,0 +1,294 @@
# Vendored from tinygrad extra/models/llama.py (MIT license).
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/llama.py
#
# Local modification: Attention / TransformerBlock / Transformer accept an
# optional `qkv_bias` flag so the same module can host Qwen2-style models that
# use bias on the Q/K/V projections (Llama 3 has no bias). Changes are marked
# with `# LOCALAI`.
#
# Copyright (c) 2023- the tinygrad authors
# SPDX-License-Identifier: MIT
from typing import Union, Optional, Any
import collections, math
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv, DEBUG
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2)
def complex_mult(A, c, d):
a, b = A[..., 0:1], A[..., 1:2]
ro = a * c - b * d
co = a * d + b * c
return ro.cat(co, dim=-1)
def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
xq_out = complex_mult(xq, c, d)
xk_out = complex_mult(xk, c, d)
return xq_out.flatten(3), xk_out.flatten(3)
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
bs, seqlen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
class Attention:
# LOCALAI: added qkv_bias
def __init__(self, dim, n_heads, n_kv_heads=None, max_context=0, linear=nn.Linear, qk_norm: float | None = None, qkv_bias: bool = False):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.max_context = max_context
self.wq = linear(dim, self.n_heads * self.head_dim, bias=qkv_bias)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
self.k_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor] = None) -> Tensor:
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
if self.q_norm is not None and self.k_norm is not None:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
if x.dtype == dtypes.bfloat16:
xq, xk = xq.contiguous_backward(), xk.contiguous_backward()
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
bsz, seqlen, _, _ = xq.shape
if self.max_context:
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
if isinstance(x.device, tuple):
self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
self.cache_kv[:, :, start_pos:start_pos + seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
keys = self.cache_kv[0, :, 0:start_pos + seqlen, :, :]
values = self.cache_kv[1, :, 0:start_pos + seqlen, :, :]
else:
assert start_pos == 0
keys, values = xk, xv
if self.max_context:
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
else:
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, is_causal=True, enable_gqa=True).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)
return self.wo(attn)
class FeedForward:
def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
self.w1 = linear(dim, hidden_dim, bias=False)
self.w2 = linear(hidden_dim, dim, bias=False)
self.w3 = linear(dim, hidden_dim, bias=False)
def __call__(self, x: Tensor) -> Tensor:
w1 = self.w1(x).silu()
w3 = self.w3(x.contiguous_backward())
return self.w2(w1 * w3)
class TransformerBlock:
# LOCALAI: added qkv_bias
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int,
linear=nn.Linear, feed_forward=FeedForward, qk_norm=None, qkv_bias: bool = False):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear, qk_norm, qkv_bias=qkv_bias)
self.feed_forward = feed_forward(dim, hidden_dim, linear)
self.attention_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).contiguous().contiguous_backward()
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
assert logits.ndim == 1, "only works on 1d tensors"
assert 0 <= p <= 1, "p must be between 0 and 1"
assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
if temp < 1e-6:
return logits.argmax()
logits = logits.to(Device.DEFAULT)
if af or ap:
if not hasattr(sample, "alpha_counter"):
setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
logits = (logits != logits).where(-float("inf"), logits)
t = (logits / temp).softmax()
counter = Tensor.arange(t.numel(), device=logits.device).contiguous()
counter2 = Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
if k:
output = Tensor.zeros(k, device=logits.device).contiguous()
output_indices = Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
for i in range(k):
t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
t = (counter == t_argmax).where(0, t)
output_cumsum = output[::-1].cumsum()[::-1] + t.sum()
output = (output_cumsum >= (1 - p)) * output
output_indices = (output_cumsum >= (1 - p)) * output_indices
output_idx = output.multinomial()
output_token = output_indices[output_idx]
else:
output_token = t.multinomial()
if af or ap:
sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
return output_token
class Transformer:
# LOCALAI: added qkv_bias
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_layers: int, norm_eps: float, vocab_size,
linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000,
max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None, disable_kv_cache=False,
qkv_bias: bool = False):
self.layers = [
TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps,
0 if disable_kv_cache else max_context, linear,
feed_forward=feed_forward, qk_norm=qk_norm, qkv_bias=qkv_bias)
for _ in range(n_layers)
]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
self.max_context = max_context
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous().requires_grad_(False)
self.forward_jit = TinyJit(self.forward) if jit else None
def forward(self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens).contiguous()
freqs_cis = self.freqs_cis.cast(h.dtype)[:, start_pos:start_pos + seqlen, :, :, :]
if self.max_context != 0 and seqlen > 1:
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1)
else:
mask = None
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h).contiguous().contiguous_backward()).contiguous_backward()
if math.isnan(temperature):
return logits
return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
def __call__(self, tokens: Tensor, start_pos: int, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context - 1).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
# LOCALAI: extract last hidden state for embeddings. Skips the LM head and
# the causal-mask branch is left intact so the pooling sees the full sequence.
def embed(self, tokens: Tensor) -> Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens).contiguous()
freqs_cis = self.freqs_cis.cast(h.dtype)[:, 0:seqlen, :, :, :]
mask = Tensor.full((1, 1, seqlen, seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(1) if seqlen > 1 else None
for layer in self.layers:
h = layer(h, 0, freqs_cis, mask)
return self.norm(h)
def convert_from_huggingface(weights: dict[str, Tensor], n_layers: int, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1] if len(v.shape) > 1 else 1).transpose(1, 2).reshape(*v.shape[:2])
keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)},
**{f"model.layers.{l}.self_attn.{x}_norm.weight": f"layers.{l}.attention.{x}_norm.weight" for x in ["q", "k"] for l in range(n_layers)},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(n_layers)},
**{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(n_layers)},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)},
**{f"model.layers.{l}.mlp.gate.weight": f"layers.{l}.feed_forward.gate.weight" for l in range(n_layers)},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
sd = {}
experts = collections.defaultdict(dict)
for k, v in weights.items():
if ".rotary_emb." in k:
continue
v = v.to(Device.DEFAULT)
if "model.layers" in k:
if ("q_proj" in k or "q_norm" in k) and permute_layers:
v = permute(v, n_heads)
elif ("k_proj" in k or "k_norm" in k) and permute_layers:
v = permute(v, n_kv_heads)
if '.mlp.experts.' in k:
_, _, layer, _, _, expert, name, _ = k.split('.')
experts[f'layers.{layer}.feed_forward.{name}'][int(expert)] = v
continue
sd[keymap[k]] = v
for k, v in experts.items():
sd[k] = Tensor.stack(*[v[i] for i in range(len(v))])
if "output.weight" not in sd and "tok_embeddings.weight" in sd:
sd["output.weight"] = sd["tok_embeddings.weight"]
return sd
def convert_from_gguf(weights: dict[str, Tensor], n_layers: int):
keymap = {
"token_embd.weight": "tok_embeddings.weight",
**{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)},
**{f"blk.{l}.attn_{x}.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v"] for l in range(n_layers)},
**{f"blk.{l}.attn_{x}.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v"] for l in range(n_layers)},
**{f"blk.{l}.attn_output.weight": f"layers.{l}.attention.wo.weight" for l in range(n_layers)},
**{f"blk.{l}.ffn_norm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)},
**{f"blk.{l}.ffn_{x}.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)},
"output_norm.weight": "norm.weight",
"rope_freqs.weight": "rope_freqs.weight",
}
sd = {keymap[k]: v for k, v in weights.items() if k in keymap}
if "output.weight" not in sd and "token_embd.weight" in weights:
sd["output.weight"] = weights["token_embd.weight"]
return sd
def fix_bf16(weights: dict[Any, Tensor]):
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

View File

@@ -0,0 +1,232 @@
# Adapted from tinygrad examples/stable_diffusion.py (MIT license).
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/stable_diffusion.py
# Copyright (c) 2023- the tinygrad authors
# SPDX-License-Identifier: MIT
#
# Local modifications: removed the MLPerf training branch (pulls
# examples/mlperf/initializers which we don't vendor) and the __main__
# argparse / fetch / profile blocks. Kept the core classes so the LocalAI
# tinygrad backend can instantiate and drive Stable Diffusion v1.x from a
# single checkpoint path.
from collections import namedtuple
from typing import Any, Dict
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.nn import Conv2d, GroupNorm
from . import clip as clip_mod
from . import unet as unet_mod
from .clip import Closed, Tokenizer
from .unet import UNetModel
class AttnBlock:
def __init__(self, in_channels):
self.norm = GroupNorm(32, in_channels)
self.q = Conv2d(in_channels, in_channels, 1)
self.k = Conv2d(in_channels, in_channels, 1)
self.v = Conv2d(in_channels, in_channels, 1)
self.proj_out = Conv2d(in_channels, in_channels, 1)
def __call__(self, x):
h_ = self.norm(x)
q, k, v = self.q(h_), self.k(h_), self.v(h_)
b, c, h, w = q.shape
q, k, v = [t.reshape(b, c, h * w).transpose(1, 2) for t in (q, k, v)]
h_ = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(b, c, h, w)
return x + self.proj_out(h_)
class ResnetBlock:
def __init__(self, in_channels, out_channels=None):
self.norm1 = GroupNorm(32, in_channels)
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
self.norm2 = GroupNorm(32, out_channels)
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else (lambda x: x)
def __call__(self, x):
h = self.conv1(self.norm1(x).swish())
h = self.conv2(self.norm2(h).swish())
return self.nin_shortcut(x) + h
class Mid:
def __init__(self, block_in):
self.block_1 = ResnetBlock(block_in, block_in)
self.attn_1 = AttnBlock(block_in)
self.block_2 = ResnetBlock(block_in, block_in)
def __call__(self, x):
return x.sequential([self.block_1, self.attn_1, self.block_2])
class Decoder:
def __init__(self):
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
self.conv_in = Conv2d(4, 512, 3, padding=1)
self.mid = Mid(512)
arr = []
for i, s in enumerate(sz):
arr.append({"block": [ResnetBlock(s[1], s[0]), ResnetBlock(s[0], s[0]), ResnetBlock(s[0], s[0])]})
if i != 0:
arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
self.up = arr
self.norm_out = GroupNorm(32, 128)
self.conv_out = Conv2d(128, 3, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
x = self.mid(x)
for l in self.up[::-1]:
for b in l['block']:
x = b(x)
if 'upsample' in l:
bs, c, py, px = x.shape
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py * 2, px * 2)
x = l['upsample']['conv'](x)
x.realize()
return self.conv_out(self.norm_out(x).swish())
class Encoder:
def __init__(self):
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
self.conv_in = Conv2d(3, 128, 3, padding=1)
arr = []
for i, s in enumerate(sz):
arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]})
if i != 3:
arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))}
self.down = arr
self.mid = Mid(512)
self.norm_out = GroupNorm(32, 512)
self.conv_out = Conv2d(512, 8, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
for l in self.down:
for b in l['block']:
x = b(x)
if 'downsample' in l:
x = l['downsample']['conv'](x)
x = self.mid(x)
return self.conv_out(self.norm_out(x).swish())
class AutoencoderKL:
def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
self.quant_conv = Conv2d(8, 8, 1)
self.post_quant_conv = Conv2d(4, 4, 1)
def __call__(self, x):
latent = self.encoder(x)
latent = self.quant_conv(latent)
latent = latent[:, 0:4]
latent = self.post_quant_conv(latent)
return self.decoder(latent)
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
return Tensor(alphas_cumprod)
# SD1.x UNet hyperparameters (same as upstream `unet_params`).
UNET_PARAMS_SD1: Dict[str, Any] = {
"adm_in_ch": None,
"in_ch": 4,
"out_ch": 4,
"model_ch": 320,
"attention_resolutions": [4, 2, 1],
"num_res_blocks": 2,
"channel_mult": [1, 2, 4, 4],
"n_heads": 8,
"transformer_depth": [1, 1, 1, 1],
"ctx_dim": 768,
"use_linear": False,
}
class StableDiffusion:
"""Stable Diffusion 1.x pipeline, adapted from tinygrad's reference example.
Drives the native CompVis `sd-v1-*.ckpt` checkpoint format (the only one
the vendored weight layout handles). For HuggingFace safetensors pipelines
the caller is expected to download / merge the `.ckpt` equivalent before
calling LoadModel.
"""
def __init__(self):
self.alphas_cumprod = get_alphas_cumprod()
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
transformer=namedtuple("Transformer", ["text_model"])(text_model=Closed.ClipTextTransformer())
)
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(
diffusion_model=UNetModel(**UNET_PARAMS_SD1)
)
# DDIM update step.
def _update(self, x, e_t, a_t, a_prev):
sqrt_one_minus_at = (1 - a_t).sqrt()
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1.0 - a_prev).sqrt() * e_t
return a_prev.sqrt() * pred_x0 + dir_xt
def _model_output(self, uncond, cond, latent, timestep, guidance):
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, uncond.cat(cond, dim=0))
uncond_latent, cond_latent = latents[0:1], latents[1:2]
return uncond_latent + guidance * (cond_latent - uncond_latent)
def step(self, uncond, cond, latent, timestep, a_t, a_prev, guidance):
e_t = self._model_output(uncond, cond, latent, timestep, guidance)
return self._update(latent, e_t, a_t, a_prev).realize()
def decode(self, x):
x = self.first_stage_model.post_quant_conv(1 / 0.18215 * x)
x = self.first_stage_model.decoder(x)
x = (x + 1.0) / 2.0
x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
return x.cast(dtypes.uint8)
def encode_prompt(self, tokenizer, prompt: str):
ids = Tensor([tokenizer.encode(prompt)])
return self.cond_stage_model.transformer.text_model(ids).realize()
def run_sd15(model: StableDiffusion, prompt: str, negative_prompt: str, steps: int, guidance: float, seed: int):
"""Generate a single 512x512 image. Returns a (512,512,3) uint8 tensor."""
tokenizer = Tokenizer.ClipTokenizer()
context = model.encode_prompt(tokenizer, prompt)
uncond = model.encode_prompt(tokenizer, negative_prompt)
timesteps = list(range(1, 1000, 1000 // steps))
alphas = model.alphas_cumprod[Tensor(timesteps)]
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
if seed is not None:
Tensor.manual_seed(seed)
latent = Tensor.randn(1, 4, 64, 64)
for index in range(len(timesteps) - 1, -1, -1):
timestep = timesteps[index]
tid = Tensor([index])
latent = model.step(
uncond, context, latent,
Tensor([timestep]),
alphas[tid], alphas_prev[tid],
Tensor([guidance]),
)
return model.decode(latent).realize()

267
backend/python/tinygrad/vendor/unet.py vendored Normal file
View File

@@ -0,0 +1,267 @@
# Vendored verbatim from tinygrad extra/models/unet.py (MIT license).
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/unet.py
# Copyright (c) 2023- the tinygrad authors
# SPDX-License-Identifier: MIT
from tinygrad import Tensor, dtypes, nn
from tinygrad.device import is_dtype_supported
from typing import Optional, Union, List, Any, Tuple, Callable
import math
# allow for monkeypatching
Linear, Conv2d, GroupNorm, LayerNorm = nn.Linear, nn.Conv2d, nn.GroupNorm, nn.LayerNorm
attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Tensor.gelu, dtypes.float16
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
class ResBlock:
def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
self.in_layers = [
GroupNorm(num_groups, channels),
Tensor.silu,
Conv2d(channels, out_channels, 3, padding=1),
]
self.emb_layers = [
Tensor.silu,
Linear(emb_channels, out_channels),
]
self.out_layers = [
GroupNorm(num_groups, out_channels),
Tensor.silu,
lambda x: x, # needed for weights loading code to work
Conv2d(out_channels, out_channels, 3, padding=1),
]
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else (lambda x: x)
def __call__(self, x:Tensor, emb:Tensor) -> Tensor:
h = x.sequential(self.in_layers)
emb_out = emb.sequential(self.emb_layers)
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
h = h.sequential(self.out_layers)
return self.skip_connection(x) + h
class CrossAttention:
def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
self.num_heads = n_heads
self.head_size = d_head
self.attn = attention
self.to_out = [Linear(n_heads*d_head, query_dim)]
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
ctx = x if ctx is None else ctx
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
attention = self.attn(q, k, v).transpose(1,2)
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
return h_.sequential(self.to_out)
class GEGLU:
def __init__(self, dim_in:int, dim_out:int):
self.proj = Linear(dim_in, dim_out * 2)
self.gelu = gelu
self.dim_out = dim_out
def __call__(self, x:Tensor) -> Tensor:
x, gate = self.proj(x).chunk(2, dim=-1)
return x * self.gelu(gate)
class FeedForward:
def __init__(self, dim:int, mult:int=4):
self.net: tuple[GEGLU, Callable, nn.Linear] = (
GEGLU(dim, dim*mult),
lambda x: x, # needed for weights loading code to work
Linear(dim*mult, dim)
)
def __call__(self, x:Tensor) -> Tensor:
return x.sequential(list(self.net))
class BasicTransformerBlock:
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
self.ff = FeedForward(dim)
self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
self.norm1 = LayerNorm(dim)
self.norm2 = LayerNorm(dim)
self.norm3 = LayerNorm(dim)
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
x = x + self.attn1(self.norm1(x))
x = x + self.attn2(self.norm2(x), ctx=ctx)
x = x + self.ff(self.norm3(x))
return x
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
class SpatialTransformer:
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1,
norm_eps:float=1e-5):
if isinstance(ctx_dim, int):
ctx_dim = [ctx_dim]*depth
else:
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
self.norm = GroupNorm(32, channels, eps=norm_eps)
assert channels == n_heads * d_head
self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
self.proj_out = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
self.use_linear = use_linear
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
ops = [ (lambda z: z.reshape(b, c, h*w).permute(0,2,1)), (lambda z: self.proj_in(z)) ]
x = x.sequential(ops if self.use_linear else ops[::-1])
for block in self.transformer_blocks:
x = block(x, ctx=ctx)
ops = [ (lambda z: self.proj_out(z)), (lambda z: z.permute(0,2,1).reshape(b, c, h, w)) ]
x = x.sequential(ops if self.use_linear else ops[::-1])
return x + x_in
class Downsample:
def __init__(self, channels:int):
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
def __call__(self, x:Tensor) -> Tensor:
return self.op(x)
class Upsample:
def __init__(self, channels:int):
self.conv = Conv2d(channels, channels, 3, padding=1)
def __call__(self, x:Tensor) -> Tensor:
bs,c,py,px = x.shape
z = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
return self.conv(z)
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
class UNetModel:
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int,
channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None,
n_heads:Optional[int]=None, num_groups:int=32, st_norm_eps:float=1e-5):
self.model_ch = model_ch
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
self.attention_resolutions = attention_resolutions
self.d_head = d_head
self.n_heads = n_heads
def get_d_and_n_heads(dims:int) -> Tuple[int,int]:
if self.d_head is None:
assert self.n_heads is not None, f"d_head and n_heads cannot both be None"
return dims // self.n_heads, self.n_heads
else:
assert self.n_heads is None, f"d_head and n_heads cannot both be non-None"
return self.d_head, dims // self.d_head
time_embed_dim = model_ch * 4
self.time_embed = [
Linear(model_ch, time_embed_dim),
Tensor.silu,
Linear(time_embed_dim, time_embed_dim),
]
if adm_in_ch is not None:
self.label_emb = [
[
Linear(adm_in_ch, time_embed_dim),
Tensor.silu,
Linear(time_embed_dim, time_embed_dim),
]
]
self.input_blocks: List[Any] = [
[Conv2d(in_ch, model_ch, 3, padding=1)]
]
input_block_channels = [model_ch]
ch = model_ch
ds = 1
for idx, mult in enumerate(channel_mult):
for _ in range(self.num_res_blocks[idx]):
layers: List[Any] = [
ResBlock(ch, time_embed_dim, model_ch*mult, num_groups),
]
ch = mult * model_ch
if ds in attention_resolutions:
d_head, n_heads = get_d_and_n_heads(ch)
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
self.input_blocks.append(layers)
input_block_channels.append(ch)
if idx != len(channel_mult) - 1:
self.input_blocks.append([
Downsample(ch),
])
input_block_channels.append(ch)
ds *= 2
d_head, n_heads = get_d_and_n_heads(ch)
self.middle_block: List = [
ResBlock(ch, time_embed_dim, ch, num_groups),
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1], norm_eps=st_norm_eps),
ResBlock(ch, time_embed_dim, ch, num_groups),
]
self.output_blocks = []
for idx, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[idx] + 1):
ich = input_block_channels.pop()
layers = [
ResBlock(ch + ich, time_embed_dim, model_ch*mult, num_groups),
]
ch = model_ch * mult
if ds in attention_resolutions:
d_head, n_heads = get_d_and_n_heads(ch)
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
if idx > 0 and i == self.num_res_blocks[idx]:
layers.append(Upsample(ch))
ds //= 2
self.output_blocks.append(layers)
self.out = [
GroupNorm(num_groups, ch),
Tensor.silu,
Conv2d(model_ch, out_ch, 3, padding=1),
]
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
t_emb = timestep_embedding(tms, self.model_ch)
emb = t_emb.sequential(self.time_embed)
if y is not None:
assert y.shape[0] == x.shape[0]
emb = emb + y.sequential(self.label_emb[0])
if is_dtype_supported(mixed_precision_dtype):
emb = emb.cast(mixed_precision_dtype)
ctx = ctx.cast(mixed_precision_dtype)
x = x .cast(mixed_precision_dtype)
def run(x:Tensor, bb) -> Tensor:
if isinstance(bb, ResBlock): x = bb(x, emb)
elif isinstance(bb, SpatialTransformer): x = bb(x, ctx)
else: x = bb(x)
return x
saved_inputs = []
for b in self.input_blocks:
for bb in b:
x = run(x, bb)
saved_inputs.append(x)
for bb in self.middle_block:
x = run(x, bb)
for b in self.output_blocks:
x = x.cat(saved_inputs.pop(), dim=1)
for bb in b:
x = run(x, bb)
return x.sequential(self.out)

View File

@@ -0,0 +1,274 @@
# Adapted from tinygrad examples/whisper.py (MIT license).
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/whisper.py
# Copyright (c) 2023- the tinygrad authors
# SPDX-License-Identifier: MIT
#
# Local modifications: removed the pyaudio listener / __main__ block; the rest
# is the core Whisper model + preprocessing + single-file transcription path.
from __future__ import annotations
import base64
import collections
import itertools
from typing import List, Literal, Optional, Union
import numpy as np
from tinygrad import Tensor, TinyJit, Variable, dtypes, nn
from tinygrad.helpers import fetch
from tinygrad.nn.state import load_state_dict, torch_load
from .audio_helpers import mel
class MultiHeadAttention:
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self', None] = None, max_self_attn_cache_len=None):
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
self.kv_caching = kv_caching
self.max_self_attn_cache_len = max_self_attn_cache_len
def __call__(self, x, xa=None, mask=None, len=None):
if self.kv_caching == 'cross':
if xa is not None:
k, v = self.key(xa), self.value(xa)
if not hasattr(self, 'cache_k'):
self.cache_k, self.cache_v = k, v
else:
self.cache_k.assign(k).realize()
self.cache_v.assign(v).realize()
else:
k, v = self.cache_k, self.cache_v
else:
k, v = self.key(x), self.value(x)
if self.kv_caching == 'self':
if not hasattr(self, 'cache_k'):
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
padding = self.max_self_attn_cache_len - len - x.shape[1]
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
q = self.query(x)
n_ctx = q.shape[1]
head_dim = q.shape[-1] // self.n_head
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None)
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
return self.out(wv)
class ResidualAttentionBlock:
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
self.mlp = [nn.Linear(n_state, n_state * 4), Tensor.gelu, nn.Linear(n_state * 4, n_state)]
self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, len=None):
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
x = x + self.mlp_ln(x).sequential(self.mlp)
return x.realize()
class AudioEncoder:
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
self.ln_post = nn.LayerNorm(n_audio_state)
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
self.encode = TinyJit(self.__call__)
def __call__(self, x):
x = self.conv1(x).gelu()
x = self.conv2(x).gelu()
x = x.permute(0, 2, 1)
x = x + self.positional_embedding[:x.shape[1]]
x = x.sequential(self.blocks)
x = self.ln_post(x)
return x.realize()
class TextDecoder:
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
self.max_tokens_to_sample = n_text_ctx // 2
self.max_self_attn_cache_len = n_text_ctx
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
self.ln = nn.LayerNorm(n_text_state)
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
self.getjitted = collections.defaultdict(lambda: TinyJit(self.forward))
def __call__(self, x, pos, encoded_audio):
pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len - 1).bind(pos) if pos else 0
return self.getjitted[x.shape](x, pos, encoded_audio)
def forward(self, x, pos, encoded_audio):
seqlen = x.shape[-1]
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos + seqlen), None))
for block in self.blocks:
x = block(x, xa=encoded_audio, mask=self.mask, len=pos)
return self.output_tok(x)
def output_tok(self, x):
return (self.ln(x) @ self.token_embedding.weight.T).realize()
class Whisper:
def __init__(self, dims, batch_size=1):
self.encoder = AudioEncoder(**dims)
self.decoder = TextDecoder(**dims)
self.is_multilingual = dims["n_vocab"] == 51865
self.batch_size = batch_size
RATE = 16000
SEGMENT_SECONDS = 30
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate: bool = False) -> np.ndarray:
import librosa
def pad_or_trim(arr, target_len):
if len(arr) == target_len:
return arr
if len(arr) < target_len:
return np.pad(arr, (0, target_len - len(arr)), 'constant')
return arr[:target_len]
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(w) for w in waveforms)
if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
max_len += SAMPLES_PER_SEGMENT - r
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
if waveforms.shape[0] < batch_size:
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
magnitudes = np.absolute(stft[..., :-1]) ** 2
mel_spec = mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS).numpy() @ magnitudes
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
log_spec = np.maximum(log_spec, log_spec.max((1, 2), keepdims=True) - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
LANGUAGES = {
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean",
"fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "it": "italian",
}
def get_encoding(encoding_name: str):
import tiktoken
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
n_vocab = len(ranks)
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
return tiktoken.Encoding(
name=encoding_name,
explicit_n_vocab=n_vocab + len(specials),
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
MODEL_URLS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
}
def init_whisper(model_name: str = "base", batch_size: int = 1):
filename = fetch(MODEL_URLS[model_name])
state = torch_load(filename)
model = Whisper(state['dims'], batch_size)
load_state_dict(model, state['model_state_dict'], strict=False)
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
return model, enc
def load_file_waveform(filename: str):
import librosa
waveform, _ = librosa.load(filename, sr=RATE)
return waveform
def transcribe_waveform(model: Whisper, enc, waveforms, language: Optional[str] = None, truncate: bool = False) -> str:
log_spec = prep_audio(waveforms, model.batch_size, truncate)
nsample = model.decoder.max_tokens_to_sample
nctx = model.decoder.max_self_attn_cache_len
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
if model.is_multilingual:
lang = language if (language and language in LANGUAGES) else "en"
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index(lang)
start_tokens.append(language_token)
start_tokens.append(enc._special_tokens["<|transcribe|>"])
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
eot = enc._special_tokens["<|endoftext|>"]
def inferloop(ctx, encoded_audio):
pos, next_tokens = 0, ctx
for _ in range(nsample):
next_tokens = model.decoder(Tensor(next_tokens, dtype=dtypes.int32), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
next_tokens[ctx[:, -1] == eot] = eot
ctx = np.concatenate((ctx, next_tokens), axis=1)
pos = ctx.shape[-1] - 1
if (next_tokens == eot).all() or pos == nctx:
break
return ctx
ctx = np.tile(start_tokens, (model.batch_size, 1))
transcriptions: list[list[int]] = [[] for _ in waveforms]
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
ctx_arr = inferloop(np.array(ctx), encoded_audio)
for i, arr in enumerate(ctx_arr):
if i >= len(waveforms):
break
end_idxs = np.where(arr == eot)[0]
start_idx = np.where(arr == start_tokens[-1])[0][0] + 1
end_idx = end_idxs[0] if len(end_idxs) else None
transcriptions[i].extend(arr[start_idx:end_idx])
ctx = ctx_arr
texts = [enc.decode([int(t) for t in toks]).strip() for toks in transcriptions]
return texts[0] if len(texts) == 1 else "\n".join(texts)

View File

@@ -44,13 +44,20 @@ import (
// BACKEND_TEST_AUDIO_FILE Path to an already-available sample audio file.
// BACKEND_TEST_CAPS Comma-separated list of capabilities to exercise.
// Supported values: health, load, predict, stream,
// embeddings, tools, transcription.
// embeddings, tools, transcription, image.
// Defaults to "health,load,predict,stream".
// A backend that only does embeddings would set this to
// "health,load,embeddings"; an image/TTS backend that cannot
// be driven by a text prompt can set it to "health,load".
// "health,load,embeddings"; an image-generation backend
// that cannot be driven by a text prompt can set it to
// "health,load,image".
// "tools" asks the backend to extract a tool call from the
// model output into ChatDelta.tool_calls.
// "image" exercises the GenerateImage RPC and asserts a
// non-empty file is written to the requested dst path.
// BACKEND_TEST_IMAGE_PROMPT Override the positive prompt for the image spec
// (default: "a photograph of an astronaut riding a horse").
// BACKEND_TEST_IMAGE_STEPS Override the diffusion step count for the image spec
// (default: 4 — keeps CPU-only runs under a few minutes).
// BACKEND_TEST_PROMPT Override the prompt used by predict/stream specs.
// BACKEND_TEST_CTX_SIZE Override the context size passed to LoadModel (default 512).
// BACKEND_TEST_THREADS Override Threads passed to LoadModel (default 4).
@@ -75,11 +82,14 @@ const (
capEmbeddings = "embeddings"
capTools = "tools"
capTranscription = "transcription"
capImage = "image"
defaultPrompt = "The capital of France is"
streamPrompt = "Once upon a time"
defaultToolPrompt = "What's the weather like in Paris, France?"
defaultToolName = "get_weather"
defaultPrompt = "The capital of France is"
streamPrompt = "Once upon a time"
defaultToolPrompt = "What's the weather like in Paris, France?"
defaultToolName = "get_weather"
defaultImagePrompt = "a photograph of an astronaut riding a horse"
defaultImageSteps = 4
)
func defaultCaps() map[string]bool {
@@ -349,6 +359,40 @@ var _ = Describe("Backend container", Ordered, func() {
GinkgoWriter.Printf("Embedding: %d dims\n", len(res.GetEmbeddings()))
})
It("generates an image via GenerateImage", func() {
if !caps[capImage] {
Skip("image capability not enabled")
}
imgPrompt := os.Getenv("BACKEND_TEST_IMAGE_PROMPT")
if imgPrompt == "" {
imgPrompt = defaultImagePrompt
}
steps := envInt32("BACKEND_TEST_IMAGE_STEPS", defaultImageSteps)
dst := filepath.Join(workDir, "generated.png")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
res, err := client.GenerateImage(ctx, &pb.GenerateImageRequest{
PositivePrompt: imgPrompt,
NegativePrompt: "",
Width: 512,
Height: 512,
Step: steps,
Seed: 42,
Dst: dst,
})
Expect(err).NotTo(HaveOccurred())
Expect(res.GetSuccess()).To(BeTrue(), "GenerateImage failed: %s", res.GetMessage())
info, err := os.Stat(dst)
Expect(err).NotTo(HaveOccurred(), "GenerateImage did not write a file at %s", dst)
Expect(info.Size()).To(BeNumerically(">", int64(0)),
"GenerateImage wrote an empty file at %s", dst)
GinkgoWriter.Printf("GenerateImage: wrote %s (%d bytes)\n", dst, info.Size())
})
It("extracts tool calls into ChatDelta", func() {
if !caps[capTools] {
Skip("tools capability not enabled")