diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml
index 7bc063578..68953d2f9 100644
--- a/.github/workflows/backend.yml
+++ b/.github/workflows/backend.yml
@@ -105,6 +105,25 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
+ # tinygrad ships a single image — its CPU device uses bundled
+ # libLLVM, and its CUDA / HIP / Metal devices dlopen the host
+ # driver libraries at runtime via tinygrad's ctypes autogen
+ # wrappers. There is no toolkit-version split because tinygrad
+ # generates kernels itself (PTX renderer for CUDA) and never
+ # links against cuDNN/cuBLAS/torch.
+ - build-type: ''
+ cuda-major-version: ""
+ cuda-minor-version: ""
+ platforms: 'linux/amd64'
+ tag-latest: 'auto'
+ tag-suffix: '-tinygrad'
+ runs-on: 'ubuntu-latest'
+ base-image: "ubuntu:24.04"
+ skip-drivers: 'true'
+ backend: "tinygrad"
+ dockerfile: "./backend/Dockerfile.python"
+ context: "./"
+ ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
diff --git a/Makefile b/Makefile
index 133133086..61e51aad4 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds
-.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp
+.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad
GOCMD=go
GOTEST=$(GOCMD) test
@@ -432,6 +432,7 @@ prepare-test-extra: protogen-python
$(MAKE) -C backend/python/whisperx
$(MAKE) -C backend/python/ace-step
$(MAKE) -C backend/python/trl
+ $(MAKE) -C backend/python/tinygrad
$(MAKE) -C backend/rust/kokoros kokoros-grpc
test-extra: prepare-test-extra
@@ -454,6 +455,7 @@ test-extra: prepare-test-extra
$(MAKE) -C backend/python/whisperx test
$(MAKE) -C backend/python/ace-step test
$(MAKE) -C backend/python/trl test
+ $(MAKE) -C backend/python/tinygrad test
$(MAKE) -C backend/rust/kokoros test
##
@@ -550,6 +552,56 @@ test-extra-backend-vllm: docker-build-vllm
BACKEND_TEST_OPTIONS=tool_parser:hermes \
$(MAKE) test-extra-backend
+## tinygrad mirrors the vllm target (same model, same caps, same parser) so
+## the two backends are directly comparable. The LLM path covers Predict,
+## streaming and native tool-call extraction. Companion targets below cover
+## embeddings, Stable Diffusion and Whisper — run them individually or via
+## the `test-extra-backend-tinygrad-all` aggregate.
+test-extra-backend-tinygrad: docker-build-tinygrad
+ BACKEND_IMAGE=local-ai-backend:tinygrad \
+ BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \
+ BACKEND_TEST_CAPS=health,load,predict,stream,tools \
+ BACKEND_TEST_OPTIONS=tool_parser:hermes \
+ $(MAKE) test-extra-backend
+
+## tinygrad — embeddings via LLM last-hidden-state pooling. Reuses the same
+## Qwen2.5-0.5B-Instruct as the chat target so we don't need a separate BERT
+## vendor; the Embedding RPC mean-pools and L2-normalizes the last-layer
+## hidden state.
+test-extra-backend-tinygrad-embeddings: docker-build-tinygrad
+ BACKEND_IMAGE=local-ai-backend:tinygrad \
+ BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \
+ BACKEND_TEST_CAPS=health,load,embeddings \
+ $(MAKE) test-extra-backend
+
+## tinygrad — Stable Diffusion 1.5. The original CompVis/runwayml repos have
+## been gated, so we use the community-maintained mirror at
+## stable-diffusion-v1-5/stable-diffusion-v1-5 with the EMA-only pruned
+## checkpoint (~4.3GB). Step count is kept low (4) so a CPU-only run finishes
+## in a few minutes; bump BACKEND_TEST_IMAGE_STEPS for higher quality.
+test-extra-backend-tinygrad-sd: docker-build-tinygrad
+ BACKEND_IMAGE=local-ai-backend:tinygrad \
+ BACKEND_TEST_MODEL_NAME=stable-diffusion-v1-5/stable-diffusion-v1-5 \
+ BACKEND_TEST_CAPS=health,load,image \
+ $(MAKE) test-extra-backend
+
+## tinygrad — Whisper. Loads OpenAI's tiny.en checkpoint (smallest at ~75MB)
+## from the original azure CDN through tinygrad's `fetch` helper, and
+## transcribes the canonical jfk.wav fixture from whisper.cpp's CI samples.
+## Exercises both AudioTranscription and AudioTranscriptionStream.
+test-extra-backend-tinygrad-whisper: docker-build-tinygrad
+ BACKEND_IMAGE=local-ai-backend:tinygrad \
+ BACKEND_TEST_MODEL_NAME=openai/whisper-tiny.en \
+ BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
+ BACKEND_TEST_CAPS=health,load,transcription \
+ $(MAKE) test-extra-backend
+
+test-extra-backend-tinygrad-all: \
+ test-extra-backend-tinygrad \
+ test-extra-backend-tinygrad-embeddings \
+ test-extra-backend-tinygrad-sd \
+ test-extra-backend-tinygrad-whisper
+
## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool
## parser from the chat template, so no tool_parser: option is needed (it
## would be ignored at runtime). Run this on macOS / arm64 with Metal; the
@@ -707,6 +759,7 @@ BACKEND_MLX_VLM = mlx-vlm|python|.|false|true
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
BACKEND_TRL = trl|python|.|false|true
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
+BACKEND_TINYGRAD = tinygrad|python|.|false|true
# Rust backends
BACKEND_KOKOROS = kokoros|rust|.|false|true
@@ -778,6 +831,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
+$(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD)))
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
@@ -785,7 +839,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
docker-save-%: backend-images
docker save local-ai-backend:$* -o backend-images/$*.tar
-docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
+docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
########################################################
### Mock Backend for E2E Tests
diff --git a/backend/index.yaml b/backend/index.yaml
index c29ca5fb8..f7dc72251 100644
--- a/backend/index.yaml
+++ b/backend/index.yaml
@@ -361,6 +361,34 @@
intel: "intel-rerankers"
amd: "rocm-rerankers"
metal: "metal-rerankers"
+- &tinygrad
+ name: "tinygrad"
+ alias: "tinygrad"
+ license: MIT
+ description: |
+ tinygrad is a minimalist deep-learning framework with zero runtime
+ dependencies that targets CUDA, ROCm, Metal, WebGPU and CPU (CLANG).
+ The LocalAI tinygrad backend exposes a single multimodal runtime that
+ covers LLM text generation (Llama / Qwen / Mistral via safetensors or
+ GGUF) with native tool-call extraction, BERT-family embeddings,
+ Stable Diffusion 1.x / 2 / XL image generation, and Whisper speech-to-text.
+
+ Single image: tinygrad generates its own GPU kernels and dlopens the
+ host driver libraries at runtime, so there is no per-toolkit build
+ split. The same image runs CPU-only or accelerates against
+ CUDA / ROCm / Metal when the host driver is visible.
+ urls:
+ - https://github.com/tinygrad/tinygrad
+ uri: "quay.io/go-skynet/local-ai-backends:latest-tinygrad"
+ mirrors:
+ - localai/localai-backends:latest-tinygrad
+ tags:
+ - text-to-text
+ - LLM
+ - embeddings
+ - image-generation
+ - transcription
+ - multimodal
- &transformers
name: "transformers"
icon: https://avatars.githubusercontent.com/u/25720743?s=200&v=4
@@ -1993,6 +2021,15 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rerankers"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-rerankers
+## tinygrad
+## Single image — the meta anchor above carries the latest uri directly
+## since there is only one variant. The development entry below points at
+## the master tag.
+- !!merge <<: *tinygrad
+ name: "tinygrad-development"
+ uri: "quay.io/go-skynet/local-ai-backends:master-tinygrad"
+ mirrors:
+ - localai/localai-backends:master-tinygrad
## Transformers
- !!merge <<: *transformers
name: "transformers-development"
diff --git a/backend/python/tinygrad/Makefile b/backend/python/tinygrad/Makefile
new file mode 100644
index 000000000..4f7c209a2
--- /dev/null
+++ b/backend/python/tinygrad/Makefile
@@ -0,0 +1,25 @@
+.DEFAULT_GOAL := install
+
+.PHONY: install
+install:
+ bash install.sh
+
+.PHONY: run
+run: install
+ @echo "Running tinygrad..."
+ bash run.sh
+ @echo "tinygrad run."
+
+.PHONY: test
+test: install
+ @echo "Testing tinygrad..."
+ bash test.sh
+ @echo "tinygrad tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
diff --git a/backend/python/tinygrad/backend.py b/backend/python/tinygrad/backend.py
new file mode 100644
index 000000000..7fa997f47
--- /dev/null
+++ b/backend/python/tinygrad/backend.py
@@ -0,0 +1,750 @@
+#!/usr/bin/env python3
+"""
+LocalAI gRPC backend for tinygrad.
+
+Multimodal scope (landing incrementally):
+ - LLM text generation (Llama 3 / Qwen 2.5 / Mistral via HF safetensors + GGUF)
+ - Native tool-call extraction via pluggable parsers (hermes, llama3_json, ...)
+ - Embeddings, Stable Diffusion, Whisper — planned; currently return UNIMPLEMENTED.
+
+The heavy imports (tinygrad, tokenizers, vendor.llama) are deferred until
+`LoadModel`, because tinygrad binds its compute device at import time from
+env vars. `_select_tinygrad_device()` maps LocalAI's BUILD_TYPE onto the
+corresponding tinygrad env flag before any import happens.
+"""
+from __future__ import annotations
+
+import argparse
+import asyncio
+import json
+import os
+import signal
+import sys
+import time
+from concurrent import futures
+from pathlib import Path
+from typing import Any, Optional
+
+import grpc
+
+import backend_pb2
+import backend_pb2_grpc
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
+from grpc_auth import get_auth_interceptors # noqa: E402
+
+from tool_parsers import resolve_parser # noqa: E402
+from tool_parsers.base import ToolCall # noqa: E402
+
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+
+# ---------------------------------------------------------------------------
+# Device selection — must run BEFORE `import tinygrad` anywhere.
+#
+# In production this is set by run.sh based on which driver libraries the
+# host has injected into the container (libcuda.so.1 → CUDA, libamdhip64
+# → HIP, otherwise CLANG). This helper is only a fallback for direct
+# invocations like the unit tests.
+# ---------------------------------------------------------------------------
+
+def _select_tinygrad_device() -> None:
+ if any(os.environ.get(k) == "1" for k in ("CUDA", "HIP", "METAL", "CLANG", "AMD", "NV")):
+ return
+ os.environ["CLANG"] = "1"
+
+
+# ---------------------------------------------------------------------------
+# Model asset discovery
+# ---------------------------------------------------------------------------
+
+def _resolve_model_assets(model_ref: str) -> Path:
+ """
+ Accept either a local path or a HuggingFace repo id (e.g.
+ "Qwen/Qwen2.5-0.5B-Instruct") and return the local directory / file.
+ HF ids are materialized via `huggingface_hub.snapshot_download`.
+ """
+ p = Path(model_ref)
+ if p.exists():
+ return p
+ if "/" in model_ref and not model_ref.startswith(("/", ".")):
+ from huggingface_hub import snapshot_download
+ local = snapshot_download(
+ repo_id=model_ref,
+ allow_patterns=[
+ "config.json",
+ "tokenizer.json",
+ "tokenizer_config.json",
+ "special_tokens_map.json",
+ "generation_config.json",
+ "*.safetensors",
+ "*.safetensors.index.json",
+ ],
+ )
+ return Path(local)
+ raise FileNotFoundError(f"Model not found: {model_ref}")
+
+
+def _load_llm_weights(model_dir: Path):
+ """Load HF safetensors weights from a directory or a single file."""
+ from tinygrad import Device, Tensor, dtypes
+ from tinygrad.nn.state import safe_load, gguf_load
+
+ if model_dir.is_file():
+ if str(model_dir).endswith(".gguf"):
+ gguf_tensor = Tensor.empty(os.stat(model_dir).st_size, dtype=dtypes.uint8,
+ device=f"disk:{model_dir}").to(Device.DEFAULT)
+ return gguf_load(gguf_tensor)[1], "gguf"
+ return safe_load(str(model_dir)), "safetensors"
+
+ index = model_dir / "model.safetensors.index.json"
+ if index.exists():
+ with open(index) as fp:
+ weight_map = json.load(fp)["weight_map"]
+ shards: dict[str, Any] = {}
+ for shard_name in set(weight_map.values()):
+ shards[shard_name] = safe_load(str(model_dir / shard_name))
+ merged = {k: shards[n][k] for k, n in weight_map.items()}
+ return merged, "safetensors"
+
+ single = model_dir / "model.safetensors"
+ if single.exists():
+ return safe_load(str(single)), "safetensors"
+
+ ggufs = sorted(model_dir.glob("*.gguf"))
+ if ggufs:
+ gguf_tensor = Tensor.empty(os.stat(ggufs[0]).st_size, dtype=dtypes.uint8,
+ device=f"disk:{ggufs[0]}").to(Device.DEFAULT)
+ return gguf_load(gguf_tensor)[1], "gguf"
+
+ raise FileNotFoundError(f"No safetensors or gguf weights found under {model_dir}")
+
+
+def _infer_qkv_bias(config: dict) -> bool:
+ """Qwen2-family uses Q/K/V projection bias; Llama/Mistral do not."""
+ architectures = [a.lower() for a in config.get("architectures", [])]
+ if any("qwen" in a for a in architectures):
+ return True
+ return bool(config.get("attention_bias", False)) or bool(config.get("qkv_bias", False))
+
+
+def _auto_tool_parser(model_ref: Optional[str], config: dict) -> Optional[str]:
+ """Pick a tool parser automatically from model family heuristics.
+
+ Order of precedence: architecture name from config.json, then model ref
+ string. Returns None to fall through to the passthrough parser.
+ """
+ arches = " ".join(a.lower() for a in config.get("architectures", []))
+ ref = (model_ref or "").lower()
+ blob = f"{arches} {ref}"
+
+ if "qwen3" in blob:
+ return "qwen3_xml"
+ if "hermes" in blob or "qwen2" in blob or "qwen" in blob:
+ return "hermes"
+ if "llama-3" in blob or "llama_3" in blob or "llama3" in blob:
+ return "llama3_json"
+ if "mistral" in blob or "mixtral" in blob:
+ return "mistral"
+ return None
+
+
+# ---------------------------------------------------------------------------
+# Servicer
+# ---------------------------------------------------------------------------
+
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ """gRPC servicer for the tinygrad backend."""
+
+ def __init__(self) -> None:
+ self._reset_state()
+
+ def _reset_state(self) -> None:
+ self.model_ref: Optional[str] = None
+ self.model_type: str = "llm"
+ self.options: dict[str, str] = {}
+ # LLM state
+ self.llm_model = None
+ self.llm_config: dict = {}
+ self.llm_tokenizer = None
+ self.llm_eos_ids: list[int] = []
+ self.chat_template: Optional[str] = None
+ self.tool_parser = resolve_parser(None)
+ self.max_context = 4096
+ # Stable Diffusion state
+ self.sd_model = None
+ # Whisper state
+ self.whisper_model = None
+ self.whisper_tokenizer = None
+
+ # --------------------- helpers --------------------------------------
+
+ @staticmethod
+ def _parse_options(options_list) -> dict[str, str]:
+ opts: dict[str, str] = {}
+ for opt in options_list:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":", 1)
+ opts[key.strip()] = value.strip()
+ return opts
+
+ @staticmethod
+ def _detect_model_type(model_ref: str, explicit: Optional[str]) -> str:
+ if explicit:
+ return explicit
+ name = (model_ref or "").lower()
+ if "whisper" in name:
+ return "whisper"
+ if "sdxl" in name:
+ return "sdxl"
+ if "sd-v1" in name or "v1-5" in name or "stable-diffusion" in name:
+ return "sd15"
+ if any(tag in name for tag in ("bge", "e5", "minilm", "bert")):
+ return "bert"
+ return "llm"
+
+ def _messages_to_dicts(self, messages) -> list[dict]:
+ result = []
+ for msg in messages:
+ d: dict = {"role": msg.role, "content": msg.content or ""}
+ if msg.name:
+ d["name"] = msg.name
+ if msg.tool_call_id:
+ d["tool_call_id"] = msg.tool_call_id
+ if msg.reasoning_content:
+ d["reasoning_content"] = msg.reasoning_content
+ if msg.tool_calls:
+ try:
+ d["tool_calls"] = json.loads(msg.tool_calls)
+ except json.JSONDecodeError:
+ pass
+ result.append(d)
+ return result
+
+ def _render_prompt(self, request) -> str:
+ """Render messages + tools into the model's chat template, or fall
+ back to the raw Prompt field for models without a template."""
+ if not request.Messages and request.Prompt:
+ return request.Prompt
+
+ if not self.chat_template:
+ # No template known — concatenate role/content lines.
+ lines = []
+ for msg in request.Messages:
+ lines.append(f"{msg.role}: {msg.content or ''}")
+ return "\n".join(lines) + "\nassistant:"
+
+ from jinja2 import Environment
+
+ env = Environment(trim_blocks=True, lstrip_blocks=True)
+ template = env.from_string(self.chat_template)
+
+ tools = None
+ if request.Tools:
+ try:
+ tools = json.loads(request.Tools)
+ except json.JSONDecodeError:
+ tools = None
+
+ return template.render(
+ messages=self._messages_to_dicts(request.Messages),
+ tools=tools,
+ add_generation_prompt=True,
+ )
+
+ # --------------------- LLM path -------------------------------------
+
+ def _load_llm(self, model_dir: Path) -> None:
+ config_path = model_dir / "config.json" if model_dir.is_dir() else None
+ if config_path is None or not config_path.exists():
+ raise FileNotFoundError(f"config.json not found under {model_dir}")
+ with open(config_path) as fp:
+ config = json.load(fp)
+ self.llm_config = config
+
+ from tinygrad import nn
+
+ from vendor.llama import Transformer, convert_from_huggingface, convert_from_gguf, fix_bf16
+
+ n_layers = config["num_hidden_layers"]
+ n_heads = config["num_attention_heads"]
+ n_kv_heads = config.get("num_key_value_heads", n_heads)
+ dim = config["hidden_size"]
+ hidden_dim = config["intermediate_size"]
+ vocab_size = config["vocab_size"]
+ norm_eps = config.get("rms_norm_eps", 1e-5)
+ rope_theta = config.get("rope_theta", 10000)
+ self.max_context = min(config.get("max_position_embeddings", 4096), 8192)
+ qkv_bias = _infer_qkv_bias(config)
+
+ weights, weight_format = _load_llm_weights(model_dir)
+
+ model = Transformer(
+ dim=dim,
+ hidden_dim=hidden_dim,
+ n_heads=n_heads,
+ n_layers=n_layers,
+ norm_eps=norm_eps,
+ vocab_size=vocab_size,
+ linear=nn.Linear,
+ embedding=nn.Embedding,
+ n_kv_heads=n_kv_heads,
+ rope_theta=rope_theta,
+ max_context=self.max_context,
+ jit=True,
+ qkv_bias=qkv_bias,
+ )
+
+ if weight_format == "safetensors":
+ weights = convert_from_huggingface(weights, n_layers, n_heads, n_kv_heads)
+ else:
+ weights = convert_from_gguf(weights, n_layers)
+ weights = fix_bf16(weights)
+
+ from tinygrad.nn.state import load_state_dict
+ load_state_dict(model, weights, strict=False, consume=True)
+ self.llm_model = model
+
+ # Tokenizer
+ tokenizer_json = model_dir / "tokenizer.json"
+ if not tokenizer_json.exists():
+ raise FileNotFoundError(f"tokenizer.json not found under {model_dir}")
+ from tokenizers import Tokenizer as HFTokenizer
+ self.llm_tokenizer = HFTokenizer.from_file(str(tokenizer_json))
+
+ # Chat template + EOS ids (from tokenizer_config.json + generation_config.json)
+ tok_cfg_path = model_dir / "tokenizer_config.json"
+ if tok_cfg_path.exists():
+ with open(tok_cfg_path) as fp:
+ tok_cfg = json.load(fp)
+ self.chat_template = tok_cfg.get("chat_template")
+
+ self.llm_eos_ids = []
+ gen_cfg_path = model_dir / "generation_config.json"
+ if gen_cfg_path.exists():
+ with open(gen_cfg_path) as fp:
+ gen_cfg = json.load(fp)
+ eos = gen_cfg.get("eos_token_id")
+ if isinstance(eos, list):
+ self.llm_eos_ids.extend(int(x) for x in eos)
+ elif isinstance(eos, int):
+ self.llm_eos_ids.append(eos)
+ if not self.llm_eos_ids:
+ eos = config.get("eos_token_id")
+ if isinstance(eos, list):
+ self.llm_eos_ids.extend(int(x) for x in eos)
+ elif isinstance(eos, int):
+ self.llm_eos_ids.append(eos)
+
+ # Auto-pick tool parser from options or model family.
+ parser_name = self.options.get("tool_parser") or _auto_tool_parser(self.model_ref, config)
+ self.tool_parser = resolve_parser(parser_name)
+
+ # --------------------- Stable Diffusion path ------------------------
+
+ def _load_sd(self, model_ref: str) -> None:
+ """Load a Stable Diffusion 1.x checkpoint (CompVis `.ckpt` format)."""
+ from huggingface_hub import hf_hub_download
+ from tinygrad.nn.state import load_state_dict, torch_load
+
+ from vendor.stable_diffusion import StableDiffusion
+
+ ckpt_path = Path(model_ref)
+ if not ckpt_path.exists():
+ # Accept an HF repo id — fetch the canonical v1-5-pruned-emaonly.ckpt
+ # from the requested repo. Common case is runwayml/stable-diffusion-v1-5.
+ repo_id = model_ref if "/" in model_ref else "runwayml/stable-diffusion-v1-5"
+ ckpt_file = self.options.get("sd_ckpt_filename", "v1-5-pruned-emaonly.ckpt")
+ ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_file))
+
+ model = StableDiffusion()
+ state_dict = torch_load(str(ckpt_path))
+ if isinstance(state_dict, dict) and "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ load_state_dict(model, state_dict, strict=False, verbose=False, realize=False)
+ self.sd_model = model
+
+ # --------------------- Whisper path ---------------------------------
+
+ def _load_whisper(self, model_ref: str) -> None:
+ """Load a Whisper checkpoint (OpenAI `.pt` format).
+
+ Accepts a model-size alias (tiny / tiny.en / base / base.en / small /
+ small.en) OR an explicit `.pt` file path OR the HF repo id naming
+ convention `openai/whisper-*` (mapped to the matching OpenAI alias).
+ """
+ from vendor.whisper import init_whisper, MODEL_URLS
+
+ alias = model_ref
+ if "/" in alias and alias.startswith("openai/whisper-"):
+ alias = alias.removeprefix("openai/whisper-")
+ if alias not in MODEL_URLS:
+ # Explicit path to a .pt checkpoint — fall back to size heuristic
+ # via filename.
+ basename = Path(alias).name.lower()
+ for name in MODEL_URLS:
+ if name in basename:
+ alias = name
+ break
+ else:
+ raise ValueError(
+ f"Unknown Whisper model_ref={model_ref!r}; expected one of {list(MODEL_URLS)} "
+ f"or an openai/whisper-* HF id"
+ )
+
+ model, enc = init_whisper(alias, batch_size=1)
+ self.whisper_model = model
+ self.whisper_tokenizer = enc
+
+ # --------------------- LLM generation -------------------------------
+
+ def _generate_tokens(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int):
+ """Synchronous generator yielding (token_id, token_text) pairs."""
+ from tinygrad import Tensor
+
+ encoded = self.llm_tokenizer.encode(prompt)
+ ids = encoded.ids
+ if not ids:
+ return
+
+ # Prefill: run one forward pass over the full prompt, sampling from the last token.
+ prompt_tensor = Tensor([ids])
+ next_tok = int(self.llm_model(prompt_tensor, 0, temperature=temperature, top_k=top_k, top_p=top_p).item())
+ pos = len(ids)
+
+ for _ in range(max_new_tokens):
+ if next_tok in self.llm_eos_ids:
+ break
+ text = self.llm_tokenizer.decode([next_tok])
+ yield next_tok, text
+ next_tok = int(self.llm_model(Tensor([[next_tok]]), pos, temperature=temperature,
+ top_k=top_k, top_p=top_p).item())
+ pos += 1
+
+ # --------------------- gRPC methods ---------------------------------
+
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ async def LoadModel(self, request, context):
+ try:
+ _select_tinygrad_device()
+ self._reset_state()
+ self.options = self._parse_options(list(request.Options))
+ self.model_ref = request.ModelFile or request.Model
+ self.model_type = self._detect_model_type(self.model_ref, self.options.get("model_type"))
+
+ if self.model_type in ("sd15", "sd", "stable-diffusion"):
+ self._load_sd(self.model_ref)
+ return backend_pb2.Result(
+ success=True, message="tinygrad Stable Diffusion 1.x loaded",
+ )
+
+ if self.model_type == "whisper":
+ self._load_whisper(self.model_ref)
+ return backend_pb2.Result(
+ success=True, message="tinygrad Whisper loaded",
+ )
+
+ if self.model_type != "llm":
+ return backend_pb2.Result(
+ success=False,
+ message=f"tinygrad: model_type={self.model_type} not yet implemented",
+ )
+
+ model_path = _resolve_model_assets(self.model_ref)
+ if model_path.is_file():
+ return backend_pb2.Result(
+ success=False,
+ message=("tinygrad currently requires an HF model directory with config.json + "
+ "tokenizer.json; single-file GGUF support lands with gallery metadata wiring."),
+ )
+ self._load_llm(model_path)
+
+ return backend_pb2.Result(
+ success=True,
+ message=f"tinygrad LLM loaded (arch={self.llm_config.get('architectures', ['?'])[0]}, "
+ f"parser={self.tool_parser.name})",
+ )
+ except Exception as exc:
+ import traceback
+ traceback.print_exc()
+ return backend_pb2.Result(success=False, message=f"LoadModel failed: {exc}")
+
+ async def Predict(self, request, context):
+ if self.llm_model is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("LLM not loaded")
+ return backend_pb2.Reply()
+
+ try:
+ prompt = self._render_prompt(request)
+ max_new = request.Tokens if request.Tokens > 0 else 256
+ temperature = request.Temperature if request.Temperature > 0 else 0.7
+ top_p = request.TopP if request.TopP > 0 else 0.95
+ top_k = request.TopK if request.TopK > 0 else 0
+
+ t0 = time.monotonic()
+ pieces: list[str] = []
+ ntok = 0
+ for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
+ pieces.append(text)
+ ntok += 1
+ elapsed = time.monotonic() - t0
+
+ full = "".join(pieces)
+ from tool_parsers.hermes import HermesToolParser
+ if isinstance(self.tool_parser, HermesToolParser):
+ result = self.tool_parser.parse_full(full)
+ content, calls, reasoning = result.content, result.tool_calls, result.reasoning
+ else:
+ content, calls = self.tool_parser.parse(full)
+ reasoning = ""
+
+ delta = backend_pb2.ChatDelta(
+ content=content,
+ reasoning_content=reasoning,
+ tool_calls=[
+ backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
+ for c in calls
+ ],
+ )
+ return backend_pb2.Reply(
+ message=content.encode("utf-8"),
+ tokens=ntok,
+ timing_token_generation=elapsed,
+ chat_deltas=[delta],
+ )
+ except Exception as exc:
+ import traceback
+ traceback.print_exc()
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"Predict failed: {exc}")
+ return backend_pb2.Reply()
+
+ async def PredictStream(self, request, context):
+ if self.llm_model is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("LLM not loaded")
+ return
+
+ try:
+ prompt = self._render_prompt(request)
+ max_new = request.Tokens if request.Tokens > 0 else 256
+ temperature = request.Temperature if request.Temperature > 0 else 0.7
+ top_p = request.TopP if request.TopP > 0 else 0.95
+ top_k = request.TopK if request.TopK > 0 else 0
+
+ buffer = ""
+ for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
+ buffer += text
+ yield backend_pb2.Reply(
+ message=text.encode("utf-8"),
+ chat_deltas=[backend_pb2.ChatDelta(content=text)],
+ )
+
+ # Final emission carries the extracted tool calls (vLLM semantics).
+ from tool_parsers.hermes import HermesToolParser
+ if isinstance(self.tool_parser, HermesToolParser):
+ result = self.tool_parser.parse_full(buffer)
+ calls = result.tool_calls
+ reasoning = result.reasoning
+ else:
+ _, calls = self.tool_parser.parse(buffer)
+ reasoning = ""
+
+ if calls or reasoning:
+ yield backend_pb2.Reply(
+ chat_deltas=[backend_pb2.ChatDelta(
+ reasoning_content=reasoning,
+ tool_calls=[
+ backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
+ for c in calls
+ ],
+ )],
+ )
+ except Exception as exc:
+ import traceback
+ traceback.print_exc()
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"PredictStream failed: {exc}")
+
+ async def Embedding(self, request, context):
+ if self.llm_model is None or self.llm_tokenizer is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("No model loaded")
+ return backend_pb2.EmbeddingResult()
+
+ try:
+ text = request.Embeddings
+ if not text:
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details("Embeddings field is empty")
+ return backend_pb2.EmbeddingResult()
+
+ from tinygrad import Tensor, dtypes
+
+ ids = self.llm_tokenizer.encode(text).ids
+ if not ids:
+ return backend_pb2.EmbeddingResult(embeddings=[])
+
+ # Clamp to context window — truncate long inputs rather than blow up.
+ ids = ids[: self.max_context]
+ tokens = Tensor([ids])
+
+ hidden = self.llm_model.embed(tokens) # (1, seqlen, dim)
+ # Mean pool over sequence dim
+ pooled = hidden.mean(axis=1).squeeze(0) # (dim,)
+ # L2 normalize
+ norm = pooled.square().sum().sqrt()
+ normalized = (pooled / (norm + 1e-12))
+ vec = normalized.cast(dtypes.float32).tolist()
+
+ return backend_pb2.EmbeddingResult(embeddings=[float(x) for x in vec])
+ except Exception as exc:
+ import traceback
+ traceback.print_exc()
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"Embedding failed: {exc}")
+ return backend_pb2.EmbeddingResult()
+
+ async def GenerateImage(self, request, context):
+ if self.sd_model is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("No Stable Diffusion model loaded")
+ return backend_pb2.Result(success=False, message="not loaded")
+
+ try:
+ from PIL import Image
+ from vendor.stable_diffusion import run_sd15
+
+ steps = request.step if request.step > 0 else 20
+ guidance = 7.5
+ seed = request.seed if request.seed != 0 else None
+ img_tensor = run_sd15(
+ model=self.sd_model,
+ prompt=request.positive_prompt or "",
+ negative_prompt=request.negative_prompt or "",
+ steps=steps,
+ guidance=guidance,
+ seed=seed,
+ )
+ arr = img_tensor.numpy()
+ image = Image.fromarray(arr)
+ dst = request.dst or "/tmp/tinygrad_image.png"
+ image.save(dst)
+ return backend_pb2.Result(success=True, message=dst)
+ except Exception as exc:
+ import traceback
+ traceback.print_exc()
+ return backend_pb2.Result(success=False, message=f"GenerateImage failed: {exc}")
+
+ def _transcribe(self, audio_path: str, language: Optional[str]) -> tuple[str, float]:
+ from vendor.whisper import load_file_waveform, transcribe_waveform
+
+ waveform = load_file_waveform(audio_path)
+ text = transcribe_waveform(
+ self.whisper_model,
+ self.whisper_tokenizer,
+ [waveform],
+ language=language or None,
+ )
+ duration = float(len(waveform)) / 16000.0
+ return text, duration
+
+ async def AudioTranscription(self, request, context):
+ if self.whisper_model is None or self.whisper_tokenizer is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("No Whisper model loaded")
+ return backend_pb2.TranscriptResult()
+
+ try:
+ if not request.dst:
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details("TranscriptRequest.dst (audio file path) is required")
+ return backend_pb2.TranscriptResult()
+
+ text, duration = self._transcribe(request.dst, request.language)
+ segments = [backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)]
+ return backend_pb2.TranscriptResult(
+ text=text,
+ language=request.language or "en",
+ duration=duration,
+ segments=segments,
+ )
+ except Exception as exc:
+ import traceback
+ traceback.print_exc()
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"AudioTranscription failed: {exc}")
+ return backend_pb2.TranscriptResult()
+
+ async def AudioTranscriptionStream(self, request, context):
+ if self.whisper_model is None or self.whisper_tokenizer is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("No Whisper model loaded")
+ return
+
+ try:
+ if not request.dst:
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+ context.set_details("TranscriptRequest.dst (audio file path) is required")
+ return
+
+ # The vendored tinygrad whisper loop is chunked at the file level
+ # (one inference pass per 30s segment), not token-level. To still
+ # produce a streaming response we run the full transcription and
+ # emit it as a single delta + a final-result envelope so the client
+ # gets both code paths exercised.
+ text, duration = self._transcribe(request.dst, request.language)
+ yield backend_pb2.TranscriptStreamResponse(delta=text)
+ final = backend_pb2.TranscriptResult(
+ text=text,
+ language=request.language or "en",
+ duration=duration,
+ segments=[backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)],
+ )
+ yield backend_pb2.TranscriptStreamResponse(final_result=final)
+ except Exception as exc:
+ import traceback
+ traceback.print_exc()
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(f"AudioTranscriptionStream failed: {exc}")
+
+ async def Status(self, request, context):
+ return backend_pb2.StatusResponse(state=backend_pb2.StatusResponse.READY)
+
+ async def Free(self, request, context):
+ self._reset_state()
+ return backend_pb2.Result(success=True, message="freed")
+
+
+async def serve(address):
+ server = grpc.aio.server(
+ migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024),
+ ('grpc.max_send_message_length', 50 * 1024 * 1024),
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024),
+ ],
+ interceptors=get_auth_interceptors(aio=True),
+ )
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+
+ loop = asyncio.get_event_loop()
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
+
+ await server.start()
+ print("Server started. Listening on: " + address, file=sys.stderr)
+ await server.wait_for_termination()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run the tinygrad gRPC backend.")
+ parser.add_argument("--addr", default="localhost:50051", help="Bind address")
+ args = parser.parse_args()
+ asyncio.run(serve(args.addr))
diff --git a/backend/python/tinygrad/install.sh b/backend/python/tinygrad/install.sh
new file mode 100755
index 000000000..c4f444938
--- /dev/null
+++ b/backend/python/tinygrad/install.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# tinygrad >= 0.12 requires Python >= 3.11 (pyproject: `requires-python = ">=3.11"`).
+# LocalAI's default portable python is 3.10, so we pin to 3.11.x here.
+PYTHON_VERSION="3.11"
+PYTHON_PATCH="14"
+PY_STANDALONE_TAG="20260203"
+
+installRequirements
diff --git a/backend/python/tinygrad/package.sh b/backend/python/tinygrad/package.sh
new file mode 100755
index 000000000..5ac600e19
--- /dev/null
+++ b/backend/python/tinygrad/package.sh
@@ -0,0 +1,103 @@
+#!/bin/bash
+# Script to package runtime shared libraries for the tinygrad backend.
+#
+# The final Dockerfile.python stage is FROM scratch, so system libraries
+# must be explicitly copied into ${BACKEND}/lib so the backend can run on
+# any host without installing them. libbackend.sh automatically prepends
+# that directory to LD_LIBRARY_PATH at run time.
+#
+# tinygrad's CPU device (CLANG / LLVM renderer) JIT-compiles kernels at
+# runtime. The default `CLANG` path invokes the external `clang` binary via
+# subprocess, which does not exist in the scratch image. We force the
+# in-process LLVM path (`CPU_LLVM=1` in run.sh) which loads libLLVM.so.*
+# through ctypes and bundle the library + its runtime dependencies here.
+#
+# Also bundle libgomp (pulled by librosa / numpy via numba) and libsndfile
+# (required by soundfile -> librosa audio I/O for Whisper).
+
+set -e
+
+CURDIR=$(dirname "$(realpath "$0")")
+LIB_DIR="${CURDIR}/lib"
+mkdir -p "${LIB_DIR}"
+
+SEARCH_DIRS=(
+ /usr/lib/x86_64-linux-gnu
+ /usr/lib/aarch64-linux-gnu
+ /lib/x86_64-linux-gnu
+ /lib/aarch64-linux-gnu
+ /usr/lib
+ /lib
+)
+
+copy_with_symlinks() {
+ local soname="$1"
+ local hit=""
+ for dir in "${SEARCH_DIRS[@]}"; do
+ if [ -e "${dir}/${soname}" ]; then
+ hit="${dir}/${soname}"
+ break
+ fi
+ done
+ if [ -z "${hit}" ]; then
+ echo "warning: ${soname} not found in standard lib paths" >&2
+ return 0
+ fi
+ local real
+ real=$(readlink -f "${hit}")
+ cp -v "${real}" "${LIB_DIR}/"
+ local real_base
+ real_base=$(basename "${real}")
+ if [ "${real_base}" != "${soname}" ]; then
+ ln -sf "${real_base}" "${LIB_DIR}/${soname}"
+ fi
+}
+
+# tinygrad searches for libLLVM under these sonames (see
+# tinygrad/runtime/autogen/llvm.py). Ubuntu 24.04's `llvm` metapackage
+# installs `libLLVM-18.so.1` into `/usr/lib/llvm-18/lib/`. Also scan the
+# standard lib directories in case a different distro layout puts it in
+# /usr/lib/x86_64-linux-gnu.
+llvm_so=""
+shopt -s nullglob
+LLVM_EXTRA_DIRS=(/usr/lib/llvm-*/lib /usr/lib/llvm-*)
+# First try the versioned symlink (libLLVM-18.so) since that's what
+# tinygrad's DLL loader matches against (see llvm.py DLL name list).
+for dir in "${SEARCH_DIRS[@]}" "${LLVM_EXTRA_DIRS[@]}"; do
+ for candidate in "${dir}"/libLLVM-[0-9]*.so "${dir}"/libLLVM-[0-9]*.so.[0-9]*; do
+ if [ -e "${candidate}" ]; then
+ llvm_so="${candidate}"
+ break 2
+ fi
+ done
+done
+# Fallback: any libLLVM.so file under /usr.
+if [ -z "${llvm_so}" ]; then
+ llvm_so=$(find /usr -maxdepth 5 -name 'libLLVM*.so*' 2>/dev/null | head -1)
+fi
+shopt -u nullglob
+if [ -z "${llvm_so}" ]; then
+ echo "ERROR: libLLVM not found — tinygrad CPU device needs it." >&2
+ echo "Install the Ubuntu \`llvm\` package in the builder stage." >&2
+ exit 1
+fi
+echo "Found libLLVM at: ${llvm_so}"
+llvm_base=$(basename "${llvm_so}")
+real_llvm=$(readlink -f "${llvm_so}")
+cp -v "${real_llvm}" "${LIB_DIR}/"
+real_base=$(basename "${real_llvm}")
+if [ "${real_base}" != "${llvm_base}" ]; then
+ ln -sf "${real_base}" "${LIB_DIR}/${llvm_base}"
+fi
+
+# libLLVM has soft runtime deps on libedit / libtinfo; pick them up if
+# present. They're optional but loading without them can fail.
+copy_with_symlinks libedit.so.2
+copy_with_symlinks libtinfo.so.6
+
+# Audio I/O for the Whisper path.
+copy_with_symlinks libsndfile.so.1
+copy_with_symlinks libgomp.so.1
+
+echo "tinygrad packaging completed successfully"
+ls -liah "${LIB_DIR}/"
diff --git a/backend/python/tinygrad/protogen.sh b/backend/python/tinygrad/protogen.sh
new file mode 100755
index 000000000..df3325c6f
--- /dev/null
+++ b/backend/python/tinygrad/protogen.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runProtogen
diff --git a/backend/python/tinygrad/requirements-cpu.txt b/backend/python/tinygrad/requirements-cpu.txt
new file mode 100644
index 000000000..e1c64640a
--- /dev/null
+++ b/backend/python/tinygrad/requirements-cpu.txt
@@ -0,0 +1 @@
+# tinygrad CPU backend uses CLANG device (no extra deps required).
diff --git a/backend/python/tinygrad/requirements-cublas12.txt b/backend/python/tinygrad/requirements-cublas12.txt
new file mode 100644
index 000000000..2deb71bc5
--- /dev/null
+++ b/backend/python/tinygrad/requirements-cublas12.txt
@@ -0,0 +1,2 @@
+# tinygrad drives CUDA through its own JIT (CUDA=1 env var).
+# Requires the CUDA 12 runtime from the base image; no extra Python deps.
diff --git a/backend/python/tinygrad/requirements-cublas13.txt b/backend/python/tinygrad/requirements-cublas13.txt
new file mode 100644
index 000000000..74dc56eae
--- /dev/null
+++ b/backend/python/tinygrad/requirements-cublas13.txt
@@ -0,0 +1,2 @@
+# tinygrad drives CUDA through its own JIT (CUDA=1 env var).
+# Requires the CUDA 13 runtime from the base image; no extra Python deps.
diff --git a/backend/python/tinygrad/requirements.txt b/backend/python/tinygrad/requirements.txt
new file mode 100644
index 000000000..12b31ffca
--- /dev/null
+++ b/backend/python/tinygrad/requirements.txt
@@ -0,0 +1,15 @@
+grpcio==1.80.0
+protobuf==6.33.5
+certifi
+setuptools
+numpy>=2.0.0
+tinygrad>=0.12.0
+tokenizers>=0.21.0
+huggingface_hub
+jinja2>=3.1.0
+tiktoken
+sentencepiece
+safetensors
+Pillow
+librosa
+soundfile
diff --git a/backend/python/tinygrad/run.sh b/backend/python/tinygrad/run.sh
new file mode 100755
index 000000000..b341291ea
--- /dev/null
+++ b/backend/python/tinygrad/run.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+# tinygrad binds its compute device at import time from a single env var
+# (CUDA / HIP / METAL / CLANG). We pick one here based on what driver
+# libraries the host has injected into the container — when a user runs
+# the image with `--gpus all` (or the equivalent rocm runtime), the
+# nvidia-container-toolkit / rocm runtime mounts the right libraries
+# under /usr/lib so we can detect them.
+#
+# tinygrad's CUDA path uses two compiler pairs: an NVRTC-backed one and
+# an in-process PTX renderer. We force the PTX renderer here
+# (`CUDA_PTX=1`) so the image is independent of the host CUDA toolkit
+# version — only libcuda.so.1 (the driver) is required.
+find_lib() {
+ local soname="$1"
+ for dir in /usr/lib/x86_64-linux-gnu /usr/lib64 /usr/lib /lib/x86_64-linux-gnu /lib64 /lib; do
+ if [ -e "${dir}/${soname}" ]; then
+ echo "${dir}/${soname}"
+ return 0
+ fi
+ done
+ return 1
+}
+
+if [ -z "${CUDA:-}${HIP:-}${METAL:-}${CLANG:-}" ]; then
+ if find_lib libcuda.so.1 >/dev/null; then
+ export CUDA=1
+ export CUDA_PTX=1
+ elif find_lib libamdhip64.so >/dev/null || find_lib libamdhip64.so.6 >/dev/null; then
+ export HIP=1
+ else
+ export CLANG=1
+ fi
+fi
+
+# The CPU path (CLANG=1) JIT-compiles via libLLVM. Force tinygrad's
+# in-process LLVM compiler so we don't need an external `clang` binary
+# (which is not present in the scratch image).
+export CPU_LLVM=1
+if [ -z "${LLVM_PATH:-}" ]; then
+ for candidate in "${EDIR}"/lib/libLLVM-*.so "${EDIR}"/lib/libLLVM-*.so.* "${EDIR}"/lib/libLLVM.so.*; do
+ if [ -e "${candidate}" ]; then
+ export LLVM_PATH="${candidate}"
+ break
+ fi
+ done
+fi
+
+startBackend $@
diff --git a/backend/python/tinygrad/test.py b/backend/python/tinygrad/test.py
new file mode 100644
index 000000000..bf47d9bda
--- /dev/null
+++ b/backend/python/tinygrad/test.py
@@ -0,0 +1,84 @@
+"""
+Unit tests for the tinygrad gRPC backend.
+
+These tests cover the cheap paths that don't need a real model checkpoint:
+ - Health responds OK
+ - Tool-call parsers emit expected ToolCall structures
+
+The full LLM / embeddings / Stable Diffusion / Whisper paths are exercised by
+the root-level `make test-extra-backend-tinygrad-all` e2e targets, which boot
+the containerized backend against real HF checkpoints.
+"""
+import os
+import subprocess
+import sys
+import time
+import unittest
+
+import grpc
+
+import backend_pb2
+import backend_pb2_grpc
+
+sys.path.insert(0, os.path.dirname(__file__))
+from tool_parsers.hermes import HermesToolParser # noqa: E402
+
+
+class TestHealth(unittest.TestCase):
+ def setUp(self):
+ self.service = subprocess.Popen(
+ ["python3", "backend.py", "--addr", "localhost:50051"]
+ )
+ time.sleep(5)
+
+ def tearDown(self):
+ self.service.kill()
+ self.service.wait()
+
+ def test_health(self):
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b"OK")
+
+
+class TestHermesParser(unittest.TestCase):
+ def test_single_tool_call(self):
+ parser = HermesToolParser()
+ text = (
+ "Sure, let me check.\n"
+ "\n"
+ '{"name": "get_weather", "arguments": {"city": "Paris"}}\n'
+ "\n"
+ "Done."
+ )
+ content, calls = parser.parse(text)
+ self.assertIn("Sure", content)
+ self.assertIn("Done", content)
+ self.assertEqual(len(calls), 1)
+ self.assertEqual(calls[0].name, "get_weather")
+ self.assertIn("Paris", calls[0].arguments)
+
+ def test_multi_call_and_thinking(self):
+ parser = HermesToolParser()
+ text = (
+ "I need both."
+ '{"name":"a","arguments":{"x":1}}'
+ '{"name":"b","arguments":{}}'
+ )
+ result = parser.parse_full(text)
+ self.assertEqual(result.reasoning, "I need both.")
+ self.assertEqual([c.name for c in result.tool_calls], ["a", "b"])
+ self.assertEqual(result.tool_calls[0].index, 0)
+ self.assertEqual(result.tool_calls[1].index, 1)
+
+ def test_no_tool_call_is_passthrough(self):
+ parser = HermesToolParser()
+ text = "plain assistant answer with no tool call"
+ content, calls = parser.parse(text)
+ self.assertEqual(content, text)
+ self.assertEqual(calls, [])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/backend/python/tinygrad/test.sh b/backend/python/tinygrad/test.sh
new file mode 100755
index 000000000..eb59f2aaf
--- /dev/null
+++ b/backend/python/tinygrad/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/backend/python/tinygrad/tool_parsers/__init__.py b/backend/python/tinygrad/tool_parsers/__init__.py
new file mode 100644
index 000000000..3f1d6c5b7
--- /dev/null
+++ b/backend/python/tinygrad/tool_parsers/__init__.py
@@ -0,0 +1,11 @@
+"""Tool-call parsers for the tinygrad backend.
+
+Each parser takes raw model output and extracts OpenAI-style tool calls so
+the backend can populate `ChatDelta.tool_calls[]` natively (matching vLLM's
+behavior, which the Go core prefers over regex fallback parsing).
+"""
+from __future__ import annotations
+
+from .base import ToolCall, ToolParser, resolve_parser
+
+__all__ = ["ToolCall", "ToolParser", "resolve_parser"]
diff --git a/backend/python/tinygrad/tool_parsers/base.py b/backend/python/tinygrad/tool_parsers/base.py
new file mode 100644
index 000000000..5e2fcc910
--- /dev/null
+++ b/backend/python/tinygrad/tool_parsers/base.py
@@ -0,0 +1,85 @@
+"""Common types + parser registry for tool-call extraction."""
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+
+@dataclass
+class ToolCall:
+ """One extracted tool call — maps 1:1 to backend_pb2.ToolCallDelta."""
+ index: int
+ name: str
+ arguments: str # JSON string
+ id: str = ""
+
+
+class ToolParser:
+ """Parser interface.
+
+ Subclasses implement `parse` (full non-streaming pass) and optionally
+ `parse_stream` (incremental). The default `parse_stream` buffers until a
+ full response is available and then delegates to `parse`.
+ """
+
+ name: str = "base"
+
+ def __init__(self) -> None:
+ self._stream_buffer = ""
+ self._stream_index = 0
+
+ def parse(self, text: str) -> tuple[str, list[ToolCall]]:
+ """Return (content_for_user, tool_calls)."""
+ raise NotImplementedError
+
+ def parse_stream(self, delta: str, finished: bool = False) -> tuple[str, list[ToolCall]]:
+ """Accumulate a streaming delta. Emits any tool calls that have closed.
+
+ Default behavior: buffer until `finished=True`, then parse once.
+ Subclasses can override to emit mid-stream.
+ """
+ self._stream_buffer += delta
+ if not finished:
+ return "", []
+ content, calls = self.parse(self._stream_buffer)
+ # Re-index starting from whatever we've already emitted in this stream.
+ reindexed: list[ToolCall] = []
+ for i, c in enumerate(calls):
+ reindexed.append(ToolCall(
+ index=self._stream_index + i,
+ name=c.name,
+ arguments=c.arguments,
+ id=c.id,
+ ))
+ self._stream_index += len(reindexed)
+ return content, reindexed
+
+ def reset(self) -> None:
+ self._stream_buffer = ""
+ self._stream_index = 0
+
+
+_REGISTRY: dict[str, type[ToolParser]] = {}
+
+
+def register(cls: type[ToolParser]) -> type[ToolParser]:
+ _REGISTRY[cls.name] = cls
+ return cls
+
+
+def resolve_parser(name: Optional[str]) -> ToolParser:
+ """Return a parser instance by name, falling back to a no-op passthrough."""
+ # Import for side effects — each module registers itself.
+ from . import hermes, llama3_json, mistral, qwen3_xml # noqa: F401
+
+ if name and name in _REGISTRY:
+ return _REGISTRY[name]()
+ return PassthroughToolParser()
+
+
+class PassthroughToolParser(ToolParser):
+ """No-op parser — used when no tool_parser is configured."""
+ name = "passthrough"
+
+ def parse(self, text: str) -> tuple[str, list[ToolCall]]:
+ return text, []
diff --git a/backend/python/tinygrad/tool_parsers/hermes.py b/backend/python/tinygrad/tool_parsers/hermes.py
new file mode 100644
index 000000000..078dc9613
--- /dev/null
+++ b/backend/python/tinygrad/tool_parsers/hermes.py
@@ -0,0 +1,74 @@
+"""Hermes-format tool-call parser.
+
+Hermes 2 / 2.5 / 3 (and Qwen 2.5 Instruct, which adopted the same convention)
+emit tool calls wrapped in `...` tags, where the inner
+content is a JSON object with `name` and `arguments` keys:
+
+
+ {"name": "get_weather", "arguments": {"city": "Paris"}}
+
+
+Multiple tool calls may appear back-to-back. Text outside the tags is plain
+assistant content that should surface to the user.
+
+This parser also strips `...` reasoning blocks and returns them
+via the reasoning_content channel (Qwen 3, DeepSeek-R1 distills).
+"""
+from __future__ import annotations
+
+import json
+import re
+from dataclasses import dataclass
+
+from .base import ToolCall, ToolParser, register
+
+_TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL)
+_THINK_RE = re.compile(r"(.*?)", re.DOTALL)
+
+
+@dataclass
+class HermesParseResult:
+ content: str
+ reasoning: str
+ tool_calls: list[ToolCall]
+
+
+@register
+class HermesToolParser(ToolParser):
+ name = "hermes"
+
+ def _parse_full(self, text: str) -> HermesParseResult:
+ reasoning_parts: list[str] = []
+
+ def _capture_reasoning(match: re.Match[str]) -> str:
+ reasoning_parts.append(match.group(1).strip())
+ return ""
+
+ text_wo_think = _THINK_RE.sub(_capture_reasoning, text)
+
+ calls: list[ToolCall] = []
+ for idx, match in enumerate(_TOOL_CALL_RE.finditer(text_wo_think)):
+ raw = match.group(1)
+ try:
+ obj = json.loads(raw)
+ except json.JSONDecodeError:
+ continue
+ if not isinstance(obj, dict):
+ continue
+ name = obj.get("name")
+ if not isinstance(name, str):
+ continue
+ args = obj.get("arguments", {})
+ args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
+ calls.append(ToolCall(index=idx, name=name, arguments=args_str))
+
+ content = _TOOL_CALL_RE.sub("", text_wo_think).strip()
+ reasoning = "\n\n".join(reasoning_parts).strip()
+ return HermesParseResult(content=content, reasoning=reasoning, tool_calls=calls)
+
+ def parse(self, text: str) -> tuple[str, list[ToolCall]]:
+ result = self._parse_full(text)
+ return result.content, result.tool_calls
+
+ def parse_full(self, text: str) -> HermesParseResult:
+ return self._parse_full(text)
diff --git a/backend/python/tinygrad/tool_parsers/llama3_json.py b/backend/python/tinygrad/tool_parsers/llama3_json.py
new file mode 100644
index 000000000..ee357519a
--- /dev/null
+++ b/backend/python/tinygrad/tool_parsers/llama3_json.py
@@ -0,0 +1,86 @@
+"""Llama 3.1 / 3.2 / 3.3 JSON tool-call parser.
+
+Meta's Llama 3.1+ instruct chat templates emit tool calls in two broadly
+compatible shapes:
+
+ 1. With the `<|python_tag|>` lead-in:
+ <|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}}
+ 2. As a bare JSON object (or list of objects) at the end of the turn.
+
+We also handle multi-call shapes where the model emits several JSON objects
+separated by `;` or newlines, and JSON arrays `[{...}, {...}]`. The key field
+for Llama 3 is historically `parameters` (older docs) but recent checkpoints
+also emit `arguments` — accept either.
+"""
+from __future__ import annotations
+
+import json
+import re
+from dataclasses import dataclass
+
+from .base import ToolCall, ToolParser, register
+
+_PYTHON_TAG = "<|python_tag|>"
+_JSON_OBJECT_RE = re.compile(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", re.DOTALL)
+
+
+def _coerce_call(obj: object, index: int) -> ToolCall | None:
+ if not isinstance(obj, dict):
+ return None
+ name = obj.get("name")
+ if not isinstance(name, str):
+ return None
+ args = obj.get("arguments", obj.get("parameters", {}))
+ args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
+ return ToolCall(index=index, name=name, arguments=args_str)
+
+
+@register
+class Llama3JsonToolParser(ToolParser):
+ name = "llama3_json"
+
+ def parse(self, text: str) -> tuple[str, list[ToolCall]]:
+ calls: list[ToolCall] = []
+
+ # Strip <|python_tag|> segments first — each segment is one tool call
+ # body. The content after the final python_tag (if any) is the call.
+ remaining = text
+ if _PYTHON_TAG in text:
+ head, *tails = text.split(_PYTHON_TAG)
+ remaining = head
+ for tail in tails:
+ parsed = _try_parse(tail.strip(), len(calls))
+ calls.extend(parsed)
+
+ # Any JSON objects / arrays left in `remaining` count as tool calls too
+ # if they parse to a {"name": ..., "arguments": ...} shape.
+ for match in _JSON_OBJECT_RE.finditer(remaining):
+ parsed = _try_parse(match.group(0), len(calls))
+ if parsed:
+ calls.extend(parsed)
+ remaining = remaining.replace(match.group(0), "", 1)
+
+ content = remaining.strip()
+ return content, calls
+
+
+def _try_parse(blob: str, start_index: int) -> list[ToolCall]:
+ """Parse a fragment that may be a JSON object or a JSON array of objects."""
+ blob = blob.strip().rstrip(";")
+ if not blob:
+ return []
+ try:
+ obj = json.loads(blob)
+ except json.JSONDecodeError:
+ return []
+ if isinstance(obj, dict):
+ call = _coerce_call(obj, start_index)
+ return [call] if call else []
+ if isinstance(obj, list):
+ calls: list[ToolCall] = []
+ for i, item in enumerate(obj):
+ c = _coerce_call(item, start_index + i)
+ if c:
+ calls.append(c)
+ return calls
+ return []
diff --git a/backend/python/tinygrad/tool_parsers/mistral.py b/backend/python/tinygrad/tool_parsers/mistral.py
new file mode 100644
index 000000000..2a8891313
--- /dev/null
+++ b/backend/python/tinygrad/tool_parsers/mistral.py
@@ -0,0 +1,56 @@
+"""Mistral / Mixtral tool-call parser.
+
+Mistral Nemo / Small / Large Instruct emit tool calls prefixed with the
+`[TOOL_CALLS]` control token, followed by a JSON array:
+
+ [TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}}]
+
+Multiple calls live inside the same array. Any text before `[TOOL_CALLS]` is
+normal assistant content and should surface to the user.
+"""
+from __future__ import annotations
+
+import json
+import re
+
+from .base import ToolCall, ToolParser, register
+
+_MARKER = "[TOOL_CALLS]"
+_JSON_ARRAY_RE = re.compile(r"\[\s*(?:\{.*?\}\s*,?\s*)+\]", re.DOTALL)
+
+
+@register
+class MistralToolParser(ToolParser):
+ name = "mistral"
+
+ def parse(self, text: str) -> tuple[str, list[ToolCall]]:
+ if _MARKER not in text:
+ return text.strip(), []
+
+ head, tail = text.split(_MARKER, 1)
+ content = head.strip()
+
+ match = _JSON_ARRAY_RE.search(tail)
+ if not match:
+ return content, []
+
+ try:
+ arr = json.loads(match.group(0))
+ except json.JSONDecodeError:
+ return content, []
+
+ if not isinstance(arr, list):
+ return content, []
+
+ calls: list[ToolCall] = []
+ for i, obj in enumerate(arr):
+ if not isinstance(obj, dict):
+ continue
+ name = obj.get("name")
+ if not isinstance(name, str):
+ continue
+ args = obj.get("arguments", {})
+ args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
+ calls.append(ToolCall(index=i, name=name, arguments=args_str))
+
+ return content, calls
diff --git a/backend/python/tinygrad/tool_parsers/qwen3_xml.py b/backend/python/tinygrad/tool_parsers/qwen3_xml.py
new file mode 100644
index 000000000..e0dcbf848
--- /dev/null
+++ b/backend/python/tinygrad/tool_parsers/qwen3_xml.py
@@ -0,0 +1,74 @@
+"""Qwen 3 XML tool-call parser.
+
+Qwen 3 Instruct emits tool calls wrapped in a two-level tag structure:
+
+
+
+
+ Paris
+
+
+ celsius
+
+
+
+
+Parameter values are raw text — we treat them as strings unless they look
+like JSON (in which case we try to parse so numbers / booleans round-trip
+cleanly). Qwen 3 also supports `...` reasoning blocks before
+the tool call — these are captured via the shared Hermes convention.
+"""
+from __future__ import annotations
+
+import json
+import re
+
+from .base import ToolCall, ToolParser, register
+
+_TOOL_CALL_RE = re.compile(r"(.*?)", re.DOTALL)
+_FUNCTION_RE = re.compile(r"]+)>(.*?)", re.DOTALL)
+_PARAMETER_RE = re.compile(r"]+)>(.*?)", re.DOTALL)
+_THINK_RE = re.compile(r"(.*?)", re.DOTALL)
+
+
+def _maybe_json(value: str):
+ value = value.strip()
+ if not value:
+ return value
+ if value[0] in "{[\"" or value in ("true", "false", "null") or value.lstrip("-").replace(".", "", 1).isdigit():
+ try:
+ return json.loads(value)
+ except json.JSONDecodeError:
+ return value
+ return value
+
+
+@register
+class Qwen3XmlToolParser(ToolParser):
+ name = "qwen3_xml"
+
+ def parse(self, text: str) -> tuple[str, list[ToolCall]]:
+ # Strip reasoning blocks from the user-visible content.
+ stripped = _THINK_RE.sub("", text)
+
+ calls: list[ToolCall] = []
+ for match in _TOOL_CALL_RE.finditer(stripped):
+ body = match.group(1)
+ fn_match = _FUNCTION_RE.search(body)
+ if not fn_match:
+ continue
+ name = fn_match.group(1).strip()
+ params_body = fn_match.group(2)
+
+ params: dict[str, object] = {}
+ for pm in _PARAMETER_RE.finditer(params_body):
+ params[pm.group(1).strip()] = _maybe_json(pm.group(2))
+
+ calls.append(ToolCall(
+ index=len(calls),
+ name=name,
+ arguments=json.dumps(params, ensure_ascii=False),
+ ))
+
+ content = _TOOL_CALL_RE.sub("", stripped).strip()
+ return content, calls
diff --git a/backend/python/tinygrad/vendor/__init__.py b/backend/python/tinygrad/vendor/__init__.py
new file mode 100644
index 000000000..ea1cc0be6
--- /dev/null
+++ b/backend/python/tinygrad/vendor/__init__.py
@@ -0,0 +1,6 @@
+"""Vendored upstream tinygrad reference code (MIT-licensed).
+
+Source: https://github.com/tinygrad/tinygrad
+These files are not part of the `tinygrad` pip package (the `extra/` tree is
+excluded from `pyproject.toml` `packages`), so we carry a pinned copy here.
+"""
diff --git a/backend/python/tinygrad/vendor/audio_helpers.py b/backend/python/tinygrad/vendor/audio_helpers.py
new file mode 100644
index 000000000..b6951eb66
--- /dev/null
+++ b/backend/python/tinygrad/vendor/audio_helpers.py
@@ -0,0 +1,83 @@
+# Vendored verbatim from tinygrad examples/audio_helpers.py (MIT license).
+# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/audio_helpers.py
+# Copyright (c) 2023- the tinygrad authors
+# SPDX-License-Identifier: MIT
+from typing import Optional
+from tinygrad import Tensor
+from tinygrad.dtype import DTypeLike, dtypes
+import math
+
+# rewritten from numpy
+def rfftfreq(n: int, d: float = 1.0, device=None) -> Tensor:
+ val = 1.0 / (n * d)
+ N = n // 2 + 1
+ results = Tensor.arange(N, device=device)
+ return results * val
+
+# just like in librosa
+def fft_frequencies(sr: float, n_fft: int) -> Tensor:
+ return rfftfreq(n=n_fft, d=1.0 / sr)
+
+def hz_to_mel(freq: Tensor) -> Tensor:
+ # linear part
+ f_min = 0.0
+ f_sp = 200.0 / 3
+ mels = (freq - f_min) / f_sp
+
+ # log-scale part
+ min_log_hz = 1000.0 # beginning of log region (Hz)
+ mask = freq >= min_log_hz
+ return mask.where(((min_log_hz - f_min) / f_sp) + (freq / min_log_hz).log() / (math.log(6.4) / 27.0), mels)
+
+def mel_to_hz(mels: Tensor) -> Tensor:
+ # linear scale
+ f_min = 0.0
+ f_sp = 200.0 / 3
+ freqs = f_min + f_sp * mels
+
+ # nonlinear scale
+ min_log_hz = 1000.0 # beginning of log region (Hz)
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
+ logstep = math.log(6.4) / 27.0 # step size for log region
+
+ log_t = mels >= min_log_mel
+ freqs = log_t.where(min_log_hz * ((logstep * (mels - min_log_mel)).exp()), freqs)
+ return freqs
+
+def mel_frequencies(n_mels: int = 128, *, fmin: float = 0.0, fmax: float = 11025.0) -> Tensor:
+ # center freqs of mel bands - uniformly spaced between limits
+ min_max_mel = hz_to_mel(Tensor([fmin, fmax]))
+
+ mels = Tensor.linspace(min_max_mel[0], min_max_mel[1], n_mels)
+ hz = mel_to_hz(mels)
+ return hz
+
+def mel(
+ *,
+ sr: float,
+ n_fft: int,
+ n_mels: int = 128,
+ fmin: float = 0.0,
+ fmax: Optional[float] = None,
+ dtype: DTypeLike = dtypes.default_float,
+) -> Tensor:
+ if fmax is None:
+ fmax = float(sr) / 2
+
+ n_mels = int(n_mels)
+
+ fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) # center freqs of each FFT bin
+ mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax) # center freqs of mel bands
+
+ fdiff = mel_f[1:] - mel_f[:-1]
+ ramps = mel_f[None].T.expand(-1, fftfreqs.shape[-1]) - fftfreqs
+
+ lower = -ramps[:n_mels] / fdiff[:n_mels][None].T
+ upper = ramps[2 : n_mels + 2] / fdiff[1 : n_mels + 1][None].T
+ weights = lower.minimum(upper).maximum(0)
+
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
+ weights *= enorm[:, None]
+
+ return weights
diff --git a/backend/python/tinygrad/vendor/clip.py b/backend/python/tinygrad/vendor/clip.py
new file mode 100644
index 000000000..3c7da3141
--- /dev/null
+++ b/backend/python/tinygrad/vendor/clip.py
@@ -0,0 +1,484 @@
+# Vendored verbatim from tinygrad extra/models/clip.py (MIT license).
+# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/clip.py
+# Copyright (c) 2023- the tinygrad authors
+# SPDX-License-Identifier: MIT
+from tinygrad import Tensor, dtypes
+from tinygrad.helpers import fetch
+from tinygrad.nn import Linear, LayerNorm, Embedding, Conv2d
+
+from typing import List, Optional, Union, Tuple, Dict
+from abc import ABC, abstractmethod
+from functools import lru_cache
+import numpy as np
+import re, gzip
+
+# Allow for monkeypatching for mlperf.
+gelu = Tensor.gelu
+
+@lru_cache()
+def default_bpe():
+ # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
+ return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
+
+class Tokenizer:
+ """
+ Namespace for CLIP Text Tokenizer components.
+ """
+
+ @staticmethod
+ def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ return set(zip(word, word[1:]))
+ @staticmethod
+ def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+ @staticmethod
+ def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a significant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+ class ClipTokenizer:
+ def __init__(self, version=None):
+ self.byte_encoder, self.version = Tokenizer.bytes_to_unicode(), version
+ merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(Tokenizer.bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ if self.version == "sd_mlperf_v5_0":
+ import regex
+ vocab.extend(['', ''])
+ self.cache = {'': '', '': ''}
+ self.pat = regex.compile(r"""||'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
+ else:
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = Tokenizer.get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except Exception:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ pairs = Tokenizer.get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
+ bpe_tokens: List[int] = []
+ if self.version == "sd_mlperf_v5_0":
+ import regex, ftfy, html
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text)).strip()
+ text = Tokenizer.whitespace_clean(text).lower()
+ re_module = regex
+ else:
+ text = Tokenizer.whitespace_clean(text.strip()).lower()
+ re_module = re
+
+ for token in re_module.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ # Truncation, keeping two slots for start and end tokens.
+ if len(bpe_tokens) > 75:
+ bpe_tokens = bpe_tokens[:75]
+ return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
+
+
+class Embedder(ABC):
+ input_key: str
+ @abstractmethod
+ def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
+ pass
+
+
+class Closed:
+ """
+ Namespace for OpenAI CLIP model components.
+ """
+ class ClipMlp:
+ def __init__(self):
+ self.fc1 = Linear(768, 3072)
+ self.fc2 = Linear(3072, 768)
+
+ def __call__(self, h:Tensor) -> Tensor:
+ h = self.fc1(h)
+ h = h.quick_gelu()
+ h = self.fc2(h)
+ return h
+
+ class ClipAttention:
+ def __init__(self):
+ self.embed_dim = 768
+ self.num_heads = 12
+ self.head_dim = self.embed_dim // self.num_heads
+ self.k_proj = Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = Linear(self.embed_dim, self.embed_dim)
+
+ def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
+ bsz, tgt_len, embed_dim = hidden_states.shape
+ q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
+ q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
+ attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
+ return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
+
+ class ClipEncoderLayer:
+ def __init__(self):
+ self.self_attn = Closed.ClipAttention()
+ self.layer_norm1 = LayerNorm(768)
+ self.mlp = Closed.ClipMlp()
+ self.layer_norm2 = LayerNorm(768)
+
+ def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(hidden_states, causal_attention_mask)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+ class ClipTextEmbeddings:
+ def __init__(self):
+ self.token_embedding = Embedding(49408, 768)
+ self.position_embedding = Embedding(77, 768)
+
+ def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
+ return self.token_embedding(input_ids) + self.position_embedding(position_ids)
+
+ class ClipEncoder:
+ def __init__(self, layer_count:int=12):
+ self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
+
+ def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
+ # the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
+ layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
+ for l in layers:
+ x = l(x, causal_attention_mask)
+ return x
+
+ class ClipTextTransformer:
+ def __init__(self, ret_layer_idx:Optional[int]=None):
+ self.embeddings = Closed.ClipTextEmbeddings()
+ self.encoder = Closed.ClipEncoder()
+ self.final_layer_norm = LayerNorm(768)
+ self.ret_layer_idx = ret_layer_idx
+
+ def __call__(self, input_ids:Tensor) -> Tensor:
+ x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
+ x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
+ return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
+
+ class ClipTextModel:
+ def __init__(self, ret_layer_idx:Optional[int]):
+ self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
+
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
+class FrozenClosedClipEmbedder(Embedder):
+ def __init__(self, ret_layer_idx:Optional[int]=None):
+ self.tokenizer = Tokenizer.ClipTokenizer()
+ self.transformer = Closed.ClipTextModel(ret_layer_idx)
+ self.input_key = "txt"
+
+ def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
+ if isinstance(texts, str): texts = [texts]
+ assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
+ tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
+ return self.transformer.text_model(tokens.reshape(len(texts),-1))
+
+
+class Open:
+ """
+ Namespace for OpenCLIP model components.
+ """
+ class MultiheadAttention:
+ def __init__(self, dims:int, n_heads:int):
+ self.dims = dims
+ self.n_heads = n_heads
+ self.d_head = self.dims // self.n_heads
+
+ self.in_proj_bias = Tensor.empty(3*dims)
+ self.in_proj_weight = Tensor.empty(3*dims, dims)
+ self.out_proj = Linear(dims, dims)
+
+ def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
+ T,B,C = x.shape
+
+ proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
+ proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0, -2)
+
+ q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)]
+
+ attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+ attn_output = attn_output.permute(2, 0, 1, 3).reshape(T, B, C)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+ class Mlp:
+ def __init__(self, dims, hidden_dims):
+ self.c_fc = Linear(dims, hidden_dims)
+ self.c_proj = Linear(hidden_dims, dims)
+ self.gelu = gelu
+
+ def __call__(self, x:Tensor) -> Tensor:
+ return x.sequential([self.c_fc, self.gelu, self.c_proj])
+
+ # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
+ class ResidualAttentionBlock:
+ def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
+ self.ln_1 = LayerNorm(dims)
+ self.attn = Open.MultiheadAttention(dims, n_heads)
+
+ self.ln_2 = LayerNorm(dims)
+ self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
+
+ def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None, transpose:bool=False) -> Tensor:
+ q_x = self.ln_1(x)
+ attn_out = self.attn(q_x.transpose(0, 1) if transpose else q_x, attn_mask=attn_mask)
+ attn_out = attn_out.transpose(0, 1) if transpose else attn_out
+ x = x + attn_out
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+ # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
+ class ClipTransformer:
+ def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
+ self.resblocks = [
+ Open.ResidualAttentionBlock(dims, n_heads, mlp_ratio) for _ in range(layers)
+ ]
+
+ def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
+ for r in self.resblocks:
+ x = r(x, attn_mask=attn_mask, transpose=True)
+ return x
+
+ # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
+ # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
+ class ClipTextTransformer:
+ def __init__(self, width:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
+ self.token_embedding = Embedding(vocab_size, width)
+ self.positional_embedding = Tensor.empty(ctx_length, width)
+ self.transformer = Open.ClipTransformer(width, layers, n_heads)
+ self.ln_final = LayerNorm(width)
+ self.text_projection = Tensor.empty(width, width)
+ self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
+
+ def __call__(self, text:Tensor) -> Tensor:
+ seq_len = text.shape[1]
+
+ x = self.token_embedding(text)
+ x = x + self.positional_embedding[:seq_len]
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ x = self.ln_final(x)
+
+ pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
+ return pooled
+
+ class ClipVisionTransformer:
+ def __init__(self, width:int, layers:int, d_head:int, image_size:int, patch_size:int):
+ grid_size = image_size // patch_size
+ n_heads = width // d_head
+ assert n_heads * d_head == width
+
+ self.conv1 = Conv2d(3, width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ self.class_embedding = Tensor.empty(width)
+ self.positional_embedding = Tensor.empty(grid_size * grid_size + 1, width)
+ self.transformer = Open.ClipTransformer(width, layers, n_heads)
+ self.ln_pre = LayerNorm(width)
+ self.ln_post = LayerNorm(width)
+ self.proj = Tensor.empty(width, 1024)
+
+ def __call__(self, x:Tensor) -> Tensor:
+ x = self.conv1(x)
+ x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
+ x = self.class_embedding.reshape(1, 1, -1).expand(x.shape[0], 1, -1).cat(x, dim=1)
+ x = x + self.positional_embedding
+
+ x = self.ln_pre(x)
+ x = self.transformer(x)
+ x = self.ln_post(x)
+
+ pooled = x[:, 0] @ self.proj
+ return pooled
+
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
+class FrozenOpenClipEmbedder(Embedder):
+ def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False, clip_tokenizer_version=None):
+ self.tokenizer = Tokenizer.ClipTokenizer(version=clip_tokenizer_version)
+ self.model = Open.ClipTextTransformer(dims, n_heads, layers)
+ self.return_pooled = return_pooled
+ self.input_key = "txt"
+ self.ln_penultimate = ln_penultimate
+
+ def tokenize(self, text:str, device:Optional[str]=None) -> Tensor:
+ return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int32, device=device).reshape(1,-1)
+
+ def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
+ for r in self.model.transformer.resblocks:
+ x, penultimate = r(x, attn_mask=attn_mask), x
+ return x.permute(1, 0, 2), penultimate.permute(1, 0, 2)
+
+ def embed_tokens(self, tokens:Tensor) -> Union[Tensor,Tuple[Tensor,...]]:
+ x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
+ x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+
+ if self.ln_penultimate:
+ penultimate = self.model.ln_final(penultimate)
+
+ if self.return_pooled:
+ x = self.model.ln_final(x)
+ index = tokens.argmax(axis=-1).reshape(-1,1,1).expand(x.shape[0],1,x.shape[-1])
+ pooled = x.gather(1, index).squeeze(1) @ self.model.text_projection
+ return penultimate, pooled
+ else:
+ return penultimate
+
+ def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
+ if isinstance(texts, str): texts = [texts]
+ assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
+ tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0)
+ return self.embed_tokens(tokens)
+
+
+clip_configs: Dict = {
+ "ViT-H-14": {
+ "dims": 1024,
+ "vision_cfg": {
+ "width": 1280,
+ "layers": 32,
+ "d_head": 80,
+ "image_size": 224,
+ "patch_size": 14,
+ },
+ "text_cfg": {
+ "width": 1024,
+ "n_heads": 16,
+ "layers": 24,
+ "ctx_length": 77,
+ "vocab_size": 49408,
+ },
+ "return_pooled": False,
+ "ln_penultimate": True,
+ }
+}
+
+class OpenClipEncoder:
+ def __init__(self, dims:int, text_cfg:Dict, vision_cfg:Dict, **_):
+ self.visual = Open.ClipVisionTransformer(**vision_cfg)
+
+ text = Open.ClipTextTransformer(**text_cfg)
+ self.transformer = text.transformer
+ self.token_embedding = text.token_embedding
+ self.positional_embedding = text.positional_embedding
+ self.ln_final = text.ln_final
+ self.text_projection = text.text_projection
+
+ self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
+ self.mean = Tensor([0.48145466, 0.45782750, 0.40821073]).reshape(-1, 1, 1)
+ self.std = Tensor([0.26862954, 0.26130258, 0.27577711]).reshape(-1, 1, 1)
+
+ # TODO:
+ # Should be doable in pure tinygrad, would just require some work and verification.
+ # This is very desirable since it would allow for full generation->evaluation in a single JIT call.
+ def prepare_image(self, image) -> Tensor:
+ from PIL import Image
+ SIZE = 224
+ w, h = image.size
+ scale = min(SIZE / h, SIZE / w)
+ image = image.resize((max(int(w*scale),SIZE),max(int(h*scale),SIZE)), Image.Resampling.BICUBIC)
+ w, h = image.size
+ if w > SIZE:
+ left = (w - SIZE) // 2
+ image = image.crop((left, left+SIZE, 0, SIZE))
+ elif h > SIZE:
+ top = (h - SIZE) // 2
+ image = image.crop((0, SIZE, top, top+SIZE))
+
+ x = Tensor(np.array(image.convert('RGB')), device=self.std.device)
+ x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0
+ return (x - self.mean) / self.std
+
+ def encode_tokens(self, tokens:Tensor) -> Tensor:
+ x = self.token_embedding(tokens)
+ x = x + self.positional_embedding
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ x = self.ln_final(x)
+ x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)]
+ x = x @ self.text_projection
+ return x
+
+ def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor:
+ image_features: Tensor = self.visual(image)
+ image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
+
+ text_features = self.encode_tokens(tokens)
+ text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
+
+ return (image_features * text_features).sum(axis=-1)
diff --git a/backend/python/tinygrad/vendor/llama.py b/backend/python/tinygrad/vendor/llama.py
new file mode 100644
index 000000000..445f68b9c
--- /dev/null
+++ b/backend/python/tinygrad/vendor/llama.py
@@ -0,0 +1,294 @@
+# Vendored from tinygrad extra/models/llama.py (MIT license).
+# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/llama.py
+#
+# Local modification: Attention / TransformerBlock / Transformer accept an
+# optional `qkv_bias` flag so the same module can host Qwen2-style models that
+# use bias on the Q/K/V projections (Llama 3 has no bias). Changes are marked
+# with `# LOCALAI`.
+#
+# Copyright (c) 2023- the tinygrad authors
+# SPDX-License-Identifier: MIT
+from typing import Union, Optional, Any
+import collections, math
+from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
+from tinygrad.helpers import getenv, DEBUG
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
+ freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
+ freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
+ return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2)
+
+
+def complex_mult(A, c, d):
+ a, b = A[..., 0:1], A[..., 1:2]
+ ro = a * c - b * d
+ co = a * d + b * c
+ return ro.cat(co, dim=-1)
+
+
+def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
+ assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
+ xq = xq.reshape(*xq.shape[0:-1], -1, 2)
+ xk = xk.reshape(*xk.shape[0:-1], -1, 2)
+ assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
+ c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
+ xq_out = complex_mult(xq, c, d)
+ xk_out = complex_mult(xk, c, d)
+ return xq_out.flatten(3), xk_out.flatten(3)
+
+
+def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
+ bs, seqlen, n_kv_heads, head_dim = x.shape
+ if n_rep == 1:
+ return x
+ return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
+
+
+class Attention:
+ # LOCALAI: added qkv_bias
+ def __init__(self, dim, n_heads, n_kv_heads=None, max_context=0, linear=nn.Linear, qk_norm: float | None = None, qkv_bias: bool = False):
+ self.n_heads = n_heads
+ self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
+ self.head_dim = dim // n_heads
+ self.n_rep = self.n_heads // self.n_kv_heads
+ self.max_context = max_context
+
+ self.wq = linear(dim, self.n_heads * self.head_dim, bias=qkv_bias)
+ self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias)
+ self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias)
+ self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
+
+ self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
+ self.k_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
+
+ def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
+
+ if self.q_norm is not None and self.k_norm is not None:
+ xq = self.q_norm(xq)
+ xk = self.k_norm(xk)
+
+ if x.dtype == dtypes.bfloat16:
+ xq, xk = xq.contiguous_backward(), xk.contiguous_backward()
+
+ xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
+ xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
+ xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
+
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
+ bsz, seqlen, _, _ = xq.shape
+
+ if self.max_context:
+ if not hasattr(self, "cache_kv"):
+ self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
+ if isinstance(x.device, tuple):
+ self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+
+ assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
+ self.cache_kv[:, :, start_pos:start_pos + seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
+
+ keys = self.cache_kv[0, :, 0:start_pos + seqlen, :, :]
+ values = self.cache_kv[1, :, 0:start_pos + seqlen, :, :]
+ else:
+ assert start_pos == 0
+ keys, values = xk, xv
+
+ if self.max_context:
+ keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
+ xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
+ attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
+ else:
+ xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
+ attn = xq.scaled_dot_product_attention(keys, values, is_causal=True, enable_gqa=True).transpose(1, 2)
+
+ attn = attn.reshape(bsz, seqlen, -1)
+ return self.wo(attn)
+
+
+class FeedForward:
+ def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
+ self.w1 = linear(dim, hidden_dim, bias=False)
+ self.w2 = linear(hidden_dim, dim, bias=False)
+ self.w3 = linear(dim, hidden_dim, bias=False)
+
+ def __call__(self, x: Tensor) -> Tensor:
+ w1 = self.w1(x).silu()
+ w3 = self.w3(x.contiguous_backward())
+ return self.w2(w1 * w3)
+
+
+class TransformerBlock:
+ # LOCALAI: added qkv_bias
+ def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int,
+ linear=nn.Linear, feed_forward=FeedForward, qk_norm=None, qkv_bias: bool = False):
+ self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear, qk_norm, qkv_bias=qkv_bias)
+ self.feed_forward = feed_forward(dim, hidden_dim, linear)
+ self.attention_norm = nn.RMSNorm(dim, norm_eps)
+ self.ffn_norm = nn.RMSNorm(dim, norm_eps)
+
+ def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
+ h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
+ return (h + self.feed_forward(self.ffn_norm(h))).contiguous().contiguous_backward()
+
+
+def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
+ assert logits.ndim == 1, "only works on 1d tensors"
+ assert 0 <= p <= 1, "p must be between 0 and 1"
+ assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
+
+ if temp < 1e-6:
+ return logits.argmax()
+
+ logits = logits.to(Device.DEFAULT)
+
+ if af or ap:
+ if not hasattr(sample, "alpha_counter"):
+ setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
+ logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
+
+ logits = (logits != logits).where(-float("inf"), logits)
+
+ t = (logits / temp).softmax()
+
+ counter = Tensor.arange(t.numel(), device=logits.device).contiguous()
+ counter2 = Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
+
+ if k:
+ output = Tensor.zeros(k, device=logits.device).contiguous()
+ output_indices = Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
+ for i in range(k):
+ t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
+ output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
+ output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
+ t = (counter == t_argmax).where(0, t)
+
+ output_cumsum = output[::-1].cumsum()[::-1] + t.sum()
+ output = (output_cumsum >= (1 - p)) * output
+ output_indices = (output_cumsum >= (1 - p)) * output_indices
+
+ output_idx = output.multinomial()
+ output_token = output_indices[output_idx]
+ else:
+ output_token = t.multinomial()
+
+ if af or ap:
+ sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
+
+ return output_token
+
+
+class Transformer:
+ # LOCALAI: added qkv_bias
+ def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_layers: int, norm_eps: float, vocab_size,
+ linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000,
+ max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None, disable_kv_cache=False,
+ qkv_bias: bool = False):
+ self.layers = [
+ TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps,
+ 0 if disable_kv_cache else max_context, linear,
+ feed_forward=feed_forward, qk_norm=qk_norm, qkv_bias=qkv_bias)
+ for _ in range(n_layers)
+ ]
+ self.norm = nn.RMSNorm(dim, norm_eps)
+ self.tok_embeddings = embedding(vocab_size, dim)
+ self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
+ self.max_context = max_context
+ self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous().requires_grad_(False)
+ self.forward_jit = TinyJit(self.forward) if jit else None
+
+ def forward(self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
+ _bsz, seqlen = tokens.shape
+ h = self.tok_embeddings(tokens).contiguous()
+ freqs_cis = self.freqs_cis.cast(h.dtype)[:, start_pos:start_pos + seqlen, :, :, :]
+
+ if self.max_context != 0 and seqlen > 1:
+ mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1)
+ else:
+ mask = None
+ for layer in self.layers:
+ h = layer(h, start_pos, freqs_cis, mask)
+ logits = self.output(self.norm(h).contiguous().contiguous_backward()).contiguous_backward()
+ if math.isnan(temperature):
+ return logits
+
+ return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
+
+ def __call__(self, tokens: Tensor, start_pos: int, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
+ if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
+ return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context - 1).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
+ return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
+
+ # LOCALAI: extract last hidden state for embeddings. Skips the LM head and
+ # the causal-mask branch is left intact so the pooling sees the full sequence.
+ def embed(self, tokens: Tensor) -> Tensor:
+ _bsz, seqlen = tokens.shape
+ h = self.tok_embeddings(tokens).contiguous()
+ freqs_cis = self.freqs_cis.cast(h.dtype)[:, 0:seqlen, :, :, :]
+ mask = Tensor.full((1, 1, seqlen, seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(1) if seqlen > 1 else None
+ for layer in self.layers:
+ h = layer(h, 0, freqs_cis, mask)
+ return self.norm(h)
+
+
+def convert_from_huggingface(weights: dict[str, Tensor], n_layers: int, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
+ def permute(v: Tensor, n_heads: int):
+ return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1] if len(v.shape) > 1 else 1).transpose(1, 2).reshape(*v.shape[:2])
+
+ keymap = {
+ "model.embed_tokens.weight": "tok_embeddings.weight",
+ **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)},
+ **{f"model.layers.{l}.self_attn.{x}_norm.weight": f"layers.{l}.attention.{x}_norm.weight" for x in ["q", "k"] for l in range(n_layers)},
+ **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(n_layers)},
+ **{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(n_layers)},
+ **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)},
+ **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)},
+ **{f"model.layers.{l}.mlp.gate.weight": f"layers.{l}.feed_forward.gate.weight" for l in range(n_layers)},
+ "model.norm.weight": "norm.weight",
+ "lm_head.weight": "output.weight",
+ }
+ sd = {}
+ experts = collections.defaultdict(dict)
+ for k, v in weights.items():
+ if ".rotary_emb." in k:
+ continue
+ v = v.to(Device.DEFAULT)
+ if "model.layers" in k:
+ if ("q_proj" in k or "q_norm" in k) and permute_layers:
+ v = permute(v, n_heads)
+ elif ("k_proj" in k or "k_norm" in k) and permute_layers:
+ v = permute(v, n_kv_heads)
+ if '.mlp.experts.' in k:
+ _, _, layer, _, _, expert, name, _ = k.split('.')
+ experts[f'layers.{layer}.feed_forward.{name}'][int(expert)] = v
+ continue
+ sd[keymap[k]] = v
+ for k, v in experts.items():
+ sd[k] = Tensor.stack(*[v[i] for i in range(len(v))])
+
+ if "output.weight" not in sd and "tok_embeddings.weight" in sd:
+ sd["output.weight"] = sd["tok_embeddings.weight"]
+
+ return sd
+
+
+def convert_from_gguf(weights: dict[str, Tensor], n_layers: int):
+ keymap = {
+ "token_embd.weight": "tok_embeddings.weight",
+ **{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)},
+ **{f"blk.{l}.attn_{x}.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v"] for l in range(n_layers)},
+ **{f"blk.{l}.attn_{x}.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v"] for l in range(n_layers)},
+ **{f"blk.{l}.attn_output.weight": f"layers.{l}.attention.wo.weight" for l in range(n_layers)},
+ **{f"blk.{l}.ffn_norm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)},
+ **{f"blk.{l}.ffn_{x}.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)},
+ "output_norm.weight": "norm.weight",
+ "rope_freqs.weight": "rope_freqs.weight",
+ }
+ sd = {keymap[k]: v for k, v in weights.items() if k in keymap}
+ if "output.weight" not in sd and "token_embd.weight" in weights:
+ sd["output.weight"] = weights["token_embd.weight"]
+ return sd
+
+
+def fix_bf16(weights: dict[Any, Tensor]):
+ return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
diff --git a/backend/python/tinygrad/vendor/stable_diffusion.py b/backend/python/tinygrad/vendor/stable_diffusion.py
new file mode 100644
index 000000000..5a993332d
--- /dev/null
+++ b/backend/python/tinygrad/vendor/stable_diffusion.py
@@ -0,0 +1,232 @@
+# Adapted from tinygrad examples/stable_diffusion.py (MIT license).
+# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/stable_diffusion.py
+# Copyright (c) 2023- the tinygrad authors
+# SPDX-License-Identifier: MIT
+#
+# Local modifications: removed the MLPerf training branch (pulls
+# examples/mlperf/initializers which we don't vendor) and the __main__
+# argparse / fetch / profile blocks. Kept the core classes so the LocalAI
+# tinygrad backend can instantiate and drive Stable Diffusion v1.x from a
+# single checkpoint path.
+from collections import namedtuple
+from typing import Any, Dict
+
+import numpy as np
+from tinygrad import Tensor, dtypes
+from tinygrad.nn import Conv2d, GroupNorm
+
+from . import clip as clip_mod
+from . import unet as unet_mod
+from .clip import Closed, Tokenizer
+from .unet import UNetModel
+
+
+class AttnBlock:
+ def __init__(self, in_channels):
+ self.norm = GroupNorm(32, in_channels)
+ self.q = Conv2d(in_channels, in_channels, 1)
+ self.k = Conv2d(in_channels, in_channels, 1)
+ self.v = Conv2d(in_channels, in_channels, 1)
+ self.proj_out = Conv2d(in_channels, in_channels, 1)
+
+ def __call__(self, x):
+ h_ = self.norm(x)
+ q, k, v = self.q(h_), self.k(h_), self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = [t.reshape(b, c, h * w).transpose(1, 2) for t in (q, k, v)]
+ h_ = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(b, c, h, w)
+ return x + self.proj_out(h_)
+
+
+class ResnetBlock:
+ def __init__(self, in_channels, out_channels=None):
+ self.norm1 = GroupNorm(32, in_channels)
+ self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
+ self.norm2 = GroupNorm(32, out_channels)
+ self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
+ self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else (lambda x: x)
+
+ def __call__(self, x):
+ h = self.conv1(self.norm1(x).swish())
+ h = self.conv2(self.norm2(h).swish())
+ return self.nin_shortcut(x) + h
+
+
+class Mid:
+ def __init__(self, block_in):
+ self.block_1 = ResnetBlock(block_in, block_in)
+ self.attn_1 = AttnBlock(block_in)
+ self.block_2 = ResnetBlock(block_in, block_in)
+
+ def __call__(self, x):
+ return x.sequential([self.block_1, self.attn_1, self.block_2])
+
+
+class Decoder:
+ def __init__(self):
+ sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
+ self.conv_in = Conv2d(4, 512, 3, padding=1)
+ self.mid = Mid(512)
+
+ arr = []
+ for i, s in enumerate(sz):
+ arr.append({"block": [ResnetBlock(s[1], s[0]), ResnetBlock(s[0], s[0]), ResnetBlock(s[0], s[0])]})
+ if i != 0:
+ arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
+ self.up = arr
+
+ self.norm_out = GroupNorm(32, 128)
+ self.conv_out = Conv2d(128, 3, 3, padding=1)
+
+ def __call__(self, x):
+ x = self.conv_in(x)
+ x = self.mid(x)
+ for l in self.up[::-1]:
+ for b in l['block']:
+ x = b(x)
+ if 'upsample' in l:
+ bs, c, py, px = x.shape
+ x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py * 2, px * 2)
+ x = l['upsample']['conv'](x)
+ x.realize()
+ return self.conv_out(self.norm_out(x).swish())
+
+
+class Encoder:
+ def __init__(self):
+ sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
+ self.conv_in = Conv2d(3, 128, 3, padding=1)
+
+ arr = []
+ for i, s in enumerate(sz):
+ arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]})
+ if i != 3:
+ arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))}
+ self.down = arr
+
+ self.mid = Mid(512)
+ self.norm_out = GroupNorm(32, 512)
+ self.conv_out = Conv2d(512, 8, 3, padding=1)
+
+ def __call__(self, x):
+ x = self.conv_in(x)
+ for l in self.down:
+ for b in l['block']:
+ x = b(x)
+ if 'downsample' in l:
+ x = l['downsample']['conv'](x)
+ x = self.mid(x)
+ return self.conv_out(self.norm_out(x).swish())
+
+
+class AutoencoderKL:
+ def __init__(self):
+ self.encoder = Encoder()
+ self.decoder = Decoder()
+ self.quant_conv = Conv2d(8, 8, 1)
+ self.post_quant_conv = Conv2d(4, 4, 1)
+
+ def __call__(self, x):
+ latent = self.encoder(x)
+ latent = self.quant_conv(latent)
+ latent = latent[:, 0:4]
+ latent = self.post_quant_conv(latent)
+ return self.decoder(latent)
+
+
+def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
+ betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ return Tensor(alphas_cumprod)
+
+
+# SD1.x UNet hyperparameters (same as upstream `unet_params`).
+UNET_PARAMS_SD1: Dict[str, Any] = {
+ "adm_in_ch": None,
+ "in_ch": 4,
+ "out_ch": 4,
+ "model_ch": 320,
+ "attention_resolutions": [4, 2, 1],
+ "num_res_blocks": 2,
+ "channel_mult": [1, 2, 4, 4],
+ "n_heads": 8,
+ "transformer_depth": [1, 1, 1, 1],
+ "ctx_dim": 768,
+ "use_linear": False,
+}
+
+
+class StableDiffusion:
+ """Stable Diffusion 1.x pipeline, adapted from tinygrad's reference example.
+
+ Drives the native CompVis `sd-v1-*.ckpt` checkpoint format (the only one
+ the vendored weight layout handles). For HuggingFace safetensors pipelines
+ the caller is expected to download / merge the `.ckpt` equivalent before
+ calling LoadModel.
+ """
+
+ def __init__(self):
+ self.alphas_cumprod = get_alphas_cumprod()
+ self.first_stage_model = AutoencoderKL()
+ self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
+ transformer=namedtuple("Transformer", ["text_model"])(text_model=Closed.ClipTextTransformer())
+ )
+ self.model = namedtuple("DiffusionModel", ["diffusion_model"])(
+ diffusion_model=UNetModel(**UNET_PARAMS_SD1)
+ )
+
+ # DDIM update step.
+ def _update(self, x, e_t, a_t, a_prev):
+ sqrt_one_minus_at = (1 - a_t).sqrt()
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ dir_xt = (1.0 - a_prev).sqrt() * e_t
+ return a_prev.sqrt() * pred_x0 + dir_xt
+
+ def _model_output(self, uncond, cond, latent, timestep, guidance):
+ latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, uncond.cat(cond, dim=0))
+ uncond_latent, cond_latent = latents[0:1], latents[1:2]
+ return uncond_latent + guidance * (cond_latent - uncond_latent)
+
+ def step(self, uncond, cond, latent, timestep, a_t, a_prev, guidance):
+ e_t = self._model_output(uncond, cond, latent, timestep, guidance)
+ return self._update(latent, e_t, a_t, a_prev).realize()
+
+ def decode(self, x):
+ x = self.first_stage_model.post_quant_conv(1 / 0.18215 * x)
+ x = self.first_stage_model.decoder(x)
+ x = (x + 1.0) / 2.0
+ x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
+ return x.cast(dtypes.uint8)
+
+ def encode_prompt(self, tokenizer, prompt: str):
+ ids = Tensor([tokenizer.encode(prompt)])
+ return self.cond_stage_model.transformer.text_model(ids).realize()
+
+
+def run_sd15(model: StableDiffusion, prompt: str, negative_prompt: str, steps: int, guidance: float, seed: int):
+ """Generate a single 512x512 image. Returns a (512,512,3) uint8 tensor."""
+ tokenizer = Tokenizer.ClipTokenizer()
+
+ context = model.encode_prompt(tokenizer, prompt)
+ uncond = model.encode_prompt(tokenizer, negative_prompt)
+
+ timesteps = list(range(1, 1000, 1000 // steps))
+ alphas = model.alphas_cumprod[Tensor(timesteps)]
+ alphas_prev = Tensor([1.0]).cat(alphas[:-1])
+
+ if seed is not None:
+ Tensor.manual_seed(seed)
+ latent = Tensor.randn(1, 4, 64, 64)
+
+ for index in range(len(timesteps) - 1, -1, -1):
+ timestep = timesteps[index]
+ tid = Tensor([index])
+ latent = model.step(
+ uncond, context, latent,
+ Tensor([timestep]),
+ alphas[tid], alphas_prev[tid],
+ Tensor([guidance]),
+ )
+
+ return model.decode(latent).realize()
diff --git a/backend/python/tinygrad/vendor/unet.py b/backend/python/tinygrad/vendor/unet.py
new file mode 100644
index 000000000..22a07d170
--- /dev/null
+++ b/backend/python/tinygrad/vendor/unet.py
@@ -0,0 +1,267 @@
+# Vendored verbatim from tinygrad extra/models/unet.py (MIT license).
+# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/unet.py
+# Copyright (c) 2023- the tinygrad authors
+# SPDX-License-Identifier: MIT
+from tinygrad import Tensor, dtypes, nn
+from tinygrad.device import is_dtype_supported
+from typing import Optional, Union, List, Any, Tuple, Callable
+import math
+
+# allow for monkeypatching
+Linear, Conv2d, GroupNorm, LayerNorm = nn.Linear, nn.Conv2d, nn.GroupNorm, nn.LayerNorm
+attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Tensor.gelu, dtypes.float16
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
+def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
+ half = dim // 2
+ freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
+ args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
+ out = Tensor.cat(args.cos(), args.sin(), dim=-1)
+ return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
+
+class ResBlock:
+ def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
+ self.in_layers = [
+ GroupNorm(num_groups, channels),
+ Tensor.silu,
+ Conv2d(channels, out_channels, 3, padding=1),
+ ]
+ self.emb_layers = [
+ Tensor.silu,
+ Linear(emb_channels, out_channels),
+ ]
+ self.out_layers = [
+ GroupNorm(num_groups, out_channels),
+ Tensor.silu,
+ lambda x: x, # needed for weights loading code to work
+ Conv2d(out_channels, out_channels, 3, padding=1),
+ ]
+ self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else (lambda x: x)
+
+ def __call__(self, x:Tensor, emb:Tensor) -> Tensor:
+ h = x.sequential(self.in_layers)
+ emb_out = emb.sequential(self.emb_layers)
+ h = h + emb_out.reshape(*emb_out.shape, 1, 1)
+ h = h.sequential(self.out_layers)
+ return self.skip_connection(x) + h
+
+class CrossAttention:
+ def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
+ self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
+ self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
+ self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
+ self.num_heads = n_heads
+ self.head_size = d_head
+ self.attn = attention
+ self.to_out = [Linear(n_heads*d_head, query_dim)]
+
+ def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
+ ctx = x if ctx is None else ctx
+ q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
+ q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
+ attention = self.attn(q, k, v).transpose(1,2)
+ h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
+ return h_.sequential(self.to_out)
+
+class GEGLU:
+ def __init__(self, dim_in:int, dim_out:int):
+ self.proj = Linear(dim_in, dim_out * 2)
+ self.gelu = gelu
+ self.dim_out = dim_out
+
+ def __call__(self, x:Tensor) -> Tensor:
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * self.gelu(gate)
+
+class FeedForward:
+ def __init__(self, dim:int, mult:int=4):
+ self.net: tuple[GEGLU, Callable, nn.Linear] = (
+ GEGLU(dim, dim*mult),
+ lambda x: x, # needed for weights loading code to work
+ Linear(dim*mult, dim)
+ )
+
+ def __call__(self, x:Tensor) -> Tensor:
+ return x.sequential(list(self.net))
+
+class BasicTransformerBlock:
+ def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
+ self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
+ self.ff = FeedForward(dim)
+ self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
+ self.norm1 = LayerNorm(dim)
+ self.norm2 = LayerNorm(dim)
+ self.norm3 = LayerNorm(dim)
+
+ def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
+ x = x + self.attn1(self.norm1(x))
+ x = x + self.attn2(self.norm2(x), ctx=ctx)
+ x = x + self.ff(self.norm3(x))
+ return x
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
+class SpatialTransformer:
+ def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1,
+ norm_eps:float=1e-5):
+ if isinstance(ctx_dim, int):
+ ctx_dim = [ctx_dim]*depth
+ else:
+ assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
+ self.norm = GroupNorm(32, channels, eps=norm_eps)
+ assert channels == n_heads * d_head
+ self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
+ self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
+ self.proj_out = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
+ self.use_linear = use_linear
+
+ def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ ops = [ (lambda z: z.reshape(b, c, h*w).permute(0,2,1)), (lambda z: self.proj_in(z)) ]
+ x = x.sequential(ops if self.use_linear else ops[::-1])
+ for block in self.transformer_blocks:
+ x = block(x, ctx=ctx)
+ ops = [ (lambda z: self.proj_out(z)), (lambda z: z.permute(0,2,1).reshape(b, c, h, w)) ]
+ x = x.sequential(ops if self.use_linear else ops[::-1])
+ return x + x_in
+
+class Downsample:
+ def __init__(self, channels:int):
+ self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
+
+ def __call__(self, x:Tensor) -> Tensor:
+ return self.op(x)
+
+class Upsample:
+ def __init__(self, channels:int):
+ self.conv = Conv2d(channels, channels, 3, padding=1)
+
+ def __call__(self, x:Tensor) -> Tensor:
+ bs,c,py,px = x.shape
+ z = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
+ return self.conv(z)
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
+class UNetModel:
+ def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int,
+ channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None,
+ n_heads:Optional[int]=None, num_groups:int=32, st_norm_eps:float=1e-5):
+ self.model_ch = model_ch
+ self.num_res_blocks = [num_res_blocks] * len(channel_mult)
+
+ self.attention_resolutions = attention_resolutions
+ self.d_head = d_head
+ self.n_heads = n_heads
+ def get_d_and_n_heads(dims:int) -> Tuple[int,int]:
+ if self.d_head is None:
+ assert self.n_heads is not None, f"d_head and n_heads cannot both be None"
+ return dims // self.n_heads, self.n_heads
+ else:
+ assert self.n_heads is None, f"d_head and n_heads cannot both be non-None"
+ return self.d_head, dims // self.d_head
+
+ time_embed_dim = model_ch * 4
+ self.time_embed = [
+ Linear(model_ch, time_embed_dim),
+ Tensor.silu,
+ Linear(time_embed_dim, time_embed_dim),
+ ]
+
+ if adm_in_ch is not None:
+ self.label_emb = [
+ [
+ Linear(adm_in_ch, time_embed_dim),
+ Tensor.silu,
+ Linear(time_embed_dim, time_embed_dim),
+ ]
+ ]
+
+ self.input_blocks: List[Any] = [
+ [Conv2d(in_ch, model_ch, 3, padding=1)]
+ ]
+ input_block_channels = [model_ch]
+ ch = model_ch
+ ds = 1
+ for idx, mult in enumerate(channel_mult):
+ for _ in range(self.num_res_blocks[idx]):
+ layers: List[Any] = [
+ ResBlock(ch, time_embed_dim, model_ch*mult, num_groups),
+ ]
+ ch = mult * model_ch
+ if ds in attention_resolutions:
+ d_head, n_heads = get_d_and_n_heads(ch)
+ layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
+
+ self.input_blocks.append(layers)
+ input_block_channels.append(ch)
+
+ if idx != len(channel_mult) - 1:
+ self.input_blocks.append([
+ Downsample(ch),
+ ])
+ input_block_channels.append(ch)
+ ds *= 2
+
+ d_head, n_heads = get_d_and_n_heads(ch)
+ self.middle_block: List = [
+ ResBlock(ch, time_embed_dim, ch, num_groups),
+ SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1], norm_eps=st_norm_eps),
+ ResBlock(ch, time_embed_dim, ch, num_groups),
+ ]
+
+ self.output_blocks = []
+ for idx, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[idx] + 1):
+ ich = input_block_channels.pop()
+ layers = [
+ ResBlock(ch + ich, time_embed_dim, model_ch*mult, num_groups),
+ ]
+ ch = model_ch * mult
+
+ if ds in attention_resolutions:
+ d_head, n_heads = get_d_and_n_heads(ch)
+ layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
+
+ if idx > 0 and i == self.num_res_blocks[idx]:
+ layers.append(Upsample(ch))
+ ds //= 2
+ self.output_blocks.append(layers)
+
+ self.out = [
+ GroupNorm(num_groups, ch),
+ Tensor.silu,
+ Conv2d(model_ch, out_ch, 3, padding=1),
+ ]
+
+ def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
+ t_emb = timestep_embedding(tms, self.model_ch)
+ emb = t_emb.sequential(self.time_embed)
+
+ if y is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + y.sequential(self.label_emb[0])
+
+ if is_dtype_supported(mixed_precision_dtype):
+ emb = emb.cast(mixed_precision_dtype)
+ ctx = ctx.cast(mixed_precision_dtype)
+ x = x .cast(mixed_precision_dtype)
+
+ def run(x:Tensor, bb) -> Tensor:
+ if isinstance(bb, ResBlock): x = bb(x, emb)
+ elif isinstance(bb, SpatialTransformer): x = bb(x, ctx)
+ else: x = bb(x)
+ return x
+
+ saved_inputs = []
+ for b in self.input_blocks:
+ for bb in b:
+ x = run(x, bb)
+ saved_inputs.append(x)
+ for bb in self.middle_block:
+ x = run(x, bb)
+ for b in self.output_blocks:
+ x = x.cat(saved_inputs.pop(), dim=1)
+ for bb in b:
+ x = run(x, bb)
+ return x.sequential(self.out)
diff --git a/backend/python/tinygrad/vendor/whisper.py b/backend/python/tinygrad/vendor/whisper.py
new file mode 100644
index 000000000..70aeb03a4
--- /dev/null
+++ b/backend/python/tinygrad/vendor/whisper.py
@@ -0,0 +1,274 @@
+# Adapted from tinygrad examples/whisper.py (MIT license).
+# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/whisper.py
+# Copyright (c) 2023- the tinygrad authors
+# SPDX-License-Identifier: MIT
+#
+# Local modifications: removed the pyaudio listener / __main__ block; the rest
+# is the core Whisper model + preprocessing + single-file transcription path.
+from __future__ import annotations
+
+import base64
+import collections
+import itertools
+from typing import List, Literal, Optional, Union
+
+import numpy as np
+from tinygrad import Tensor, TinyJit, Variable, dtypes, nn
+from tinygrad.helpers import fetch
+from tinygrad.nn.state import load_state_dict, torch_load
+
+from .audio_helpers import mel
+
+
+class MultiHeadAttention:
+ def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self', None] = None, max_self_attn_cache_len=None):
+ self.n_head = n_head
+ self.query = nn.Linear(n_state, n_state)
+ self.key = nn.Linear(n_state, n_state, bias=False)
+ self.value = nn.Linear(n_state, n_state)
+ self.out = nn.Linear(n_state, n_state)
+ self.kv_caching = kv_caching
+ self.max_self_attn_cache_len = max_self_attn_cache_len
+
+ def __call__(self, x, xa=None, mask=None, len=None):
+ if self.kv_caching == 'cross':
+ if xa is not None:
+ k, v = self.key(xa), self.value(xa)
+ if not hasattr(self, 'cache_k'):
+ self.cache_k, self.cache_v = k, v
+ else:
+ self.cache_k.assign(k).realize()
+ self.cache_v.assign(v).realize()
+ else:
+ k, v = self.cache_k, self.cache_v
+ else:
+ k, v = self.key(x), self.value(x)
+ if self.kv_caching == 'self':
+ if not hasattr(self, 'cache_k'):
+ self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
+ self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
+ k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
+ v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
+ padding = self.max_self_attn_cache_len - len - x.shape[1]
+ self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
+ self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
+
+ q = self.query(x)
+ n_ctx = q.shape[1]
+ head_dim = q.shape[-1] // self.n_head
+ q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
+ k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
+ v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
+ attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None)
+ wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
+ return self.out(wv)
+
+
+class ResidualAttentionBlock:
+ def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
+ self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
+ self.attn_ln = nn.LayerNorm(n_state)
+ self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
+ self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
+ self.mlp = [nn.Linear(n_state, n_state * 4), Tensor.gelu, nn.Linear(n_state * 4, n_state)]
+ self.mlp_ln = nn.LayerNorm(n_state)
+
+ def __call__(self, x, xa=None, mask=None, len=None):
+ x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
+ if self.cross_attn:
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa)
+ x = x + self.mlp_ln(x).sequential(self.mlp)
+ return x.realize()
+
+
+class AudioEncoder:
+ def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
+ self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
+ self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
+ self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
+ self.ln_post = nn.LayerNorm(n_audio_state)
+ self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
+ self.encode = TinyJit(self.__call__)
+
+ def __call__(self, x):
+ x = self.conv1(x).gelu()
+ x = self.conv2(x).gelu()
+ x = x.permute(0, 2, 1)
+ x = x + self.positional_embedding[:x.shape[1]]
+ x = x.sequential(self.blocks)
+ x = self.ln_post(x)
+ return x.realize()
+
+
+class TextDecoder:
+ def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
+ self.max_tokens_to_sample = n_text_ctx // 2
+ self.max_self_attn_cache_len = n_text_ctx
+ self.token_embedding = nn.Embedding(n_vocab, n_text_state)
+ self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
+ self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
+ self.ln = nn.LayerNorm(n_text_state)
+ self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
+ self.getjitted = collections.defaultdict(lambda: TinyJit(self.forward))
+
+ def __call__(self, x, pos, encoded_audio):
+ pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len - 1).bind(pos) if pos else 0
+ return self.getjitted[x.shape](x, pos, encoded_audio)
+
+ def forward(self, x, pos, encoded_audio):
+ seqlen = x.shape[-1]
+ x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos + seqlen), None))
+ for block in self.blocks:
+ x = block(x, xa=encoded_audio, mask=self.mask, len=pos)
+ return self.output_tok(x)
+
+ def output_tok(self, x):
+ return (self.ln(x) @ self.token_embedding.weight.T).realize()
+
+
+class Whisper:
+ def __init__(self, dims, batch_size=1):
+ self.encoder = AudioEncoder(**dims)
+ self.decoder = TextDecoder(**dims)
+ self.is_multilingual = dims["n_vocab"] == 51865
+ self.batch_size = batch_size
+
+
+RATE = 16000
+SEGMENT_SECONDS = 30
+SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS
+N_FFT = 400
+HOP_LENGTH = 160
+N_MELS = 80
+FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH
+
+
+def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate: bool = False) -> np.ndarray:
+ import librosa
+
+ def pad_or_trim(arr, target_len):
+ if len(arr) == target_len:
+ return arr
+ if len(arr) < target_len:
+ return np.pad(arr, (0, target_len - len(arr)), 'constant')
+ return arr[:target_len]
+
+ max_len = SAMPLES_PER_SEGMENT if truncate else max(len(w) for w in waveforms)
+ if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
+ max_len += SAMPLES_PER_SEGMENT - r
+
+ waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
+ if waveforms.shape[0] < batch_size:
+ waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
+
+ stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
+ magnitudes = np.absolute(stft[..., :-1]) ** 2
+ mel_spec = mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS).numpy() @ magnitudes
+ log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
+ log_spec = np.maximum(log_spec, log_spec.max((1, 2), keepdims=True) - 8.0)
+ log_spec = (log_spec + 4.0) / 4.0
+ return log_spec
+
+
+LANGUAGES = {
+ "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean",
+ "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "it": "italian",
+}
+
+
+def get_encoding(encoding_name: str):
+ import tiktoken
+
+ with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
+ ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
+ n_vocab = len(ranks)
+ specials = [
+ "<|endoftext|>",
+ "<|startoftranscript|>",
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
+ "<|translate|>",
+ "<|transcribe|>",
+ "<|startoflm|>",
+ "<|startofprev|>",
+ "<|nospeech|>",
+ "<|notimestamps|>",
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
+ ]
+ special_tokens = dict(zip(specials, itertools.count(n_vocab)))
+ return tiktoken.Encoding(
+ name=encoding_name,
+ explicit_n_vocab=n_vocab + len(specials),
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
+ mergeable_ranks=ranks,
+ special_tokens=special_tokens,
+ )
+
+
+MODEL_URLS = {
+ "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
+ "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
+ "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
+ "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
+ "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
+ "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
+}
+
+
+def init_whisper(model_name: str = "base", batch_size: int = 1):
+ filename = fetch(MODEL_URLS[model_name])
+ state = torch_load(filename)
+ model = Whisper(state['dims'], batch_size)
+ load_state_dict(model, state['model_state_dict'], strict=False)
+ enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
+ return model, enc
+
+
+def load_file_waveform(filename: str):
+ import librosa
+ waveform, _ = librosa.load(filename, sr=RATE)
+ return waveform
+
+
+def transcribe_waveform(model: Whisper, enc, waveforms, language: Optional[str] = None, truncate: bool = False) -> str:
+ log_spec = prep_audio(waveforms, model.batch_size, truncate)
+ nsample = model.decoder.max_tokens_to_sample
+ nctx = model.decoder.max_self_attn_cache_len
+
+ start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
+ if model.is_multilingual:
+ lang = language if (language and language in LANGUAGES) else "en"
+ language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index(lang)
+ start_tokens.append(language_token)
+ start_tokens.append(enc._special_tokens["<|transcribe|>"])
+ start_tokens.append(enc._special_tokens["<|notimestamps|>"])
+
+ eot = enc._special_tokens["<|endoftext|>"]
+
+ def inferloop(ctx, encoded_audio):
+ pos, next_tokens = 0, ctx
+ for _ in range(nsample):
+ next_tokens = model.decoder(Tensor(next_tokens, dtype=dtypes.int32), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
+ next_tokens[ctx[:, -1] == eot] = eot
+ ctx = np.concatenate((ctx, next_tokens), axis=1)
+ pos = ctx.shape[-1] - 1
+ if (next_tokens == eot).all() or pos == nctx:
+ break
+ return ctx
+
+ ctx = np.tile(start_tokens, (model.batch_size, 1))
+ transcriptions: list[list[int]] = [[] for _ in waveforms]
+
+ for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
+ encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
+ ctx_arr = inferloop(np.array(ctx), encoded_audio)
+ for i, arr in enumerate(ctx_arr):
+ if i >= len(waveforms):
+ break
+ end_idxs = np.where(arr == eot)[0]
+ start_idx = np.where(arr == start_tokens[-1])[0][0] + 1
+ end_idx = end_idxs[0] if len(end_idxs) else None
+ transcriptions[i].extend(arr[start_idx:end_idx])
+ ctx = ctx_arr
+
+ texts = [enc.decode([int(t) for t in toks]).strip() for toks in transcriptions]
+ return texts[0] if len(texts) == 1 else "\n".join(texts)
diff --git a/tests/e2e-backends/backend_test.go b/tests/e2e-backends/backend_test.go
index 6b1eb49ec..f36fd02ae 100644
--- a/tests/e2e-backends/backend_test.go
+++ b/tests/e2e-backends/backend_test.go
@@ -44,13 +44,20 @@ import (
// BACKEND_TEST_AUDIO_FILE Path to an already-available sample audio file.
// BACKEND_TEST_CAPS Comma-separated list of capabilities to exercise.
// Supported values: health, load, predict, stream,
-// embeddings, tools, transcription.
+// embeddings, tools, transcription, image.
// Defaults to "health,load,predict,stream".
// A backend that only does embeddings would set this to
-// "health,load,embeddings"; an image/TTS backend that cannot
-// be driven by a text prompt can set it to "health,load".
+// "health,load,embeddings"; an image-generation backend
+// that cannot be driven by a text prompt can set it to
+// "health,load,image".
// "tools" asks the backend to extract a tool call from the
// model output into ChatDelta.tool_calls.
+// "image" exercises the GenerateImage RPC and asserts a
+// non-empty file is written to the requested dst path.
+// BACKEND_TEST_IMAGE_PROMPT Override the positive prompt for the image spec
+// (default: "a photograph of an astronaut riding a horse").
+// BACKEND_TEST_IMAGE_STEPS Override the diffusion step count for the image spec
+// (default: 4 — keeps CPU-only runs under a few minutes).
// BACKEND_TEST_PROMPT Override the prompt used by predict/stream specs.
// BACKEND_TEST_CTX_SIZE Override the context size passed to LoadModel (default 512).
// BACKEND_TEST_THREADS Override Threads passed to LoadModel (default 4).
@@ -75,11 +82,14 @@ const (
capEmbeddings = "embeddings"
capTools = "tools"
capTranscription = "transcription"
+ capImage = "image"
- defaultPrompt = "The capital of France is"
- streamPrompt = "Once upon a time"
- defaultToolPrompt = "What's the weather like in Paris, France?"
- defaultToolName = "get_weather"
+ defaultPrompt = "The capital of France is"
+ streamPrompt = "Once upon a time"
+ defaultToolPrompt = "What's the weather like in Paris, France?"
+ defaultToolName = "get_weather"
+ defaultImagePrompt = "a photograph of an astronaut riding a horse"
+ defaultImageSteps = 4
)
func defaultCaps() map[string]bool {
@@ -349,6 +359,40 @@ var _ = Describe("Backend container", Ordered, func() {
GinkgoWriter.Printf("Embedding: %d dims\n", len(res.GetEmbeddings()))
})
+ It("generates an image via GenerateImage", func() {
+ if !caps[capImage] {
+ Skip("image capability not enabled")
+ }
+
+ imgPrompt := os.Getenv("BACKEND_TEST_IMAGE_PROMPT")
+ if imgPrompt == "" {
+ imgPrompt = defaultImagePrompt
+ }
+ steps := envInt32("BACKEND_TEST_IMAGE_STEPS", defaultImageSteps)
+
+ dst := filepath.Join(workDir, "generated.png")
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
+ defer cancel()
+ res, err := client.GenerateImage(ctx, &pb.GenerateImageRequest{
+ PositivePrompt: imgPrompt,
+ NegativePrompt: "",
+ Width: 512,
+ Height: 512,
+ Step: steps,
+ Seed: 42,
+ Dst: dst,
+ })
+ Expect(err).NotTo(HaveOccurred())
+ Expect(res.GetSuccess()).To(BeTrue(), "GenerateImage failed: %s", res.GetMessage())
+
+ info, err := os.Stat(dst)
+ Expect(err).NotTo(HaveOccurred(), "GenerateImage did not write a file at %s", dst)
+ Expect(info.Size()).To(BeNumerically(">", int64(0)),
+ "GenerateImage wrote an empty file at %s", dst)
+ GinkgoWriter.Printf("GenerateImage: wrote %s (%d bytes)\n", dst, info.Size())
+ })
+
It("extracts tool calls into ChatDelta", func() {
if !caps[capTools] {
Skip("tools capability not enabled")