diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 7bc063578..68953d2f9 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -105,6 +105,25 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + # tinygrad ships a single image — its CPU device uses bundled + # libLLVM, and its CUDA / HIP / Metal devices dlopen the host + # driver libraries at runtime via tinygrad's ctypes autogen + # wrappers. There is no toolkit-version split because tinygrad + # generates kernels itself (PTX renderer for CUDA) and never + # links against cuDNN/cuBLAS/torch. + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-tinygrad' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'true' + backend: "tinygrad" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" diff --git a/Makefile b/Makefile index 133133086..61e51aad4 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad GOCMD=go GOTEST=$(GOCMD) test @@ -432,6 +432,7 @@ prepare-test-extra: protogen-python $(MAKE) -C backend/python/whisperx $(MAKE) -C backend/python/ace-step $(MAKE) -C backend/python/trl + $(MAKE) -C backend/python/tinygrad $(MAKE) -C backend/rust/kokoros kokoros-grpc test-extra: prepare-test-extra @@ -454,6 +455,7 @@ test-extra: prepare-test-extra $(MAKE) -C backend/python/whisperx test $(MAKE) -C backend/python/ace-step test $(MAKE) -C backend/python/trl test + $(MAKE) -C backend/python/tinygrad test $(MAKE) -C backend/rust/kokoros test ## @@ -550,6 +552,56 @@ test-extra-backend-vllm: docker-build-vllm BACKEND_TEST_OPTIONS=tool_parser:hermes \ $(MAKE) test-extra-backend +## tinygrad mirrors the vllm target (same model, same caps, same parser) so +## the two backends are directly comparable. The LLM path covers Predict, +## streaming and native tool-call extraction. Companion targets below cover +## embeddings, Stable Diffusion and Whisper — run them individually or via +## the `test-extra-backend-tinygrad-all` aggregate. +test-extra-backend-tinygrad: docker-build-tinygrad + BACKEND_IMAGE=local-ai-backend:tinygrad \ + BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \ + BACKEND_TEST_CAPS=health,load,predict,stream,tools \ + BACKEND_TEST_OPTIONS=tool_parser:hermes \ + $(MAKE) test-extra-backend + +## tinygrad — embeddings via LLM last-hidden-state pooling. Reuses the same +## Qwen2.5-0.5B-Instruct as the chat target so we don't need a separate BERT +## vendor; the Embedding RPC mean-pools and L2-normalizes the last-layer +## hidden state. +test-extra-backend-tinygrad-embeddings: docker-build-tinygrad + BACKEND_IMAGE=local-ai-backend:tinygrad \ + BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \ + BACKEND_TEST_CAPS=health,load,embeddings \ + $(MAKE) test-extra-backend + +## tinygrad — Stable Diffusion 1.5. The original CompVis/runwayml repos have +## been gated, so we use the community-maintained mirror at +## stable-diffusion-v1-5/stable-diffusion-v1-5 with the EMA-only pruned +## checkpoint (~4.3GB). Step count is kept low (4) so a CPU-only run finishes +## in a few minutes; bump BACKEND_TEST_IMAGE_STEPS for higher quality. +test-extra-backend-tinygrad-sd: docker-build-tinygrad + BACKEND_IMAGE=local-ai-backend:tinygrad \ + BACKEND_TEST_MODEL_NAME=stable-diffusion-v1-5/stable-diffusion-v1-5 \ + BACKEND_TEST_CAPS=health,load,image \ + $(MAKE) test-extra-backend + +## tinygrad — Whisper. Loads OpenAI's tiny.en checkpoint (smallest at ~75MB) +## from the original azure CDN through tinygrad's `fetch` helper, and +## transcribes the canonical jfk.wav fixture from whisper.cpp's CI samples. +## Exercises both AudioTranscription and AudioTranscriptionStream. +test-extra-backend-tinygrad-whisper: docker-build-tinygrad + BACKEND_IMAGE=local-ai-backend:tinygrad \ + BACKEND_TEST_MODEL_NAME=openai/whisper-tiny.en \ + BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \ + BACKEND_TEST_CAPS=health,load,transcription \ + $(MAKE) test-extra-backend + +test-extra-backend-tinygrad-all: \ + test-extra-backend-tinygrad \ + test-extra-backend-tinygrad-embeddings \ + test-extra-backend-tinygrad-sd \ + test-extra-backend-tinygrad-whisper + ## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool ## parser from the chat template, so no tool_parser: option is needed (it ## would be ignored at runtime). Run this on macOS / arm64 with Metal; the @@ -707,6 +759,7 @@ BACKEND_MLX_VLM = mlx-vlm|python|.|false|true BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true BACKEND_TRL = trl|python|.|false|true BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true +BACKEND_TINYGRAD = tinygrad|python|.|false|true # Rust backends BACKEND_KOKOROS = kokoros|rust|.|false|true @@ -778,6 +831,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM))) $(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED))) $(eval $(call generate-docker-build-target,$(BACKEND_TRL))) $(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION))) +$(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD))) $(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS))) $(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP))) @@ -785,7 +839,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP))) docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp +docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp ######################################################## ### Mock Backend for E2E Tests diff --git a/backend/index.yaml b/backend/index.yaml index c29ca5fb8..f7dc72251 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -361,6 +361,34 @@ intel: "intel-rerankers" amd: "rocm-rerankers" metal: "metal-rerankers" +- &tinygrad + name: "tinygrad" + alias: "tinygrad" + license: MIT + description: | + tinygrad is a minimalist deep-learning framework with zero runtime + dependencies that targets CUDA, ROCm, Metal, WebGPU and CPU (CLANG). + The LocalAI tinygrad backend exposes a single multimodal runtime that + covers LLM text generation (Llama / Qwen / Mistral via safetensors or + GGUF) with native tool-call extraction, BERT-family embeddings, + Stable Diffusion 1.x / 2 / XL image generation, and Whisper speech-to-text. + + Single image: tinygrad generates its own GPU kernels and dlopens the + host driver libraries at runtime, so there is no per-toolkit build + split. The same image runs CPU-only or accelerates against + CUDA / ROCm / Metal when the host driver is visible. + urls: + - https://github.com/tinygrad/tinygrad + uri: "quay.io/go-skynet/local-ai-backends:latest-tinygrad" + mirrors: + - localai/localai-backends:latest-tinygrad + tags: + - text-to-text + - LLM + - embeddings + - image-generation + - transcription + - multimodal - &transformers name: "transformers" icon: https://avatars.githubusercontent.com/u/25720743?s=200&v=4 @@ -1993,6 +2021,15 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rerankers" mirrors: - localai/localai-backends:master-metal-darwin-arm64-rerankers +## tinygrad +## Single image — the meta anchor above carries the latest uri directly +## since there is only one variant. The development entry below points at +## the master tag. +- !!merge <<: *tinygrad + name: "tinygrad-development" + uri: "quay.io/go-skynet/local-ai-backends:master-tinygrad" + mirrors: + - localai/localai-backends:master-tinygrad ## Transformers - !!merge <<: *transformers name: "transformers-development" diff --git a/backend/python/tinygrad/Makefile b/backend/python/tinygrad/Makefile new file mode 100644 index 000000000..4f7c209a2 --- /dev/null +++ b/backend/python/tinygrad/Makefile @@ -0,0 +1,25 @@ +.DEFAULT_GOAL := install + +.PHONY: install +install: + bash install.sh + +.PHONY: run +run: install + @echo "Running tinygrad..." + bash run.sh + @echo "tinygrad run." + +.PHONY: test +test: install + @echo "Testing tinygrad..." + bash test.sh + @echo "tinygrad tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/tinygrad/backend.py b/backend/python/tinygrad/backend.py new file mode 100644 index 000000000..7fa997f47 --- /dev/null +++ b/backend/python/tinygrad/backend.py @@ -0,0 +1,750 @@ +#!/usr/bin/env python3 +""" +LocalAI gRPC backend for tinygrad. + +Multimodal scope (landing incrementally): + - LLM text generation (Llama 3 / Qwen 2.5 / Mistral via HF safetensors + GGUF) + - Native tool-call extraction via pluggable parsers (hermes, llama3_json, ...) + - Embeddings, Stable Diffusion, Whisper — planned; currently return UNIMPLEMENTED. + +The heavy imports (tinygrad, tokenizers, vendor.llama) are deferred until +`LoadModel`, because tinygrad binds its compute device at import time from +env vars. `_select_tinygrad_device()` maps LocalAI's BUILD_TYPE onto the +corresponding tinygrad env flag before any import happens. +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import signal +import sys +import time +from concurrent import futures +from pathlib import Path +from typing import Any, Optional + +import grpc + +import backend_pb2 +import backend_pb2_grpc + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) +from grpc_auth import get_auth_interceptors # noqa: E402 + +from tool_parsers import resolve_parser # noqa: E402 +from tool_parsers.base import ToolCall # noqa: E402 + +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + + +# --------------------------------------------------------------------------- +# Device selection — must run BEFORE `import tinygrad` anywhere. +# +# In production this is set by run.sh based on which driver libraries the +# host has injected into the container (libcuda.so.1 → CUDA, libamdhip64 +# → HIP, otherwise CLANG). This helper is only a fallback for direct +# invocations like the unit tests. +# --------------------------------------------------------------------------- + +def _select_tinygrad_device() -> None: + if any(os.environ.get(k) == "1" for k in ("CUDA", "HIP", "METAL", "CLANG", "AMD", "NV")): + return + os.environ["CLANG"] = "1" + + +# --------------------------------------------------------------------------- +# Model asset discovery +# --------------------------------------------------------------------------- + +def _resolve_model_assets(model_ref: str) -> Path: + """ + Accept either a local path or a HuggingFace repo id (e.g. + "Qwen/Qwen2.5-0.5B-Instruct") and return the local directory / file. + HF ids are materialized via `huggingface_hub.snapshot_download`. + """ + p = Path(model_ref) + if p.exists(): + return p + if "/" in model_ref and not model_ref.startswith(("/", ".")): + from huggingface_hub import snapshot_download + local = snapshot_download( + repo_id=model_ref, + allow_patterns=[ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "generation_config.json", + "*.safetensors", + "*.safetensors.index.json", + ], + ) + return Path(local) + raise FileNotFoundError(f"Model not found: {model_ref}") + + +def _load_llm_weights(model_dir: Path): + """Load HF safetensors weights from a directory or a single file.""" + from tinygrad import Device, Tensor, dtypes + from tinygrad.nn.state import safe_load, gguf_load + + if model_dir.is_file(): + if str(model_dir).endswith(".gguf"): + gguf_tensor = Tensor.empty(os.stat(model_dir).st_size, dtype=dtypes.uint8, + device=f"disk:{model_dir}").to(Device.DEFAULT) + return gguf_load(gguf_tensor)[1], "gguf" + return safe_load(str(model_dir)), "safetensors" + + index = model_dir / "model.safetensors.index.json" + if index.exists(): + with open(index) as fp: + weight_map = json.load(fp)["weight_map"] + shards: dict[str, Any] = {} + for shard_name in set(weight_map.values()): + shards[shard_name] = safe_load(str(model_dir / shard_name)) + merged = {k: shards[n][k] for k, n in weight_map.items()} + return merged, "safetensors" + + single = model_dir / "model.safetensors" + if single.exists(): + return safe_load(str(single)), "safetensors" + + ggufs = sorted(model_dir.glob("*.gguf")) + if ggufs: + gguf_tensor = Tensor.empty(os.stat(ggufs[0]).st_size, dtype=dtypes.uint8, + device=f"disk:{ggufs[0]}").to(Device.DEFAULT) + return gguf_load(gguf_tensor)[1], "gguf" + + raise FileNotFoundError(f"No safetensors or gguf weights found under {model_dir}") + + +def _infer_qkv_bias(config: dict) -> bool: + """Qwen2-family uses Q/K/V projection bias; Llama/Mistral do not.""" + architectures = [a.lower() for a in config.get("architectures", [])] + if any("qwen" in a for a in architectures): + return True + return bool(config.get("attention_bias", False)) or bool(config.get("qkv_bias", False)) + + +def _auto_tool_parser(model_ref: Optional[str], config: dict) -> Optional[str]: + """Pick a tool parser automatically from model family heuristics. + + Order of precedence: architecture name from config.json, then model ref + string. Returns None to fall through to the passthrough parser. + """ + arches = " ".join(a.lower() for a in config.get("architectures", [])) + ref = (model_ref or "").lower() + blob = f"{arches} {ref}" + + if "qwen3" in blob: + return "qwen3_xml" + if "hermes" in blob or "qwen2" in blob or "qwen" in blob: + return "hermes" + if "llama-3" in blob or "llama_3" in blob or "llama3" in blob: + return "llama3_json" + if "mistral" in blob or "mixtral" in blob: + return "mistral" + return None + + +# --------------------------------------------------------------------------- +# Servicer +# --------------------------------------------------------------------------- + +class BackendServicer(backend_pb2_grpc.BackendServicer): + """gRPC servicer for the tinygrad backend.""" + + def __init__(self) -> None: + self._reset_state() + + def _reset_state(self) -> None: + self.model_ref: Optional[str] = None + self.model_type: str = "llm" + self.options: dict[str, str] = {} + # LLM state + self.llm_model = None + self.llm_config: dict = {} + self.llm_tokenizer = None + self.llm_eos_ids: list[int] = [] + self.chat_template: Optional[str] = None + self.tool_parser = resolve_parser(None) + self.max_context = 4096 + # Stable Diffusion state + self.sd_model = None + # Whisper state + self.whisper_model = None + self.whisper_tokenizer = None + + # --------------------- helpers -------------------------------------- + + @staticmethod + def _parse_options(options_list) -> dict[str, str]: + opts: dict[str, str] = {} + for opt in options_list: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + opts[key.strip()] = value.strip() + return opts + + @staticmethod + def _detect_model_type(model_ref: str, explicit: Optional[str]) -> str: + if explicit: + return explicit + name = (model_ref or "").lower() + if "whisper" in name: + return "whisper" + if "sdxl" in name: + return "sdxl" + if "sd-v1" in name or "v1-5" in name or "stable-diffusion" in name: + return "sd15" + if any(tag in name for tag in ("bge", "e5", "minilm", "bert")): + return "bert" + return "llm" + + def _messages_to_dicts(self, messages) -> list[dict]: + result = [] + for msg in messages: + d: dict = {"role": msg.role, "content": msg.content or ""} + if msg.name: + d["name"] = msg.name + if msg.tool_call_id: + d["tool_call_id"] = msg.tool_call_id + if msg.reasoning_content: + d["reasoning_content"] = msg.reasoning_content + if msg.tool_calls: + try: + d["tool_calls"] = json.loads(msg.tool_calls) + except json.JSONDecodeError: + pass + result.append(d) + return result + + def _render_prompt(self, request) -> str: + """Render messages + tools into the model's chat template, or fall + back to the raw Prompt field for models without a template.""" + if not request.Messages and request.Prompt: + return request.Prompt + + if not self.chat_template: + # No template known — concatenate role/content lines. + lines = [] + for msg in request.Messages: + lines.append(f"{msg.role}: {msg.content or ''}") + return "\n".join(lines) + "\nassistant:" + + from jinja2 import Environment + + env = Environment(trim_blocks=True, lstrip_blocks=True) + template = env.from_string(self.chat_template) + + tools = None + if request.Tools: + try: + tools = json.loads(request.Tools) + except json.JSONDecodeError: + tools = None + + return template.render( + messages=self._messages_to_dicts(request.Messages), + tools=tools, + add_generation_prompt=True, + ) + + # --------------------- LLM path ------------------------------------- + + def _load_llm(self, model_dir: Path) -> None: + config_path = model_dir / "config.json" if model_dir.is_dir() else None + if config_path is None or not config_path.exists(): + raise FileNotFoundError(f"config.json not found under {model_dir}") + with open(config_path) as fp: + config = json.load(fp) + self.llm_config = config + + from tinygrad import nn + + from vendor.llama import Transformer, convert_from_huggingface, convert_from_gguf, fix_bf16 + + n_layers = config["num_hidden_layers"] + n_heads = config["num_attention_heads"] + n_kv_heads = config.get("num_key_value_heads", n_heads) + dim = config["hidden_size"] + hidden_dim = config["intermediate_size"] + vocab_size = config["vocab_size"] + norm_eps = config.get("rms_norm_eps", 1e-5) + rope_theta = config.get("rope_theta", 10000) + self.max_context = min(config.get("max_position_embeddings", 4096), 8192) + qkv_bias = _infer_qkv_bias(config) + + weights, weight_format = _load_llm_weights(model_dir) + + model = Transformer( + dim=dim, + hidden_dim=hidden_dim, + n_heads=n_heads, + n_layers=n_layers, + norm_eps=norm_eps, + vocab_size=vocab_size, + linear=nn.Linear, + embedding=nn.Embedding, + n_kv_heads=n_kv_heads, + rope_theta=rope_theta, + max_context=self.max_context, + jit=True, + qkv_bias=qkv_bias, + ) + + if weight_format == "safetensors": + weights = convert_from_huggingface(weights, n_layers, n_heads, n_kv_heads) + else: + weights = convert_from_gguf(weights, n_layers) + weights = fix_bf16(weights) + + from tinygrad.nn.state import load_state_dict + load_state_dict(model, weights, strict=False, consume=True) + self.llm_model = model + + # Tokenizer + tokenizer_json = model_dir / "tokenizer.json" + if not tokenizer_json.exists(): + raise FileNotFoundError(f"tokenizer.json not found under {model_dir}") + from tokenizers import Tokenizer as HFTokenizer + self.llm_tokenizer = HFTokenizer.from_file(str(tokenizer_json)) + + # Chat template + EOS ids (from tokenizer_config.json + generation_config.json) + tok_cfg_path = model_dir / "tokenizer_config.json" + if tok_cfg_path.exists(): + with open(tok_cfg_path) as fp: + tok_cfg = json.load(fp) + self.chat_template = tok_cfg.get("chat_template") + + self.llm_eos_ids = [] + gen_cfg_path = model_dir / "generation_config.json" + if gen_cfg_path.exists(): + with open(gen_cfg_path) as fp: + gen_cfg = json.load(fp) + eos = gen_cfg.get("eos_token_id") + if isinstance(eos, list): + self.llm_eos_ids.extend(int(x) for x in eos) + elif isinstance(eos, int): + self.llm_eos_ids.append(eos) + if not self.llm_eos_ids: + eos = config.get("eos_token_id") + if isinstance(eos, list): + self.llm_eos_ids.extend(int(x) for x in eos) + elif isinstance(eos, int): + self.llm_eos_ids.append(eos) + + # Auto-pick tool parser from options or model family. + parser_name = self.options.get("tool_parser") or _auto_tool_parser(self.model_ref, config) + self.tool_parser = resolve_parser(parser_name) + + # --------------------- Stable Diffusion path ------------------------ + + def _load_sd(self, model_ref: str) -> None: + """Load a Stable Diffusion 1.x checkpoint (CompVis `.ckpt` format).""" + from huggingface_hub import hf_hub_download + from tinygrad.nn.state import load_state_dict, torch_load + + from vendor.stable_diffusion import StableDiffusion + + ckpt_path = Path(model_ref) + if not ckpt_path.exists(): + # Accept an HF repo id — fetch the canonical v1-5-pruned-emaonly.ckpt + # from the requested repo. Common case is runwayml/stable-diffusion-v1-5. + repo_id = model_ref if "/" in model_ref else "runwayml/stable-diffusion-v1-5" + ckpt_file = self.options.get("sd_ckpt_filename", "v1-5-pruned-emaonly.ckpt") + ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_file)) + + model = StableDiffusion() + state_dict = torch_load(str(ckpt_path)) + if isinstance(state_dict, dict) and "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + load_state_dict(model, state_dict, strict=False, verbose=False, realize=False) + self.sd_model = model + + # --------------------- Whisper path --------------------------------- + + def _load_whisper(self, model_ref: str) -> None: + """Load a Whisper checkpoint (OpenAI `.pt` format). + + Accepts a model-size alias (tiny / tiny.en / base / base.en / small / + small.en) OR an explicit `.pt` file path OR the HF repo id naming + convention `openai/whisper-*` (mapped to the matching OpenAI alias). + """ + from vendor.whisper import init_whisper, MODEL_URLS + + alias = model_ref + if "/" in alias and alias.startswith("openai/whisper-"): + alias = alias.removeprefix("openai/whisper-") + if alias not in MODEL_URLS: + # Explicit path to a .pt checkpoint — fall back to size heuristic + # via filename. + basename = Path(alias).name.lower() + for name in MODEL_URLS: + if name in basename: + alias = name + break + else: + raise ValueError( + f"Unknown Whisper model_ref={model_ref!r}; expected one of {list(MODEL_URLS)} " + f"or an openai/whisper-* HF id" + ) + + model, enc = init_whisper(alias, batch_size=1) + self.whisper_model = model + self.whisper_tokenizer = enc + + # --------------------- LLM generation ------------------------------- + + def _generate_tokens(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int): + """Synchronous generator yielding (token_id, token_text) pairs.""" + from tinygrad import Tensor + + encoded = self.llm_tokenizer.encode(prompt) + ids = encoded.ids + if not ids: + return + + # Prefill: run one forward pass over the full prompt, sampling from the last token. + prompt_tensor = Tensor([ids]) + next_tok = int(self.llm_model(prompt_tensor, 0, temperature=temperature, top_k=top_k, top_p=top_p).item()) + pos = len(ids) + + for _ in range(max_new_tokens): + if next_tok in self.llm_eos_ids: + break + text = self.llm_tokenizer.decode([next_tok]) + yield next_tok, text + next_tok = int(self.llm_model(Tensor([[next_tok]]), pos, temperature=temperature, + top_k=top_k, top_p=top_p).item()) + pos += 1 + + # --------------------- gRPC methods --------------------------------- + + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + async def LoadModel(self, request, context): + try: + _select_tinygrad_device() + self._reset_state() + self.options = self._parse_options(list(request.Options)) + self.model_ref = request.ModelFile or request.Model + self.model_type = self._detect_model_type(self.model_ref, self.options.get("model_type")) + + if self.model_type in ("sd15", "sd", "stable-diffusion"): + self._load_sd(self.model_ref) + return backend_pb2.Result( + success=True, message="tinygrad Stable Diffusion 1.x loaded", + ) + + if self.model_type == "whisper": + self._load_whisper(self.model_ref) + return backend_pb2.Result( + success=True, message="tinygrad Whisper loaded", + ) + + if self.model_type != "llm": + return backend_pb2.Result( + success=False, + message=f"tinygrad: model_type={self.model_type} not yet implemented", + ) + + model_path = _resolve_model_assets(self.model_ref) + if model_path.is_file(): + return backend_pb2.Result( + success=False, + message=("tinygrad currently requires an HF model directory with config.json + " + "tokenizer.json; single-file GGUF support lands with gallery metadata wiring."), + ) + self._load_llm(model_path) + + return backend_pb2.Result( + success=True, + message=f"tinygrad LLM loaded (arch={self.llm_config.get('architectures', ['?'])[0]}, " + f"parser={self.tool_parser.name})", + ) + except Exception as exc: + import traceback + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"LoadModel failed: {exc}") + + async def Predict(self, request, context): + if self.llm_model is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("LLM not loaded") + return backend_pb2.Reply() + + try: + prompt = self._render_prompt(request) + max_new = request.Tokens if request.Tokens > 0 else 256 + temperature = request.Temperature if request.Temperature > 0 else 0.7 + top_p = request.TopP if request.TopP > 0 else 0.95 + top_k = request.TopK if request.TopK > 0 else 0 + + t0 = time.monotonic() + pieces: list[str] = [] + ntok = 0 + for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k): + pieces.append(text) + ntok += 1 + elapsed = time.monotonic() - t0 + + full = "".join(pieces) + from tool_parsers.hermes import HermesToolParser + if isinstance(self.tool_parser, HermesToolParser): + result = self.tool_parser.parse_full(full) + content, calls, reasoning = result.content, result.tool_calls, result.reasoning + else: + content, calls = self.tool_parser.parse(full) + reasoning = "" + + delta = backend_pb2.ChatDelta( + content=content, + reasoning_content=reasoning, + tool_calls=[ + backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments) + for c in calls + ], + ) + return backend_pb2.Reply( + message=content.encode("utf-8"), + tokens=ntok, + timing_token_generation=elapsed, + chat_deltas=[delta], + ) + except Exception as exc: + import traceback + traceback.print_exc() + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"Predict failed: {exc}") + return backend_pb2.Reply() + + async def PredictStream(self, request, context): + if self.llm_model is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("LLM not loaded") + return + + try: + prompt = self._render_prompt(request) + max_new = request.Tokens if request.Tokens > 0 else 256 + temperature = request.Temperature if request.Temperature > 0 else 0.7 + top_p = request.TopP if request.TopP > 0 else 0.95 + top_k = request.TopK if request.TopK > 0 else 0 + + buffer = "" + for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k): + buffer += text + yield backend_pb2.Reply( + message=text.encode("utf-8"), + chat_deltas=[backend_pb2.ChatDelta(content=text)], + ) + + # Final emission carries the extracted tool calls (vLLM semantics). + from tool_parsers.hermes import HermesToolParser + if isinstance(self.tool_parser, HermesToolParser): + result = self.tool_parser.parse_full(buffer) + calls = result.tool_calls + reasoning = result.reasoning + else: + _, calls = self.tool_parser.parse(buffer) + reasoning = "" + + if calls or reasoning: + yield backend_pb2.Reply( + chat_deltas=[backend_pb2.ChatDelta( + reasoning_content=reasoning, + tool_calls=[ + backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments) + for c in calls + ], + )], + ) + except Exception as exc: + import traceback + traceback.print_exc() + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"PredictStream failed: {exc}") + + async def Embedding(self, request, context): + if self.llm_model is None or self.llm_tokenizer is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("No model loaded") + return backend_pb2.EmbeddingResult() + + try: + text = request.Embeddings + if not text: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("Embeddings field is empty") + return backend_pb2.EmbeddingResult() + + from tinygrad import Tensor, dtypes + + ids = self.llm_tokenizer.encode(text).ids + if not ids: + return backend_pb2.EmbeddingResult(embeddings=[]) + + # Clamp to context window — truncate long inputs rather than blow up. + ids = ids[: self.max_context] + tokens = Tensor([ids]) + + hidden = self.llm_model.embed(tokens) # (1, seqlen, dim) + # Mean pool over sequence dim + pooled = hidden.mean(axis=1).squeeze(0) # (dim,) + # L2 normalize + norm = pooled.square().sum().sqrt() + normalized = (pooled / (norm + 1e-12)) + vec = normalized.cast(dtypes.float32).tolist() + + return backend_pb2.EmbeddingResult(embeddings=[float(x) for x in vec]) + except Exception as exc: + import traceback + traceback.print_exc() + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"Embedding failed: {exc}") + return backend_pb2.EmbeddingResult() + + async def GenerateImage(self, request, context): + if self.sd_model is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("No Stable Diffusion model loaded") + return backend_pb2.Result(success=False, message="not loaded") + + try: + from PIL import Image + from vendor.stable_diffusion import run_sd15 + + steps = request.step if request.step > 0 else 20 + guidance = 7.5 + seed = request.seed if request.seed != 0 else None + img_tensor = run_sd15( + model=self.sd_model, + prompt=request.positive_prompt or "", + negative_prompt=request.negative_prompt or "", + steps=steps, + guidance=guidance, + seed=seed, + ) + arr = img_tensor.numpy() + image = Image.fromarray(arr) + dst = request.dst or "/tmp/tinygrad_image.png" + image.save(dst) + return backend_pb2.Result(success=True, message=dst) + except Exception as exc: + import traceback + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"GenerateImage failed: {exc}") + + def _transcribe(self, audio_path: str, language: Optional[str]) -> tuple[str, float]: + from vendor.whisper import load_file_waveform, transcribe_waveform + + waveform = load_file_waveform(audio_path) + text = transcribe_waveform( + self.whisper_model, + self.whisper_tokenizer, + [waveform], + language=language or None, + ) + duration = float(len(waveform)) / 16000.0 + return text, duration + + async def AudioTranscription(self, request, context): + if self.whisper_model is None or self.whisper_tokenizer is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("No Whisper model loaded") + return backend_pb2.TranscriptResult() + + try: + if not request.dst: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("TranscriptRequest.dst (audio file path) is required") + return backend_pb2.TranscriptResult() + + text, duration = self._transcribe(request.dst, request.language) + segments = [backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)] + return backend_pb2.TranscriptResult( + text=text, + language=request.language or "en", + duration=duration, + segments=segments, + ) + except Exception as exc: + import traceback + traceback.print_exc() + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"AudioTranscription failed: {exc}") + return backend_pb2.TranscriptResult() + + async def AudioTranscriptionStream(self, request, context): + if self.whisper_model is None or self.whisper_tokenizer is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("No Whisper model loaded") + return + + try: + if not request.dst: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("TranscriptRequest.dst (audio file path) is required") + return + + # The vendored tinygrad whisper loop is chunked at the file level + # (one inference pass per 30s segment), not token-level. To still + # produce a streaming response we run the full transcription and + # emit it as a single delta + a final-result envelope so the client + # gets both code paths exercised. + text, duration = self._transcribe(request.dst, request.language) + yield backend_pb2.TranscriptStreamResponse(delta=text) + final = backend_pb2.TranscriptResult( + text=text, + language=request.language or "en", + duration=duration, + segments=[backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)], + ) + yield backend_pb2.TranscriptStreamResponse(final_result=final) + except Exception as exc: + import traceback + traceback.print_exc() + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"AudioTranscriptionStream failed: {exc}") + + async def Status(self, request, context): + return backend_pb2.StatusResponse(state=backend_pb2.StatusResponse.READY) + + async def Free(self, request, context): + self._reset_state() + return backend_pb2.Result(success=True, message="freed") + + +async def serve(address): + server = grpc.aio.server( + migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), + ('grpc.max_send_message_length', 50 * 1024 * 1024), + ('grpc.max_receive_message_length', 50 * 1024 * 1024), + ], + interceptors=get_auth_interceptors(aio=True), + ) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5))) + + await server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + await server.wait_for_termination() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the tinygrad gRPC backend.") + parser.add_argument("--addr", default="localhost:50051", help="Bind address") + args = parser.parse_args() + asyncio.run(serve(args.addr)) diff --git a/backend/python/tinygrad/install.sh b/backend/python/tinygrad/install.sh new file mode 100755 index 000000000..c4f444938 --- /dev/null +++ b/backend/python/tinygrad/install.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +# tinygrad >= 0.12 requires Python >= 3.11 (pyproject: `requires-python = ">=3.11"`). +# LocalAI's default portable python is 3.10, so we pin to 3.11.x here. +PYTHON_VERSION="3.11" +PYTHON_PATCH="14" +PY_STANDALONE_TAG="20260203" + +installRequirements diff --git a/backend/python/tinygrad/package.sh b/backend/python/tinygrad/package.sh new file mode 100755 index 000000000..5ac600e19 --- /dev/null +++ b/backend/python/tinygrad/package.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# Script to package runtime shared libraries for the tinygrad backend. +# +# The final Dockerfile.python stage is FROM scratch, so system libraries +# must be explicitly copied into ${BACKEND}/lib so the backend can run on +# any host without installing them. libbackend.sh automatically prepends +# that directory to LD_LIBRARY_PATH at run time. +# +# tinygrad's CPU device (CLANG / LLVM renderer) JIT-compiles kernels at +# runtime. The default `CLANG` path invokes the external `clang` binary via +# subprocess, which does not exist in the scratch image. We force the +# in-process LLVM path (`CPU_LLVM=1` in run.sh) which loads libLLVM.so.* +# through ctypes and bundle the library + its runtime dependencies here. +# +# Also bundle libgomp (pulled by librosa / numpy via numba) and libsndfile +# (required by soundfile -> librosa audio I/O for Whisper). + +set -e + +CURDIR=$(dirname "$(realpath "$0")") +LIB_DIR="${CURDIR}/lib" +mkdir -p "${LIB_DIR}" + +SEARCH_DIRS=( + /usr/lib/x86_64-linux-gnu + /usr/lib/aarch64-linux-gnu + /lib/x86_64-linux-gnu + /lib/aarch64-linux-gnu + /usr/lib + /lib +) + +copy_with_symlinks() { + local soname="$1" + local hit="" + for dir in "${SEARCH_DIRS[@]}"; do + if [ -e "${dir}/${soname}" ]; then + hit="${dir}/${soname}" + break + fi + done + if [ -z "${hit}" ]; then + echo "warning: ${soname} not found in standard lib paths" >&2 + return 0 + fi + local real + real=$(readlink -f "${hit}") + cp -v "${real}" "${LIB_DIR}/" + local real_base + real_base=$(basename "${real}") + if [ "${real_base}" != "${soname}" ]; then + ln -sf "${real_base}" "${LIB_DIR}/${soname}" + fi +} + +# tinygrad searches for libLLVM under these sonames (see +# tinygrad/runtime/autogen/llvm.py). Ubuntu 24.04's `llvm` metapackage +# installs `libLLVM-18.so.1` into `/usr/lib/llvm-18/lib/`. Also scan the +# standard lib directories in case a different distro layout puts it in +# /usr/lib/x86_64-linux-gnu. +llvm_so="" +shopt -s nullglob +LLVM_EXTRA_DIRS=(/usr/lib/llvm-*/lib /usr/lib/llvm-*) +# First try the versioned symlink (libLLVM-18.so) since that's what +# tinygrad's DLL loader matches against (see llvm.py DLL name list). +for dir in "${SEARCH_DIRS[@]}" "${LLVM_EXTRA_DIRS[@]}"; do + for candidate in "${dir}"/libLLVM-[0-9]*.so "${dir}"/libLLVM-[0-9]*.so.[0-9]*; do + if [ -e "${candidate}" ]; then + llvm_so="${candidate}" + break 2 + fi + done +done +# Fallback: any libLLVM.so file under /usr. +if [ -z "${llvm_so}" ]; then + llvm_so=$(find /usr -maxdepth 5 -name 'libLLVM*.so*' 2>/dev/null | head -1) +fi +shopt -u nullglob +if [ -z "${llvm_so}" ]; then + echo "ERROR: libLLVM not found — tinygrad CPU device needs it." >&2 + echo "Install the Ubuntu \`llvm\` package in the builder stage." >&2 + exit 1 +fi +echo "Found libLLVM at: ${llvm_so}" +llvm_base=$(basename "${llvm_so}") +real_llvm=$(readlink -f "${llvm_so}") +cp -v "${real_llvm}" "${LIB_DIR}/" +real_base=$(basename "${real_llvm}") +if [ "${real_base}" != "${llvm_base}" ]; then + ln -sf "${real_base}" "${LIB_DIR}/${llvm_base}" +fi + +# libLLVM has soft runtime deps on libedit / libtinfo; pick them up if +# present. They're optional but loading without them can fail. +copy_with_symlinks libedit.so.2 +copy_with_symlinks libtinfo.so.6 + +# Audio I/O for the Whisper path. +copy_with_symlinks libsndfile.so.1 +copy_with_symlinks libgomp.so.1 + +echo "tinygrad packaging completed successfully" +ls -liah "${LIB_DIR}/" diff --git a/backend/python/tinygrad/protogen.sh b/backend/python/tinygrad/protogen.sh new file mode 100755 index 000000000..df3325c6f --- /dev/null +++ b/backend/python/tinygrad/protogen.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runProtogen diff --git a/backend/python/tinygrad/requirements-cpu.txt b/backend/python/tinygrad/requirements-cpu.txt new file mode 100644 index 000000000..e1c64640a --- /dev/null +++ b/backend/python/tinygrad/requirements-cpu.txt @@ -0,0 +1 @@ +# tinygrad CPU backend uses CLANG device (no extra deps required). diff --git a/backend/python/tinygrad/requirements-cublas12.txt b/backend/python/tinygrad/requirements-cublas12.txt new file mode 100644 index 000000000..2deb71bc5 --- /dev/null +++ b/backend/python/tinygrad/requirements-cublas12.txt @@ -0,0 +1,2 @@ +# tinygrad drives CUDA through its own JIT (CUDA=1 env var). +# Requires the CUDA 12 runtime from the base image; no extra Python deps. diff --git a/backend/python/tinygrad/requirements-cublas13.txt b/backend/python/tinygrad/requirements-cublas13.txt new file mode 100644 index 000000000..74dc56eae --- /dev/null +++ b/backend/python/tinygrad/requirements-cublas13.txt @@ -0,0 +1,2 @@ +# tinygrad drives CUDA through its own JIT (CUDA=1 env var). +# Requires the CUDA 13 runtime from the base image; no extra Python deps. diff --git a/backend/python/tinygrad/requirements.txt b/backend/python/tinygrad/requirements.txt new file mode 100644 index 000000000..12b31ffca --- /dev/null +++ b/backend/python/tinygrad/requirements.txt @@ -0,0 +1,15 @@ +grpcio==1.80.0 +protobuf==6.33.5 +certifi +setuptools +numpy>=2.0.0 +tinygrad>=0.12.0 +tokenizers>=0.21.0 +huggingface_hub +jinja2>=3.1.0 +tiktoken +sentencepiece +safetensors +Pillow +librosa +soundfile diff --git a/backend/python/tinygrad/run.sh b/backend/python/tinygrad/run.sh new file mode 100755 index 000000000..b341291ea --- /dev/null +++ b/backend/python/tinygrad/run.sh @@ -0,0 +1,55 @@ +#!/bin/bash +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +# tinygrad binds its compute device at import time from a single env var +# (CUDA / HIP / METAL / CLANG). We pick one here based on what driver +# libraries the host has injected into the container — when a user runs +# the image with `--gpus all` (or the equivalent rocm runtime), the +# nvidia-container-toolkit / rocm runtime mounts the right libraries +# under /usr/lib so we can detect them. +# +# tinygrad's CUDA path uses two compiler pairs: an NVRTC-backed one and +# an in-process PTX renderer. We force the PTX renderer here +# (`CUDA_PTX=1`) so the image is independent of the host CUDA toolkit +# version — only libcuda.so.1 (the driver) is required. +find_lib() { + local soname="$1" + for dir in /usr/lib/x86_64-linux-gnu /usr/lib64 /usr/lib /lib/x86_64-linux-gnu /lib64 /lib; do + if [ -e "${dir}/${soname}" ]; then + echo "${dir}/${soname}" + return 0 + fi + done + return 1 +} + +if [ -z "${CUDA:-}${HIP:-}${METAL:-}${CLANG:-}" ]; then + if find_lib libcuda.so.1 >/dev/null; then + export CUDA=1 + export CUDA_PTX=1 + elif find_lib libamdhip64.so >/dev/null || find_lib libamdhip64.so.6 >/dev/null; then + export HIP=1 + else + export CLANG=1 + fi +fi + +# The CPU path (CLANG=1) JIT-compiles via libLLVM. Force tinygrad's +# in-process LLVM compiler so we don't need an external `clang` binary +# (which is not present in the scratch image). +export CPU_LLVM=1 +if [ -z "${LLVM_PATH:-}" ]; then + for candidate in "${EDIR}"/lib/libLLVM-*.so "${EDIR}"/lib/libLLVM-*.so.* "${EDIR}"/lib/libLLVM.so.*; do + if [ -e "${candidate}" ]; then + export LLVM_PATH="${candidate}" + break + fi + done +fi + +startBackend $@ diff --git a/backend/python/tinygrad/test.py b/backend/python/tinygrad/test.py new file mode 100644 index 000000000..bf47d9bda --- /dev/null +++ b/backend/python/tinygrad/test.py @@ -0,0 +1,84 @@ +""" +Unit tests for the tinygrad gRPC backend. + +These tests cover the cheap paths that don't need a real model checkpoint: + - Health responds OK + - Tool-call parsers emit expected ToolCall structures + +The full LLM / embeddings / Stable Diffusion / Whisper paths are exercised by +the root-level `make test-extra-backend-tinygrad-all` e2e targets, which boot +the containerized backend against real HF checkpoints. +""" +import os +import subprocess +import sys +import time +import unittest + +import grpc + +import backend_pb2 +import backend_pb2_grpc + +sys.path.insert(0, os.path.dirname(__file__)) +from tool_parsers.hermes import HermesToolParser # noqa: E402 + + +class TestHealth(unittest.TestCase): + def setUp(self): + self.service = subprocess.Popen( + ["python3", "backend.py", "--addr", "localhost:50051"] + ) + time.sleep(5) + + def tearDown(self): + self.service.kill() + self.service.wait() + + def test_health(self): + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b"OK") + + +class TestHermesParser(unittest.TestCase): + def test_single_tool_call(self): + parser = HermesToolParser() + text = ( + "Sure, let me check.\n" + "\n" + '{"name": "get_weather", "arguments": {"city": "Paris"}}\n' + "\n" + "Done." + ) + content, calls = parser.parse(text) + self.assertIn("Sure", content) + self.assertIn("Done", content) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0].name, "get_weather") + self.assertIn("Paris", calls[0].arguments) + + def test_multi_call_and_thinking(self): + parser = HermesToolParser() + text = ( + "I need both." + '{"name":"a","arguments":{"x":1}}' + '{"name":"b","arguments":{}}' + ) + result = parser.parse_full(text) + self.assertEqual(result.reasoning, "I need both.") + self.assertEqual([c.name for c in result.tool_calls], ["a", "b"]) + self.assertEqual(result.tool_calls[0].index, 0) + self.assertEqual(result.tool_calls[1].index, 1) + + def test_no_tool_call_is_passthrough(self): + parser = HermesToolParser() + text = "plain assistant answer with no tool call" + content, calls = parser.parse(text) + self.assertEqual(content, text) + self.assertEqual(calls, []) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/python/tinygrad/test.sh b/backend/python/tinygrad/test.sh new file mode 100755 index 000000000..eb59f2aaf --- /dev/null +++ b/backend/python/tinygrad/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/backend/python/tinygrad/tool_parsers/__init__.py b/backend/python/tinygrad/tool_parsers/__init__.py new file mode 100644 index 000000000..3f1d6c5b7 --- /dev/null +++ b/backend/python/tinygrad/tool_parsers/__init__.py @@ -0,0 +1,11 @@ +"""Tool-call parsers for the tinygrad backend. + +Each parser takes raw model output and extracts OpenAI-style tool calls so +the backend can populate `ChatDelta.tool_calls[]` natively (matching vLLM's +behavior, which the Go core prefers over regex fallback parsing). +""" +from __future__ import annotations + +from .base import ToolCall, ToolParser, resolve_parser + +__all__ = ["ToolCall", "ToolParser", "resolve_parser"] diff --git a/backend/python/tinygrad/tool_parsers/base.py b/backend/python/tinygrad/tool_parsers/base.py new file mode 100644 index 000000000..5e2fcc910 --- /dev/null +++ b/backend/python/tinygrad/tool_parsers/base.py @@ -0,0 +1,85 @@ +"""Common types + parser registry for tool-call extraction.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ToolCall: + """One extracted tool call — maps 1:1 to backend_pb2.ToolCallDelta.""" + index: int + name: str + arguments: str # JSON string + id: str = "" + + +class ToolParser: + """Parser interface. + + Subclasses implement `parse` (full non-streaming pass) and optionally + `parse_stream` (incremental). The default `parse_stream` buffers until a + full response is available and then delegates to `parse`. + """ + + name: str = "base" + + def __init__(self) -> None: + self._stream_buffer = "" + self._stream_index = 0 + + def parse(self, text: str) -> tuple[str, list[ToolCall]]: + """Return (content_for_user, tool_calls).""" + raise NotImplementedError + + def parse_stream(self, delta: str, finished: bool = False) -> tuple[str, list[ToolCall]]: + """Accumulate a streaming delta. Emits any tool calls that have closed. + + Default behavior: buffer until `finished=True`, then parse once. + Subclasses can override to emit mid-stream. + """ + self._stream_buffer += delta + if not finished: + return "", [] + content, calls = self.parse(self._stream_buffer) + # Re-index starting from whatever we've already emitted in this stream. + reindexed: list[ToolCall] = [] + for i, c in enumerate(calls): + reindexed.append(ToolCall( + index=self._stream_index + i, + name=c.name, + arguments=c.arguments, + id=c.id, + )) + self._stream_index += len(reindexed) + return content, reindexed + + def reset(self) -> None: + self._stream_buffer = "" + self._stream_index = 0 + + +_REGISTRY: dict[str, type[ToolParser]] = {} + + +def register(cls: type[ToolParser]) -> type[ToolParser]: + _REGISTRY[cls.name] = cls + return cls + + +def resolve_parser(name: Optional[str]) -> ToolParser: + """Return a parser instance by name, falling back to a no-op passthrough.""" + # Import for side effects — each module registers itself. + from . import hermes, llama3_json, mistral, qwen3_xml # noqa: F401 + + if name and name in _REGISTRY: + return _REGISTRY[name]() + return PassthroughToolParser() + + +class PassthroughToolParser(ToolParser): + """No-op parser — used when no tool_parser is configured.""" + name = "passthrough" + + def parse(self, text: str) -> tuple[str, list[ToolCall]]: + return text, [] diff --git a/backend/python/tinygrad/tool_parsers/hermes.py b/backend/python/tinygrad/tool_parsers/hermes.py new file mode 100644 index 000000000..078dc9613 --- /dev/null +++ b/backend/python/tinygrad/tool_parsers/hermes.py @@ -0,0 +1,74 @@ +"""Hermes-format tool-call parser. + +Hermes 2 / 2.5 / 3 (and Qwen 2.5 Instruct, which adopted the same convention) +emit tool calls wrapped in `...` tags, where the inner +content is a JSON object with `name` and `arguments` keys: + + + {"name": "get_weather", "arguments": {"city": "Paris"}} + + +Multiple tool calls may appear back-to-back. Text outside the tags is plain +assistant content that should surface to the user. + +This parser also strips `...` reasoning blocks and returns them +via the reasoning_content channel (Qwen 3, DeepSeek-R1 distills). +""" +from __future__ import annotations + +import json +import re +from dataclasses import dataclass + +from .base import ToolCall, ToolParser, register + +_TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) +_THINK_RE = re.compile(r"(.*?)", re.DOTALL) + + +@dataclass +class HermesParseResult: + content: str + reasoning: str + tool_calls: list[ToolCall] + + +@register +class HermesToolParser(ToolParser): + name = "hermes" + + def _parse_full(self, text: str) -> HermesParseResult: + reasoning_parts: list[str] = [] + + def _capture_reasoning(match: re.Match[str]) -> str: + reasoning_parts.append(match.group(1).strip()) + return "" + + text_wo_think = _THINK_RE.sub(_capture_reasoning, text) + + calls: list[ToolCall] = [] + for idx, match in enumerate(_TOOL_CALL_RE.finditer(text_wo_think)): + raw = match.group(1) + try: + obj = json.loads(raw) + except json.JSONDecodeError: + continue + if not isinstance(obj, dict): + continue + name = obj.get("name") + if not isinstance(name, str): + continue + args = obj.get("arguments", {}) + args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False) + calls.append(ToolCall(index=idx, name=name, arguments=args_str)) + + content = _TOOL_CALL_RE.sub("", text_wo_think).strip() + reasoning = "\n\n".join(reasoning_parts).strip() + return HermesParseResult(content=content, reasoning=reasoning, tool_calls=calls) + + def parse(self, text: str) -> tuple[str, list[ToolCall]]: + result = self._parse_full(text) + return result.content, result.tool_calls + + def parse_full(self, text: str) -> HermesParseResult: + return self._parse_full(text) diff --git a/backend/python/tinygrad/tool_parsers/llama3_json.py b/backend/python/tinygrad/tool_parsers/llama3_json.py new file mode 100644 index 000000000..ee357519a --- /dev/null +++ b/backend/python/tinygrad/tool_parsers/llama3_json.py @@ -0,0 +1,86 @@ +"""Llama 3.1 / 3.2 / 3.3 JSON tool-call parser. + +Meta's Llama 3.1+ instruct chat templates emit tool calls in two broadly +compatible shapes: + + 1. With the `<|python_tag|>` lead-in: + <|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}} + 2. As a bare JSON object (or list of objects) at the end of the turn. + +We also handle multi-call shapes where the model emits several JSON objects +separated by `;` or newlines, and JSON arrays `[{...}, {...}]`. The key field +for Llama 3 is historically `parameters` (older docs) but recent checkpoints +also emit `arguments` — accept either. +""" +from __future__ import annotations + +import json +import re +from dataclasses import dataclass + +from .base import ToolCall, ToolParser, register + +_PYTHON_TAG = "<|python_tag|>" +_JSON_OBJECT_RE = re.compile(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", re.DOTALL) + + +def _coerce_call(obj: object, index: int) -> ToolCall | None: + if not isinstance(obj, dict): + return None + name = obj.get("name") + if not isinstance(name, str): + return None + args = obj.get("arguments", obj.get("parameters", {})) + args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False) + return ToolCall(index=index, name=name, arguments=args_str) + + +@register +class Llama3JsonToolParser(ToolParser): + name = "llama3_json" + + def parse(self, text: str) -> tuple[str, list[ToolCall]]: + calls: list[ToolCall] = [] + + # Strip <|python_tag|> segments first — each segment is one tool call + # body. The content after the final python_tag (if any) is the call. + remaining = text + if _PYTHON_TAG in text: + head, *tails = text.split(_PYTHON_TAG) + remaining = head + for tail in tails: + parsed = _try_parse(tail.strip(), len(calls)) + calls.extend(parsed) + + # Any JSON objects / arrays left in `remaining` count as tool calls too + # if they parse to a {"name": ..., "arguments": ...} shape. + for match in _JSON_OBJECT_RE.finditer(remaining): + parsed = _try_parse(match.group(0), len(calls)) + if parsed: + calls.extend(parsed) + remaining = remaining.replace(match.group(0), "", 1) + + content = remaining.strip() + return content, calls + + +def _try_parse(blob: str, start_index: int) -> list[ToolCall]: + """Parse a fragment that may be a JSON object or a JSON array of objects.""" + blob = blob.strip().rstrip(";") + if not blob: + return [] + try: + obj = json.loads(blob) + except json.JSONDecodeError: + return [] + if isinstance(obj, dict): + call = _coerce_call(obj, start_index) + return [call] if call else [] + if isinstance(obj, list): + calls: list[ToolCall] = [] + for i, item in enumerate(obj): + c = _coerce_call(item, start_index + i) + if c: + calls.append(c) + return calls + return [] diff --git a/backend/python/tinygrad/tool_parsers/mistral.py b/backend/python/tinygrad/tool_parsers/mistral.py new file mode 100644 index 000000000..2a8891313 --- /dev/null +++ b/backend/python/tinygrad/tool_parsers/mistral.py @@ -0,0 +1,56 @@ +"""Mistral / Mixtral tool-call parser. + +Mistral Nemo / Small / Large Instruct emit tool calls prefixed with the +`[TOOL_CALLS]` control token, followed by a JSON array: + + [TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}}] + +Multiple calls live inside the same array. Any text before `[TOOL_CALLS]` is +normal assistant content and should surface to the user. +""" +from __future__ import annotations + +import json +import re + +from .base import ToolCall, ToolParser, register + +_MARKER = "[TOOL_CALLS]" +_JSON_ARRAY_RE = re.compile(r"\[\s*(?:\{.*?\}\s*,?\s*)+\]", re.DOTALL) + + +@register +class MistralToolParser(ToolParser): + name = "mistral" + + def parse(self, text: str) -> tuple[str, list[ToolCall]]: + if _MARKER not in text: + return text.strip(), [] + + head, tail = text.split(_MARKER, 1) + content = head.strip() + + match = _JSON_ARRAY_RE.search(tail) + if not match: + return content, [] + + try: + arr = json.loads(match.group(0)) + except json.JSONDecodeError: + return content, [] + + if not isinstance(arr, list): + return content, [] + + calls: list[ToolCall] = [] + for i, obj in enumerate(arr): + if not isinstance(obj, dict): + continue + name = obj.get("name") + if not isinstance(name, str): + continue + args = obj.get("arguments", {}) + args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False) + calls.append(ToolCall(index=i, name=name, arguments=args_str)) + + return content, calls diff --git a/backend/python/tinygrad/tool_parsers/qwen3_xml.py b/backend/python/tinygrad/tool_parsers/qwen3_xml.py new file mode 100644 index 000000000..e0dcbf848 --- /dev/null +++ b/backend/python/tinygrad/tool_parsers/qwen3_xml.py @@ -0,0 +1,74 @@ +"""Qwen 3 XML tool-call parser. + +Qwen 3 Instruct emits tool calls wrapped in a two-level tag structure: + + + + + Paris + + + celsius + + + + +Parameter values are raw text — we treat them as strings unless they look +like JSON (in which case we try to parse so numbers / booleans round-trip +cleanly). Qwen 3 also supports `...` reasoning blocks before +the tool call — these are captured via the shared Hermes convention. +""" +from __future__ import annotations + +import json +import re + +from .base import ToolCall, ToolParser, register + +_TOOL_CALL_RE = re.compile(r"(.*?)", re.DOTALL) +_FUNCTION_RE = re.compile(r"]+)>(.*?)", re.DOTALL) +_PARAMETER_RE = re.compile(r"]+)>(.*?)", re.DOTALL) +_THINK_RE = re.compile(r"(.*?)", re.DOTALL) + + +def _maybe_json(value: str): + value = value.strip() + if not value: + return value + if value[0] in "{[\"" or value in ("true", "false", "null") or value.lstrip("-").replace(".", "", 1).isdigit(): + try: + return json.loads(value) + except json.JSONDecodeError: + return value + return value + + +@register +class Qwen3XmlToolParser(ToolParser): + name = "qwen3_xml" + + def parse(self, text: str) -> tuple[str, list[ToolCall]]: + # Strip reasoning blocks from the user-visible content. + stripped = _THINK_RE.sub("", text) + + calls: list[ToolCall] = [] + for match in _TOOL_CALL_RE.finditer(stripped): + body = match.group(1) + fn_match = _FUNCTION_RE.search(body) + if not fn_match: + continue + name = fn_match.group(1).strip() + params_body = fn_match.group(2) + + params: dict[str, object] = {} + for pm in _PARAMETER_RE.finditer(params_body): + params[pm.group(1).strip()] = _maybe_json(pm.group(2)) + + calls.append(ToolCall( + index=len(calls), + name=name, + arguments=json.dumps(params, ensure_ascii=False), + )) + + content = _TOOL_CALL_RE.sub("", stripped).strip() + return content, calls diff --git a/backend/python/tinygrad/vendor/__init__.py b/backend/python/tinygrad/vendor/__init__.py new file mode 100644 index 000000000..ea1cc0be6 --- /dev/null +++ b/backend/python/tinygrad/vendor/__init__.py @@ -0,0 +1,6 @@ +"""Vendored upstream tinygrad reference code (MIT-licensed). + +Source: https://github.com/tinygrad/tinygrad +These files are not part of the `tinygrad` pip package (the `extra/` tree is +excluded from `pyproject.toml` `packages`), so we carry a pinned copy here. +""" diff --git a/backend/python/tinygrad/vendor/audio_helpers.py b/backend/python/tinygrad/vendor/audio_helpers.py new file mode 100644 index 000000000..b6951eb66 --- /dev/null +++ b/backend/python/tinygrad/vendor/audio_helpers.py @@ -0,0 +1,83 @@ +# Vendored verbatim from tinygrad examples/audio_helpers.py (MIT license). +# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/audio_helpers.py +# Copyright (c) 2023- the tinygrad authors +# SPDX-License-Identifier: MIT +from typing import Optional +from tinygrad import Tensor +from tinygrad.dtype import DTypeLike, dtypes +import math + +# rewritten from numpy +def rfftfreq(n: int, d: float = 1.0, device=None) -> Tensor: + val = 1.0 / (n * d) + N = n // 2 + 1 + results = Tensor.arange(N, device=device) + return results * val + +# just like in librosa +def fft_frequencies(sr: float, n_fft: int) -> Tensor: + return rfftfreq(n=n_fft, d=1.0 / sr) + +def hz_to_mel(freq: Tensor) -> Tensor: + # linear part + f_min = 0.0 + f_sp = 200.0 / 3 + mels = (freq - f_min) / f_sp + + # log-scale part + min_log_hz = 1000.0 # beginning of log region (Hz) + mask = freq >= min_log_hz + return mask.where(((min_log_hz - f_min) / f_sp) + (freq / min_log_hz).log() / (math.log(6.4) / 27.0), mels) + +def mel_to_hz(mels: Tensor) -> Tensor: + # linear scale + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mels + + # nonlinear scale + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = math.log(6.4) / 27.0 # step size for log region + + log_t = mels >= min_log_mel + freqs = log_t.where(min_log_hz * ((logstep * (mels - min_log_mel)).exp()), freqs) + return freqs + +def mel_frequencies(n_mels: int = 128, *, fmin: float = 0.0, fmax: float = 11025.0) -> Tensor: + # center freqs of mel bands - uniformly spaced between limits + min_max_mel = hz_to_mel(Tensor([fmin, fmax])) + + mels = Tensor.linspace(min_max_mel[0], min_max_mel[1], n_mels) + hz = mel_to_hz(mels) + return hz + +def mel( + *, + sr: float, + n_fft: int, + n_mels: int = 128, + fmin: float = 0.0, + fmax: Optional[float] = None, + dtype: DTypeLike = dtypes.default_float, +) -> Tensor: + if fmax is None: + fmax = float(sr) / 2 + + n_mels = int(n_mels) + + fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) # center freqs of each FFT bin + mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax) # center freqs of mel bands + + fdiff = mel_f[1:] - mel_f[:-1] + ramps = mel_f[None].T.expand(-1, fftfreqs.shape[-1]) - fftfreqs + + lower = -ramps[:n_mels] / fdiff[:n_mels][None].T + upper = ramps[2 : n_mels + 2] / fdiff[1 : n_mels + 1][None].T + weights = lower.minimum(upper).maximum(0) + + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) + weights *= enorm[:, None] + + return weights diff --git a/backend/python/tinygrad/vendor/clip.py b/backend/python/tinygrad/vendor/clip.py new file mode 100644 index 000000000..3c7da3141 --- /dev/null +++ b/backend/python/tinygrad/vendor/clip.py @@ -0,0 +1,484 @@ +# Vendored verbatim from tinygrad extra/models/clip.py (MIT license). +# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/clip.py +# Copyright (c) 2023- the tinygrad authors +# SPDX-License-Identifier: MIT +from tinygrad import Tensor, dtypes +from tinygrad.helpers import fetch +from tinygrad.nn import Linear, LayerNorm, Embedding, Conv2d + +from typing import List, Optional, Union, Tuple, Dict +from abc import ABC, abstractmethod +from functools import lru_cache +import numpy as np +import re, gzip + +# Allow for monkeypatching for mlperf. +gelu = Tensor.gelu + +@lru_cache() +def default_bpe(): + # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) + return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz") + +class Tokenizer: + """ + Namespace for CLIP Text Tokenizer components. + """ + + @staticmethod + def get_pairs(word): + """ + Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + return set(zip(word, word[1:])) + @staticmethod + def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + @staticmethod + def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + class ClipTokenizer: + def __init__(self, version=None): + self.byte_encoder, self.version = Tokenizer.bytes_to_unicode(), version + merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(Tokenizer.bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if self.version == "sd_mlperf_v5_0": + import regex + vocab.extend(['', '']) + self.cache = {'': '', '': ''} + self.pat = regex.compile(r"""||'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE) + else: + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = Tokenizer.get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + pairs = Tokenizer.get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]: + bpe_tokens: List[int] = [] + if self.version == "sd_mlperf_v5_0": + import regex, ftfy, html + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)).strip() + text = Tokenizer.whitespace_clean(text).lower() + re_module = regex + else: + text = Tokenizer.whitespace_clean(text.strip()).lower() + re_module = re + + for token in re_module.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + # Truncation, keeping two slots for start and end tokens. + if len(bpe_tokens) > 75: + bpe_tokens = bpe_tokens[:75] + return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2) + + +class Embedder(ABC): + input_key: str + @abstractmethod + def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: + pass + + +class Closed: + """ + Namespace for OpenAI CLIP model components. + """ + class ClipMlp: + def __init__(self): + self.fc1 = Linear(768, 3072) + self.fc2 = Linear(3072, 768) + + def __call__(self, h:Tensor) -> Tensor: + h = self.fc1(h) + h = h.quick_gelu() + h = self.fc2(h) + return h + + class ClipAttention: + def __init__(self): + self.embed_dim = 768 + self.num_heads = 12 + self.head_dim = self.embed_dim // self.num_heads + self.k_proj = Linear(self.embed_dim, self.embed_dim) + self.v_proj = Linear(self.embed_dim, self.embed_dim) + self.q_proj = Linear(self.embed_dim, self.embed_dim) + self.out_proj = Linear(self.embed_dim, self.embed_dim) + + def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor: + bsz, tgt_len, embed_dim = hidden_states.shape + q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)] + attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask) + return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)) + + class ClipEncoderLayer: + def __init__(self): + self.self_attn = Closed.ClipAttention() + self.layer_norm1 = LayerNorm(768) + self.mlp = Closed.ClipMlp() + self.layer_norm2 = LayerNorm(768) + + def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states, causal_attention_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + class ClipTextEmbeddings: + def __init__(self): + self.token_embedding = Embedding(49408, 768) + self.position_embedding = Embedding(77, 768) + + def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor: + return self.token_embedding(input_ids) + self.position_embedding(position_ids) + + class ClipEncoder: + def __init__(self, layer_count:int=12): + self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)] + + def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor: + # the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state + layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx] + for l in layers: + x = l(x, causal_attention_mask) + return x + + class ClipTextTransformer: + def __init__(self, ret_layer_idx:Optional[int]=None): + self.embeddings = Closed.ClipTextEmbeddings() + self.encoder = Closed.ClipEncoder() + self.final_layer_norm = LayerNorm(768) + self.ret_layer_idx = ret_layer_idx + + def __call__(self, input_ids:Tensor) -> Tensor: + x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1)) + x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx) + return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x + + class ClipTextModel: + def __init__(self, ret_layer_idx:Optional[int]): + self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx) + + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331 +class FrozenClosedClipEmbedder(Embedder): + def __init__(self, ret_layer_idx:Optional[int]=None): + self.tokenizer = Tokenizer.ClipTokenizer() + self.transformer = Closed.ClipTextModel(ret_layer_idx) + self.input_key = "txt" + + def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: + if isinstance(texts, str): texts = [texts] + assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}" + tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0) + return self.transformer.text_model(tokens.reshape(len(texts),-1)) + + +class Open: + """ + Namespace for OpenCLIP model components. + """ + class MultiheadAttention: + def __init__(self, dims:int, n_heads:int): + self.dims = dims + self.n_heads = n_heads + self.d_head = self.dims // self.n_heads + + self.in_proj_bias = Tensor.empty(3*dims) + self.in_proj_weight = Tensor.empty(3*dims, dims) + self.out_proj = Linear(dims, dims) + + def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: + T,B,C = x.shape + + proj = x.linear(self.in_proj_weight.T, self.in_proj_bias) + proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0, -2) + + q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)] + + attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + attn_output = attn_output.permute(2, 0, 1, 3).reshape(T, B, C) + attn_output = self.out_proj(attn_output) + + return attn_output + + class Mlp: + def __init__(self, dims, hidden_dims): + self.c_fc = Linear(dims, hidden_dims) + self.c_proj = Linear(hidden_dims, dims) + self.gelu = gelu + + def __call__(self, x:Tensor) -> Tensor: + return x.sequential([self.c_fc, self.gelu, self.c_proj]) + + # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210 + class ResidualAttentionBlock: + def __init__(self, dims:int, n_heads:int, mlp_ratio:float): + self.ln_1 = LayerNorm(dims) + self.attn = Open.MultiheadAttention(dims, n_heads) + + self.ln_2 = LayerNorm(dims) + self.mlp = Open.Mlp(dims, int(dims * mlp_ratio)) + + def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None, transpose:bool=False) -> Tensor: + q_x = self.ln_1(x) + attn_out = self.attn(q_x.transpose(0, 1) if transpose else q_x, attn_mask=attn_mask) + attn_out = attn_out.transpose(0, 1) if transpose else attn_out + x = x + attn_out + x = x + self.mlp(self.ln_2(x)) + return x + + # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317 + class ClipTransformer: + def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0): + self.resblocks = [ + Open.ResidualAttentionBlock(dims, n_heads, mlp_ratio) for _ in range(layers) + ] + + def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: + for r in self.resblocks: + x = r(x, attn_mask=attn_mask, transpose=True) + return x + + # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220 + # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661 + class ClipTextTransformer: + def __init__(self, width:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77): + self.token_embedding = Embedding(vocab_size, width) + self.positional_embedding = Tensor.empty(ctx_length, width) + self.transformer = Open.ClipTransformer(width, layers, n_heads) + self.ln_final = LayerNorm(width) + self.text_projection = Tensor.empty(width, width) + self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize() + + def __call__(self, text:Tensor) -> Tensor: + seq_len = text.shape[1] + + x = self.token_embedding(text) + x = x + self.positional_embedding[:seq_len] + x = self.transformer(x, attn_mask=self.attn_mask) + x = self.ln_final(x) + + pooled = x[:, text.argmax(dim=-1)] @ self.text_projection + return pooled + + class ClipVisionTransformer: + def __init__(self, width:int, layers:int, d_head:int, image_size:int, patch_size:int): + grid_size = image_size // patch_size + n_heads = width // d_head + assert n_heads * d_head == width + + self.conv1 = Conv2d(3, width, kernel_size=patch_size, stride=patch_size, bias=False) + + self.class_embedding = Tensor.empty(width) + self.positional_embedding = Tensor.empty(grid_size * grid_size + 1, width) + self.transformer = Open.ClipTransformer(width, layers, n_heads) + self.ln_pre = LayerNorm(width) + self.ln_post = LayerNorm(width) + self.proj = Tensor.empty(width, 1024) + + def __call__(self, x:Tensor) -> Tensor: + x = self.conv1(x) + x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1) + x = self.class_embedding.reshape(1, 1, -1).expand(x.shape[0], 1, -1).cat(x, dim=1) + x = x + self.positional_embedding + + x = self.ln_pre(x) + x = self.transformer(x) + x = self.ln_post(x) + + pooled = x[:, 0] @ self.proj + return pooled + + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396 +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498 +class FrozenOpenClipEmbedder(Embedder): + def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False, clip_tokenizer_version=None): + self.tokenizer = Tokenizer.ClipTokenizer(version=clip_tokenizer_version) + self.model = Open.ClipTextTransformer(dims, n_heads, layers) + self.return_pooled = return_pooled + self.input_key = "txt" + self.ln_penultimate = ln_penultimate + + def tokenize(self, text:str, device:Optional[str]=None) -> Tensor: + return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int32, device=device).reshape(1,-1) + + def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None): + for r in self.model.transformer.resblocks: + x, penultimate = r(x, attn_mask=attn_mask), x + return x.permute(1, 0, 2), penultimate.permute(1, 0, 2) + + def embed_tokens(self, tokens:Tensor) -> Union[Tensor,Tuple[Tensor,...]]: + x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2) + x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + + if self.ln_penultimate: + penultimate = self.model.ln_final(penultimate) + + if self.return_pooled: + x = self.model.ln_final(x) + index = tokens.argmax(axis=-1).reshape(-1,1,1).expand(x.shape[0],1,x.shape[-1]) + pooled = x.gather(1, index).squeeze(1) @ self.model.text_projection + return penultimate, pooled + else: + return penultimate + + def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: + if isinstance(texts, str): texts = [texts] + assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}" + tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0) + return self.embed_tokens(tokens) + + +clip_configs: Dict = { + "ViT-H-14": { + "dims": 1024, + "vision_cfg": { + "width": 1280, + "layers": 32, + "d_head": 80, + "image_size": 224, + "patch_size": 14, + }, + "text_cfg": { + "width": 1024, + "n_heads": 16, + "layers": 24, + "ctx_length": 77, + "vocab_size": 49408, + }, + "return_pooled": False, + "ln_penultimate": True, + } +} + +class OpenClipEncoder: + def __init__(self, dims:int, text_cfg:Dict, vision_cfg:Dict, **_): + self.visual = Open.ClipVisionTransformer(**vision_cfg) + + text = Open.ClipTextTransformer(**text_cfg) + self.transformer = text.transformer + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + + self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize() + self.mean = Tensor([0.48145466, 0.45782750, 0.40821073]).reshape(-1, 1, 1) + self.std = Tensor([0.26862954, 0.26130258, 0.27577711]).reshape(-1, 1, 1) + + # TODO: + # Should be doable in pure tinygrad, would just require some work and verification. + # This is very desirable since it would allow for full generation->evaluation in a single JIT call. + def prepare_image(self, image) -> Tensor: + from PIL import Image + SIZE = 224 + w, h = image.size + scale = min(SIZE / h, SIZE / w) + image = image.resize((max(int(w*scale),SIZE),max(int(h*scale),SIZE)), Image.Resampling.BICUBIC) + w, h = image.size + if w > SIZE: + left = (w - SIZE) // 2 + image = image.crop((left, left+SIZE, 0, SIZE)) + elif h > SIZE: + top = (h - SIZE) // 2 + image = image.crop((0, SIZE, top, top+SIZE)) + + x = Tensor(np.array(image.convert('RGB')), device=self.std.device) + x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0 + return (x - self.mean) / self.std + + def encode_tokens(self, tokens:Tensor) -> Tensor: + x = self.token_embedding(tokens) + x = x + self.positional_embedding + x = self.transformer(x, attn_mask=self.attn_mask) + x = self.ln_final(x) + x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)] + x = x @ self.text_projection + return x + + def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor: + image_features: Tensor = self.visual(image) + image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm + + text_features = self.encode_tokens(tokens) + text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm + + return (image_features * text_features).sum(axis=-1) diff --git a/backend/python/tinygrad/vendor/llama.py b/backend/python/tinygrad/vendor/llama.py new file mode 100644 index 000000000..445f68b9c --- /dev/null +++ b/backend/python/tinygrad/vendor/llama.py @@ -0,0 +1,294 @@ +# Vendored from tinygrad extra/models/llama.py (MIT license). +# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/llama.py +# +# Local modification: Attention / TransformerBlock / Transformer accept an +# optional `qkv_bias` flag so the same module can host Qwen2-style models that +# use bias on the Q/K/V projections (Llama 3 has no bias). Changes are marked +# with `# LOCALAI`. +# +# Copyright (c) 2023- the tinygrad authors +# SPDX-License-Identifier: MIT +from typing import Union, Optional, Any +import collections, math +from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device +from tinygrad.helpers import getenv, DEBUG + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: + freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim)) + freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) + return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2) + + +def complex_mult(A, c, d): + a, b = A[..., 0:1], A[..., 1:2] + ro = a * c - b * d + co = a * d + b * c + return ro.cat(co, dim=-1) + + +def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}" + xq = xq.reshape(*xq.shape[0:-1], -1, 2) + xk = xk.reshape(*xk.shape[0:-1], -1, 2) + assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5 + c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2] + xq_out = complex_mult(xq, c, d) + xk_out = complex_mult(xk, c, d) + return xq_out.flatten(3), xk_out.flatten(3) + + +def repeat_kv(x: Tensor, n_rep: int) -> Tensor: + bs, seqlen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim) + + +class Attention: + # LOCALAI: added qkv_bias + def __init__(self, dim, n_heads, n_kv_heads=None, max_context=0, linear=nn.Linear, qk_norm: float | None = None, qkv_bias: bool = False): + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads + self.head_dim = dim // n_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.max_context = max_context + + self.wq = linear(dim, self.n_heads * self.head_dim, bias=qkv_bias) + self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias) + self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias) + self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) + + self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None + self.k_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None + + def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor] = None) -> Tensor: + xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x) + + if self.q_norm is not None and self.k_norm is not None: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + if x.dtype == dtypes.bfloat16: + xq, xk = xq.contiguous_backward(), xk.contiguous_backward() + + xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim) + xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim) + xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis) + bsz, seqlen, _, _ = xq.shape + + if self.max_context: + if not hasattr(self, "cache_kv"): + self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize() + if isinstance(x.device, tuple): + self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize() + + assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}" + self.cache_kv[:, :, start_pos:start_pos + seqlen, :, :].assign(Tensor.stack(xk, xv)).realize() + + keys = self.cache_kv[0, :, 0:start_pos + seqlen, :, :] + values = self.cache_kv[1, :, 0:start_pos + seqlen, :, :] + else: + assert start_pos == 0 + keys, values = xk, xv + + if self.max_context: + keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) + xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) + attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2) + else: + xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) + attn = xq.scaled_dot_product_attention(keys, values, is_causal=True, enable_gqa=True).transpose(1, 2) + + attn = attn.reshape(bsz, seqlen, -1) + return self.wo(attn) + + +class FeedForward: + def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear): + self.w1 = linear(dim, hidden_dim, bias=False) + self.w2 = linear(hidden_dim, dim, bias=False) + self.w3 = linear(dim, hidden_dim, bias=False) + + def __call__(self, x: Tensor) -> Tensor: + w1 = self.w1(x).silu() + w3 = self.w3(x.contiguous_backward()) + return self.w2(w1 * w3) + + +class TransformerBlock: + # LOCALAI: added qkv_bias + def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, + linear=nn.Linear, feed_forward=FeedForward, qk_norm=None, qkv_bias: bool = False): + self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear, qk_norm, qkv_bias=qkv_bias) + self.feed_forward = feed_forward(dim, hidden_dim, linear) + self.attention_norm = nn.RMSNorm(dim, norm_eps) + self.ffn_norm = nn.RMSNorm(dim, norm_eps) + + def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]): + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) + return (h + self.feed_forward(self.ffn_norm(h))).contiguous().contiguous_backward() + + +def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): + assert logits.ndim == 1, "only works on 1d tensors" + assert 0 <= p <= 1, "p must be between 0 and 1" + assert 0 <= k <= logits.numel(), "k must be between 0 and numel" + + if temp < 1e-6: + return logits.argmax() + + logits = logits.to(Device.DEFAULT) + + if af or ap: + if not hasattr(sample, "alpha_counter"): + setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous()) + logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap) + + logits = (logits != logits).where(-float("inf"), logits) + + t = (logits / temp).softmax() + + counter = Tensor.arange(t.numel(), device=logits.device).contiguous() + counter2 = Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous() + + if k: + output = Tensor.zeros(k, device=logits.device).contiguous() + output_indices = Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous() + for i in range(k): + t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int) + output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),)) + output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),)) + t = (counter == t_argmax).where(0, t) + + output_cumsum = output[::-1].cumsum()[::-1] + t.sum() + output = (output_cumsum >= (1 - p)) * output + output_indices = (output_cumsum >= (1 - p)) * output_indices + + output_idx = output.multinomial() + output_token = output_indices[output_idx] + else: + output_token = t.multinomial() + + if af or ap: + sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter) + + return output_token + + +class Transformer: + # LOCALAI: added qkv_bias + def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_layers: int, norm_eps: float, vocab_size, + linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000, + max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None, disable_kv_cache=False, + qkv_bias: bool = False): + self.layers = [ + TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, + 0 if disable_kv_cache else max_context, linear, + feed_forward=feed_forward, qk_norm=qk_norm, qkv_bias=qkv_bias) + for _ in range(n_layers) + ] + self.norm = nn.RMSNorm(dim, norm_eps) + self.tok_embeddings = embedding(vocab_size, dim) + self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False) + self.max_context = max_context + self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous().requires_grad_(False) + self.forward_jit = TinyJit(self.forward) if jit else None + + def forward(self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens).contiguous() + freqs_cis = self.freqs_cis.cast(h.dtype)[:, start_pos:start_pos + seqlen, :, :, :] + + if self.max_context != 0 and seqlen > 1: + mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1) + else: + mask = None + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + logits = self.output(self.norm(h).contiguous().contiguous_backward()).contiguous_backward() + if math.isnan(temperature): + return logits + + return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p) + + def __call__(self, tokens: Tensor, start_pos: int, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0): + if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0: + return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context - 1).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p) + return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p) + + # LOCALAI: extract last hidden state for embeddings. Skips the LM head and + # the causal-mask branch is left intact so the pooling sees the full sequence. + def embed(self, tokens: Tensor) -> Tensor: + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens).contiguous() + freqs_cis = self.freqs_cis.cast(h.dtype)[:, 0:seqlen, :, :, :] + mask = Tensor.full((1, 1, seqlen, seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(1) if seqlen > 1 else None + for layer in self.layers: + h = layer(h, 0, freqs_cis, mask) + return self.norm(h) + + +def convert_from_huggingface(weights: dict[str, Tensor], n_layers: int, n_heads: int, n_kv_heads: int, permute_layers: bool = True): + def permute(v: Tensor, n_heads: int): + return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1] if len(v.shape) > 1 else 1).transpose(1, 2).reshape(*v.shape[:2]) + + keymap = { + "model.embed_tokens.weight": "tok_embeddings.weight", + **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)}, + **{f"model.layers.{l}.self_attn.{x}_norm.weight": f"layers.{l}.attention.{x}_norm.weight" for x in ["q", "k"] for l in range(n_layers)}, + **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(n_layers)}, + **{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(n_layers)}, + **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)}, + **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)}, + **{f"model.layers.{l}.mlp.gate.weight": f"layers.{l}.feed_forward.gate.weight" for l in range(n_layers)}, + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + sd = {} + experts = collections.defaultdict(dict) + for k, v in weights.items(): + if ".rotary_emb." in k: + continue + v = v.to(Device.DEFAULT) + if "model.layers" in k: + if ("q_proj" in k or "q_norm" in k) and permute_layers: + v = permute(v, n_heads) + elif ("k_proj" in k or "k_norm" in k) and permute_layers: + v = permute(v, n_kv_heads) + if '.mlp.experts.' in k: + _, _, layer, _, _, expert, name, _ = k.split('.') + experts[f'layers.{layer}.feed_forward.{name}'][int(expert)] = v + continue + sd[keymap[k]] = v + for k, v in experts.items(): + sd[k] = Tensor.stack(*[v[i] for i in range(len(v))]) + + if "output.weight" not in sd and "tok_embeddings.weight" in sd: + sd["output.weight"] = sd["tok_embeddings.weight"] + + return sd + + +def convert_from_gguf(weights: dict[str, Tensor], n_layers: int): + keymap = { + "token_embd.weight": "tok_embeddings.weight", + **{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)}, + **{f"blk.{l}.attn_{x}.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v"] for l in range(n_layers)}, + **{f"blk.{l}.attn_{x}.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v"] for l in range(n_layers)}, + **{f"blk.{l}.attn_output.weight": f"layers.{l}.attention.wo.weight" for l in range(n_layers)}, + **{f"blk.{l}.ffn_norm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)}, + **{f"blk.{l}.ffn_{x}.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)}, + "output_norm.weight": "norm.weight", + "rope_freqs.weight": "rope_freqs.weight", + } + sd = {keymap[k]: v for k, v in weights.items() if k in keymap} + if "output.weight" not in sd and "token_embd.weight" in weights: + sd["output.weight"] = weights["token_embd.weight"] + return sd + + +def fix_bf16(weights: dict[Any, Tensor]): + return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()} diff --git a/backend/python/tinygrad/vendor/stable_diffusion.py b/backend/python/tinygrad/vendor/stable_diffusion.py new file mode 100644 index 000000000..5a993332d --- /dev/null +++ b/backend/python/tinygrad/vendor/stable_diffusion.py @@ -0,0 +1,232 @@ +# Adapted from tinygrad examples/stable_diffusion.py (MIT license). +# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/stable_diffusion.py +# Copyright (c) 2023- the tinygrad authors +# SPDX-License-Identifier: MIT +# +# Local modifications: removed the MLPerf training branch (pulls +# examples/mlperf/initializers which we don't vendor) and the __main__ +# argparse / fetch / profile blocks. Kept the core classes so the LocalAI +# tinygrad backend can instantiate and drive Stable Diffusion v1.x from a +# single checkpoint path. +from collections import namedtuple +from typing import Any, Dict + +import numpy as np +from tinygrad import Tensor, dtypes +from tinygrad.nn import Conv2d, GroupNorm + +from . import clip as clip_mod +from . import unet as unet_mod +from .clip import Closed, Tokenizer +from .unet import UNetModel + + +class AttnBlock: + def __init__(self, in_channels): + self.norm = GroupNorm(32, in_channels) + self.q = Conv2d(in_channels, in_channels, 1) + self.k = Conv2d(in_channels, in_channels, 1) + self.v = Conv2d(in_channels, in_channels, 1) + self.proj_out = Conv2d(in_channels, in_channels, 1) + + def __call__(self, x): + h_ = self.norm(x) + q, k, v = self.q(h_), self.k(h_), self.v(h_) + b, c, h, w = q.shape + q, k, v = [t.reshape(b, c, h * w).transpose(1, 2) for t in (q, k, v)] + h_ = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(b, c, h, w) + return x + self.proj_out(h_) + + +class ResnetBlock: + def __init__(self, in_channels, out_channels=None): + self.norm1 = GroupNorm(32, in_channels) + self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1) + self.norm2 = GroupNorm(32, out_channels) + self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1) + self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else (lambda x: x) + + def __call__(self, x): + h = self.conv1(self.norm1(x).swish()) + h = self.conv2(self.norm2(h).swish()) + return self.nin_shortcut(x) + h + + +class Mid: + def __init__(self, block_in): + self.block_1 = ResnetBlock(block_in, block_in) + self.attn_1 = AttnBlock(block_in) + self.block_2 = ResnetBlock(block_in, block_in) + + def __call__(self, x): + return x.sequential([self.block_1, self.attn_1, self.block_2]) + + +class Decoder: + def __init__(self): + sz = [(128, 256), (256, 512), (512, 512), (512, 512)] + self.conv_in = Conv2d(4, 512, 3, padding=1) + self.mid = Mid(512) + + arr = [] + for i, s in enumerate(sz): + arr.append({"block": [ResnetBlock(s[1], s[0]), ResnetBlock(s[0], s[0]), ResnetBlock(s[0], s[0])]}) + if i != 0: + arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)} + self.up = arr + + self.norm_out = GroupNorm(32, 128) + self.conv_out = Conv2d(128, 3, 3, padding=1) + + def __call__(self, x): + x = self.conv_in(x) + x = self.mid(x) + for l in self.up[::-1]: + for b in l['block']: + x = b(x) + if 'upsample' in l: + bs, c, py, px = x.shape + x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py * 2, px * 2) + x = l['upsample']['conv'](x) + x.realize() + return self.conv_out(self.norm_out(x).swish()) + + +class Encoder: + def __init__(self): + sz = [(128, 128), (128, 256), (256, 512), (512, 512)] + self.conv_in = Conv2d(3, 128, 3, padding=1) + + arr = [] + for i, s in enumerate(sz): + arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]}) + if i != 3: + arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))} + self.down = arr + + self.mid = Mid(512) + self.norm_out = GroupNorm(32, 512) + self.conv_out = Conv2d(512, 8, 3, padding=1) + + def __call__(self, x): + x = self.conv_in(x) + for l in self.down: + for b in l['block']: + x = b(x) + if 'downsample' in l: + x = l['downsample']['conv'](x) + x = self.mid(x) + return self.conv_out(self.norm_out(x).swish()) + + +class AutoencoderKL: + def __init__(self): + self.encoder = Encoder() + self.decoder = Decoder() + self.quant_conv = Conv2d(8, 8, 1) + self.post_quant_conv = Conv2d(4, 4, 1) + + def __call__(self, x): + latent = self.encoder(x) + latent = self.quant_conv(latent) + latent = latent[:, 0:4] + latent = self.post_quant_conv(latent) + return self.decoder(latent) + + +def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000): + betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2 + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + return Tensor(alphas_cumprod) + + +# SD1.x UNet hyperparameters (same as upstream `unet_params`). +UNET_PARAMS_SD1: Dict[str, Any] = { + "adm_in_ch": None, + "in_ch": 4, + "out_ch": 4, + "model_ch": 320, + "attention_resolutions": [4, 2, 1], + "num_res_blocks": 2, + "channel_mult": [1, 2, 4, 4], + "n_heads": 8, + "transformer_depth": [1, 1, 1, 1], + "ctx_dim": 768, + "use_linear": False, +} + + +class StableDiffusion: + """Stable Diffusion 1.x pipeline, adapted from tinygrad's reference example. + + Drives the native CompVis `sd-v1-*.ckpt` checkpoint format (the only one + the vendored weight layout handles). For HuggingFace safetensors pipelines + the caller is expected to download / merge the `.ckpt` equivalent before + calling LoadModel. + """ + + def __init__(self): + self.alphas_cumprod = get_alphas_cumprod() + self.first_stage_model = AutoencoderKL() + self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])( + transformer=namedtuple("Transformer", ["text_model"])(text_model=Closed.ClipTextTransformer()) + ) + self.model = namedtuple("DiffusionModel", ["diffusion_model"])( + diffusion_model=UNetModel(**UNET_PARAMS_SD1) + ) + + # DDIM update step. + def _update(self, x, e_t, a_t, a_prev): + sqrt_one_minus_at = (1 - a_t).sqrt() + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1.0 - a_prev).sqrt() * e_t + return a_prev.sqrt() * pred_x0 + dir_xt + + def _model_output(self, uncond, cond, latent, timestep, guidance): + latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, uncond.cat(cond, dim=0)) + uncond_latent, cond_latent = latents[0:1], latents[1:2] + return uncond_latent + guidance * (cond_latent - uncond_latent) + + def step(self, uncond, cond, latent, timestep, a_t, a_prev, guidance): + e_t = self._model_output(uncond, cond, latent, timestep, guidance) + return self._update(latent, e_t, a_t, a_prev).realize() + + def decode(self, x): + x = self.first_stage_model.post_quant_conv(1 / 0.18215 * x) + x = self.first_stage_model.decoder(x) + x = (x + 1.0) / 2.0 + x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255 + return x.cast(dtypes.uint8) + + def encode_prompt(self, tokenizer, prompt: str): + ids = Tensor([tokenizer.encode(prompt)]) + return self.cond_stage_model.transformer.text_model(ids).realize() + + +def run_sd15(model: StableDiffusion, prompt: str, negative_prompt: str, steps: int, guidance: float, seed: int): + """Generate a single 512x512 image. Returns a (512,512,3) uint8 tensor.""" + tokenizer = Tokenizer.ClipTokenizer() + + context = model.encode_prompt(tokenizer, prompt) + uncond = model.encode_prompt(tokenizer, negative_prompt) + + timesteps = list(range(1, 1000, 1000 // steps)) + alphas = model.alphas_cumprod[Tensor(timesteps)] + alphas_prev = Tensor([1.0]).cat(alphas[:-1]) + + if seed is not None: + Tensor.manual_seed(seed) + latent = Tensor.randn(1, 4, 64, 64) + + for index in range(len(timesteps) - 1, -1, -1): + timestep = timesteps[index] + tid = Tensor([index]) + latent = model.step( + uncond, context, latent, + Tensor([timestep]), + alphas[tid], alphas_prev[tid], + Tensor([guidance]), + ) + + return model.decode(latent).realize() diff --git a/backend/python/tinygrad/vendor/unet.py b/backend/python/tinygrad/vendor/unet.py new file mode 100644 index 000000000..22a07d170 --- /dev/null +++ b/backend/python/tinygrad/vendor/unet.py @@ -0,0 +1,267 @@ +# Vendored verbatim from tinygrad extra/models/unet.py (MIT license). +# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/unet.py +# Copyright (c) 2023- the tinygrad authors +# SPDX-License-Identifier: MIT +from tinygrad import Tensor, dtypes, nn +from tinygrad.device import is_dtype_supported +from typing import Optional, Union, List, Any, Tuple, Callable +import math + +# allow for monkeypatching +Linear, Conv2d, GroupNorm, LayerNorm = nn.Linear, nn.Conv2d, nn.GroupNorm, nn.LayerNorm +attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Tensor.gelu, dtypes.float16 + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207 +def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000): + half = dim // 2 + freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp() + args = timesteps.unsqueeze(1) * freqs.unsqueeze(0) + out = Tensor.cat(args.cos(), args.sin(), dim=-1) + return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out + +class ResBlock: + def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32): + self.in_layers = [ + GroupNorm(num_groups, channels), + Tensor.silu, + Conv2d(channels, out_channels, 3, padding=1), + ] + self.emb_layers = [ + Tensor.silu, + Linear(emb_channels, out_channels), + ] + self.out_layers = [ + GroupNorm(num_groups, out_channels), + Tensor.silu, + lambda x: x, # needed for weights loading code to work + Conv2d(out_channels, out_channels, 3, padding=1), + ] + self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else (lambda x: x) + + def __call__(self, x:Tensor, emb:Tensor) -> Tensor: + h = x.sequential(self.in_layers) + emb_out = emb.sequential(self.emb_layers) + h = h + emb_out.reshape(*emb_out.shape, 1, 1) + h = h.sequential(self.out_layers) + return self.skip_connection(x) + h + +class CrossAttention: + def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int): + self.to_q = Linear(query_dim, n_heads*d_head, bias=False) + self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False) + self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False) + self.num_heads = n_heads + self.head_size = d_head + self.attn = attention + self.to_out = [Linear(n_heads*d_head, query_dim)] + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + ctx = x if ctx is None else ctx + q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx) + q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] + attention = self.attn(q, k, v).transpose(1,2) + h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) + return h_.sequential(self.to_out) + +class GEGLU: + def __init__(self, dim_in:int, dim_out:int): + self.proj = Linear(dim_in, dim_out * 2) + self.gelu = gelu + self.dim_out = dim_out + + def __call__(self, x:Tensor) -> Tensor: + x, gate = self.proj(x).chunk(2, dim=-1) + return x * self.gelu(gate) + +class FeedForward: + def __init__(self, dim:int, mult:int=4): + self.net: tuple[GEGLU, Callable, nn.Linear] = ( + GEGLU(dim, dim*mult), + lambda x: x, # needed for weights loading code to work + Linear(dim*mult, dim) + ) + + def __call__(self, x:Tensor) -> Tensor: + return x.sequential(list(self.net)) + +class BasicTransformerBlock: + def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int): + self.attn1 = CrossAttention(dim, dim, n_heads, d_head) + self.ff = FeedForward(dim) + self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head) + self.norm1 = LayerNorm(dim) + self.norm2 = LayerNorm(dim) + self.norm3 = LayerNorm(dim) + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + x = x + self.attn1(self.norm1(x)) + x = x + self.attn2(self.norm2(x), ctx=ctx) + x = x + self.ff(self.norm3(x)) + return x + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619 +class SpatialTransformer: + def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1, + norm_eps:float=1e-5): + if isinstance(ctx_dim, int): + ctx_dim = [ctx_dim]*depth + else: + assert isinstance(ctx_dim, list) and depth == len(ctx_dim) + self.norm = GroupNorm(32, channels, eps=norm_eps) + assert channels == n_heads * d_head + self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1) + self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)] + self.proj_out = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1) + self.use_linear = use_linear + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + ops = [ (lambda z: z.reshape(b, c, h*w).permute(0,2,1)), (lambda z: self.proj_in(z)) ] + x = x.sequential(ops if self.use_linear else ops[::-1]) + for block in self.transformer_blocks: + x = block(x, ctx=ctx) + ops = [ (lambda z: self.proj_out(z)), (lambda z: z.permute(0,2,1).reshape(b, c, h, w)) ] + x = x.sequential(ops if self.use_linear else ops[::-1]) + return x + x_in + +class Downsample: + def __init__(self, channels:int): + self.op = Conv2d(channels, channels, 3, stride=2, padding=1) + + def __call__(self, x:Tensor) -> Tensor: + return self.op(x) + +class Upsample: + def __init__(self, channels:int): + self.conv = Conv2d(channels, channels, 3, padding=1) + + def __call__(self, x:Tensor) -> Tensor: + bs,c,py,px = x.shape + z = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2) + return self.conv(z) + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472 +class UNetModel: + def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int, + channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None, + n_heads:Optional[int]=None, num_groups:int=32, st_norm_eps:float=1e-5): + self.model_ch = model_ch + self.num_res_blocks = [num_res_blocks] * len(channel_mult) + + self.attention_resolutions = attention_resolutions + self.d_head = d_head + self.n_heads = n_heads + def get_d_and_n_heads(dims:int) -> Tuple[int,int]: + if self.d_head is None: + assert self.n_heads is not None, f"d_head and n_heads cannot both be None" + return dims // self.n_heads, self.n_heads + else: + assert self.n_heads is None, f"d_head and n_heads cannot both be non-None" + return self.d_head, dims // self.d_head + + time_embed_dim = model_ch * 4 + self.time_embed = [ + Linear(model_ch, time_embed_dim), + Tensor.silu, + Linear(time_embed_dim, time_embed_dim), + ] + + if adm_in_ch is not None: + self.label_emb = [ + [ + Linear(adm_in_ch, time_embed_dim), + Tensor.silu, + Linear(time_embed_dim, time_embed_dim), + ] + ] + + self.input_blocks: List[Any] = [ + [Conv2d(in_ch, model_ch, 3, padding=1)] + ] + input_block_channels = [model_ch] + ch = model_ch + ds = 1 + for idx, mult in enumerate(channel_mult): + for _ in range(self.num_res_blocks[idx]): + layers: List[Any] = [ + ResBlock(ch, time_embed_dim, model_ch*mult, num_groups), + ] + ch = mult * model_ch + if ds in attention_resolutions: + d_head, n_heads = get_d_and_n_heads(ch) + layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps)) + + self.input_blocks.append(layers) + input_block_channels.append(ch) + + if idx != len(channel_mult) - 1: + self.input_blocks.append([ + Downsample(ch), + ]) + input_block_channels.append(ch) + ds *= 2 + + d_head, n_heads = get_d_and_n_heads(ch) + self.middle_block: List = [ + ResBlock(ch, time_embed_dim, ch, num_groups), + SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1], norm_eps=st_norm_eps), + ResBlock(ch, time_embed_dim, ch, num_groups), + ] + + self.output_blocks = [] + for idx, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[idx] + 1): + ich = input_block_channels.pop() + layers = [ + ResBlock(ch + ich, time_embed_dim, model_ch*mult, num_groups), + ] + ch = model_ch * mult + + if ds in attention_resolutions: + d_head, n_heads = get_d_and_n_heads(ch) + layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps)) + + if idx > 0 and i == self.num_res_blocks[idx]: + layers.append(Upsample(ch)) + ds //= 2 + self.output_blocks.append(layers) + + self.out = [ + GroupNorm(num_groups, ch), + Tensor.silu, + Conv2d(model_ch, out_ch, 3, padding=1), + ] + + def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor: + t_emb = timestep_embedding(tms, self.model_ch) + emb = t_emb.sequential(self.time_embed) + + if y is not None: + assert y.shape[0] == x.shape[0] + emb = emb + y.sequential(self.label_emb[0]) + + if is_dtype_supported(mixed_precision_dtype): + emb = emb.cast(mixed_precision_dtype) + ctx = ctx.cast(mixed_precision_dtype) + x = x .cast(mixed_precision_dtype) + + def run(x:Tensor, bb) -> Tensor: + if isinstance(bb, ResBlock): x = bb(x, emb) + elif isinstance(bb, SpatialTransformer): x = bb(x, ctx) + else: x = bb(x) + return x + + saved_inputs = [] + for b in self.input_blocks: + for bb in b: + x = run(x, bb) + saved_inputs.append(x) + for bb in self.middle_block: + x = run(x, bb) + for b in self.output_blocks: + x = x.cat(saved_inputs.pop(), dim=1) + for bb in b: + x = run(x, bb) + return x.sequential(self.out) diff --git a/backend/python/tinygrad/vendor/whisper.py b/backend/python/tinygrad/vendor/whisper.py new file mode 100644 index 000000000..70aeb03a4 --- /dev/null +++ b/backend/python/tinygrad/vendor/whisper.py @@ -0,0 +1,274 @@ +# Adapted from tinygrad examples/whisper.py (MIT license). +# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/whisper.py +# Copyright (c) 2023- the tinygrad authors +# SPDX-License-Identifier: MIT +# +# Local modifications: removed the pyaudio listener / __main__ block; the rest +# is the core Whisper model + preprocessing + single-file transcription path. +from __future__ import annotations + +import base64 +import collections +import itertools +from typing import List, Literal, Optional, Union + +import numpy as np +from tinygrad import Tensor, TinyJit, Variable, dtypes, nn +from tinygrad.helpers import fetch +from tinygrad.nn.state import load_state_dict, torch_load + +from .audio_helpers import mel + + +class MultiHeadAttention: + def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self', None] = None, max_self_attn_cache_len=None): + self.n_head = n_head + self.query = nn.Linear(n_state, n_state) + self.key = nn.Linear(n_state, n_state, bias=False) + self.value = nn.Linear(n_state, n_state) + self.out = nn.Linear(n_state, n_state) + self.kv_caching = kv_caching + self.max_self_attn_cache_len = max_self_attn_cache_len + + def __call__(self, x, xa=None, mask=None, len=None): + if self.kv_caching == 'cross': + if xa is not None: + k, v = self.key(xa), self.value(xa) + if not hasattr(self, 'cache_k'): + self.cache_k, self.cache_v = k, v + else: + self.cache_k.assign(k).realize() + self.cache_v.assign(v).realize() + else: + k, v = self.cache_k, self.cache_v + else: + k, v = self.key(x), self.value(x) + if self.kv_caching == 'self': + if not hasattr(self, 'cache_k'): + self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2]) + self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2]) + k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1) + v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1) + padding = self.max_self_attn_cache_len - len - x.shape[1] + self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize() + self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize() + + q = self.query(x) + n_ctx = q.shape[1] + head_dim = q.shape[-1] // self.n_head + q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None) + wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2) + return self.out(wv) + + +class ResidualAttentionBlock: + def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None): + self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len) + self.attn_ln = nn.LayerNorm(n_state) + self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None + self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None + self.mlp = [nn.Linear(n_state, n_state * 4), Tensor.gelu, nn.Linear(n_state * 4, n_state)] + self.mlp_ln = nn.LayerNorm(n_state) + + def __call__(self, x, xa=None, mask=None, len=None): + x = x + self.attn(self.attn_ln(x), mask=mask, len=len) + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa) + x = x + self.mlp_ln(x).sequential(self.mlp) + return x.realize() + + +class AudioEncoder: + def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_): + self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1) + self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)] + self.ln_post = nn.LayerNorm(n_audio_state) + self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state) + self.encode = TinyJit(self.__call__) + + def __call__(self, x): + x = self.conv1(x).gelu() + x = self.conv2(x).gelu() + x = x.permute(0, 2, 1) + x = x + self.positional_embedding[:x.shape[1]] + x = x.sequential(self.blocks) + x = self.ln_post(x) + return x.realize() + + +class TextDecoder: + def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_): + self.max_tokens_to_sample = n_text_ctx // 2 + self.max_self_attn_cache_len = n_text_ctx + self.token_embedding = nn.Embedding(n_vocab, n_text_state) + self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state) + self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)] + self.ln = nn.LayerNorm(n_text_state) + self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize() + self.getjitted = collections.defaultdict(lambda: TinyJit(self.forward)) + + def __call__(self, x, pos, encoded_audio): + pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len - 1).bind(pos) if pos else 0 + return self.getjitted[x.shape](x, pos, encoded_audio) + + def forward(self, x, pos, encoded_audio): + seqlen = x.shape[-1] + x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos + seqlen), None)) + for block in self.blocks: + x = block(x, xa=encoded_audio, mask=self.mask, len=pos) + return self.output_tok(x) + + def output_tok(self, x): + return (self.ln(x) @ self.token_embedding.weight.T).realize() + + +class Whisper: + def __init__(self, dims, batch_size=1): + self.encoder = AudioEncoder(**dims) + self.decoder = TextDecoder(**dims) + self.is_multilingual = dims["n_vocab"] == 51865 + self.batch_size = batch_size + + +RATE = 16000 +SEGMENT_SECONDS = 30 +SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS +N_FFT = 400 +HOP_LENGTH = 160 +N_MELS = 80 +FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH + + +def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate: bool = False) -> np.ndarray: + import librosa + + def pad_or_trim(arr, target_len): + if len(arr) == target_len: + return arr + if len(arr) < target_len: + return np.pad(arr, (0, target_len - len(arr)), 'constant') + return arr[:target_len] + + max_len = SAMPLES_PER_SEGMENT if truncate else max(len(w) for w in waveforms) + if (r := max_len % SAMPLES_PER_SEGMENT) > 0: + max_len += SAMPLES_PER_SEGMENT - r + + waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms))) + if waveforms.shape[0] < batch_size: + waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0))) + + stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle) + magnitudes = np.absolute(stft[..., :-1]) ** 2 + mel_spec = mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS).numpy() @ magnitudes + log_spec = np.log10(np.clip(mel_spec, 1e-10, None)) + log_spec = np.maximum(log_spec, log_spec.max((1, 2), keepdims=True) - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + + +LANGUAGES = { + "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", + "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "it": "italian", +} + + +def get_encoding(encoding_name: str): + import tiktoken + + with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f: + ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)} + n_vocab = len(ranks) + specials = [ + "<|endoftext|>", + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], + ] + special_tokens = dict(zip(specials, itertools.count(n_vocab))) + return tiktoken.Encoding( + name=encoding_name, + explicit_n_vocab=n_vocab + len(specials), + pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + mergeable_ranks=ranks, + special_tokens=special_tokens, + ) + + +MODEL_URLS = { + "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", + "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", + "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", + "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", + "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", + "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", +} + + +def init_whisper(model_name: str = "base", batch_size: int = 1): + filename = fetch(MODEL_URLS[model_name]) + state = torch_load(filename) + model = Whisper(state['dims'], batch_size) + load_state_dict(model, state['model_state_dict'], strict=False) + enc = get_encoding("multilingual" if model.is_multilingual else "gpt2") + return model, enc + + +def load_file_waveform(filename: str): + import librosa + waveform, _ = librosa.load(filename, sr=RATE) + return waveform + + +def transcribe_waveform(model: Whisper, enc, waveforms, language: Optional[str] = None, truncate: bool = False) -> str: + log_spec = prep_audio(waveforms, model.batch_size, truncate) + nsample = model.decoder.max_tokens_to_sample + nctx = model.decoder.max_self_attn_cache_len + + start_tokens = [enc._special_tokens["<|startoftranscript|>"]] + if model.is_multilingual: + lang = language if (language and language in LANGUAGES) else "en" + language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index(lang) + start_tokens.append(language_token) + start_tokens.append(enc._special_tokens["<|transcribe|>"]) + start_tokens.append(enc._special_tokens["<|notimestamps|>"]) + + eot = enc._special_tokens["<|endoftext|>"] + + def inferloop(ctx, encoded_audio): + pos, next_tokens = 0, ctx + for _ in range(nsample): + next_tokens = model.decoder(Tensor(next_tokens, dtype=dtypes.int32), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1) + next_tokens[ctx[:, -1] == eot] = eot + ctx = np.concatenate((ctx, next_tokens), axis=1) + pos = ctx.shape[-1] - 1 + if (next_tokens == eot).all() or pos == nctx: + break + return ctx + + ctx = np.tile(start_tokens, (model.batch_size, 1)) + transcriptions: list[list[int]] = [[] for _ in waveforms] + + for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT): + encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT])) + ctx_arr = inferloop(np.array(ctx), encoded_audio) + for i, arr in enumerate(ctx_arr): + if i >= len(waveforms): + break + end_idxs = np.where(arr == eot)[0] + start_idx = np.where(arr == start_tokens[-1])[0][0] + 1 + end_idx = end_idxs[0] if len(end_idxs) else None + transcriptions[i].extend(arr[start_idx:end_idx]) + ctx = ctx_arr + + texts = [enc.decode([int(t) for t in toks]).strip() for toks in transcriptions] + return texts[0] if len(texts) == 1 else "\n".join(texts) diff --git a/tests/e2e-backends/backend_test.go b/tests/e2e-backends/backend_test.go index 6b1eb49ec..f36fd02ae 100644 --- a/tests/e2e-backends/backend_test.go +++ b/tests/e2e-backends/backend_test.go @@ -44,13 +44,20 @@ import ( // BACKEND_TEST_AUDIO_FILE Path to an already-available sample audio file. // BACKEND_TEST_CAPS Comma-separated list of capabilities to exercise. // Supported values: health, load, predict, stream, -// embeddings, tools, transcription. +// embeddings, tools, transcription, image. // Defaults to "health,load,predict,stream". // A backend that only does embeddings would set this to -// "health,load,embeddings"; an image/TTS backend that cannot -// be driven by a text prompt can set it to "health,load". +// "health,load,embeddings"; an image-generation backend +// that cannot be driven by a text prompt can set it to +// "health,load,image". // "tools" asks the backend to extract a tool call from the // model output into ChatDelta.tool_calls. +// "image" exercises the GenerateImage RPC and asserts a +// non-empty file is written to the requested dst path. +// BACKEND_TEST_IMAGE_PROMPT Override the positive prompt for the image spec +// (default: "a photograph of an astronaut riding a horse"). +// BACKEND_TEST_IMAGE_STEPS Override the diffusion step count for the image spec +// (default: 4 — keeps CPU-only runs under a few minutes). // BACKEND_TEST_PROMPT Override the prompt used by predict/stream specs. // BACKEND_TEST_CTX_SIZE Override the context size passed to LoadModel (default 512). // BACKEND_TEST_THREADS Override Threads passed to LoadModel (default 4). @@ -75,11 +82,14 @@ const ( capEmbeddings = "embeddings" capTools = "tools" capTranscription = "transcription" + capImage = "image" - defaultPrompt = "The capital of France is" - streamPrompt = "Once upon a time" - defaultToolPrompt = "What's the weather like in Paris, France?" - defaultToolName = "get_weather" + defaultPrompt = "The capital of France is" + streamPrompt = "Once upon a time" + defaultToolPrompt = "What's the weather like in Paris, France?" + defaultToolName = "get_weather" + defaultImagePrompt = "a photograph of an astronaut riding a horse" + defaultImageSteps = 4 ) func defaultCaps() map[string]bool { @@ -349,6 +359,40 @@ var _ = Describe("Backend container", Ordered, func() { GinkgoWriter.Printf("Embedding: %d dims\n", len(res.GetEmbeddings())) }) + It("generates an image via GenerateImage", func() { + if !caps[capImage] { + Skip("image capability not enabled") + } + + imgPrompt := os.Getenv("BACKEND_TEST_IMAGE_PROMPT") + if imgPrompt == "" { + imgPrompt = defaultImagePrompt + } + steps := envInt32("BACKEND_TEST_IMAGE_STEPS", defaultImageSteps) + + dst := filepath.Join(workDir, "generated.png") + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + res, err := client.GenerateImage(ctx, &pb.GenerateImageRequest{ + PositivePrompt: imgPrompt, + NegativePrompt: "", + Width: 512, + Height: 512, + Step: steps, + Seed: 42, + Dst: dst, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(res.GetSuccess()).To(BeTrue(), "GenerateImage failed: %s", res.GetMessage()) + + info, err := os.Stat(dst) + Expect(err).NotTo(HaveOccurred(), "GenerateImage did not write a file at %s", dst) + Expect(info.Size()).To(BeNumerically(">", int64(0)), + "GenerateImage wrote an empty file at %s", dst) + GinkgoWriter.Printf("GenerateImage: wrote %s (%d bytes)\n", dst, info.Size()) + }) + It("extracts tool calls into ChatDelta", func() { if !caps[capTools] { Skip("tools capability not enabled")