mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-16 12:59:33 -04:00
feat(backend): add tinygrad multimodal backend (experimental) (#9364)
* feat(backend): add tinygrad multimodal backend
Wire tinygrad as a new Python backend covering LLM text generation with
native tool-call extraction, embeddings, Stable Diffusion 1.x image
generation, and Whisper speech-to-text from a single self-contained
container.
Backend (`backend/python/tinygrad/`):
- `backend.py` gRPC servicer with LLM Predict/PredictStream (auto-detects
Llama / Qwen2 / Mistral architecture from `config.json`, supports
safetensors and GGUF), Embedding via mean-pooled last hidden state,
GenerateImage via the vendored SD1.x pipeline, AudioTranscription +
AudioTranscriptionStream via the vendored Whisper inference loop, plus
Tokenize / ModelMetadata / Status / Free.
- Vendored upstream model code under `vendor/` (MIT, headers preserved):
llama.py with an added `qkv_bias` flag for Qwen2-family bias support
and an `embed()` method that returns the last hidden state, plus
clip.py, unet.py, stable_diffusion.py (trimmed to drop the MLPerf
training branch that pulls `mlperf.initializers`), audio_helpers.py
and whisper.py (trimmed to drop the pyaudio listener).
- Pluggable tool-call parsers under `tool_parsers/`: hermes (Qwen2.5 /
Hermes), llama3_json (Llama 3.1+), qwen3_xml (Qwen 3), mistral
(Mistral / Mixtral). Auto-selected from model architecture or `Options`.
- `install.sh` pins Python 3.11.14 (tinygrad >=0.12 needs >=3.11; the
default portable python is 3.10).
- `package.sh` bundles libLLVM.so.1 + libedit/libtinfo/libgomp/libsndfile
into the scratch image. `run.sh` sets `CPU_LLVM=1` and `LLVM_PATH` so
tinygrad's CPU device uses the in-process libLLVM JIT instead of
shelling out to the missing `clang` binary.
- Local unit tests for Health and the four parsers in `test.py`.
Build wiring:
- Root `Makefile`: `.NOTPARALLEL`, `prepare-test-extra`, `test-extra`,
`BACKEND_TINYGRAD = tinygrad|python|.|false|true`,
docker-build-target eval, and `docker-build-backends` aggregator.
- `.github/workflows/backend.yml`: cpu / cuda12 / cuda13 build matrix
entries (mirrors the transformers backend placement).
- `backend/index.yaml`: `&tinygrad` meta + cpu/cuda12/cuda13 image
entries (latest + development).
E2E test wiring:
- `tests/e2e-backends/backend_test.go` gains an `image` capability that
exercises GenerateImage and asserts a non-empty PNG is written to
`dst`. New `BACKEND_TEST_IMAGE_PROMPT` / `BACKEND_TEST_IMAGE_STEPS`
knobs.
- Five new make targets next to `test-extra-backend-vllm`:
- `test-extra-backend-tinygrad` — Qwen2.5-0.5B-Instruct + hermes,
mirrors the vllm target 1:1 (5/9 specs in ~57s).
- `test-extra-backend-tinygrad-embeddings` — same model, embeddings
via LLM hidden state (3/9 in ~10s).
- `test-extra-backend-tinygrad-sd` — stable-diffusion-v1-5 mirror,
health/load/image (3/9 in ~10min, 4 diffusion steps on CPU).
- `test-extra-backend-tinygrad-whisper` — openai/whisper-tiny.en
against jfk.wav from whisper.cpp samples (4/9 in ~49s).
- `test-extra-backend-tinygrad-all` aggregate.
All four targets land green on the first MVP pass: 15 specs total, 0
failures across LLM+tools, embeddings, image generation, and speech
transcription.
* refactor(tinygrad): collapse to a single backend image
tinygrad generates its own GPU kernels (PTX renderer for CUDA, the
autogen ctypes wrappers for HIP / Metal / WebGPU) and never links
against cuDNN, cuBLAS, or any toolkit-version-tied library. The only
runtime dependency that varies across hosts is the driver's libcuda.so.1
/ libamdhip64.so, which are injected into the container at run time by
the nvidia-container / rocm runtimes. So unlike torch- or vLLM-based
backends, there is no reason to ship per-CUDA-version images.
- Drop the cuda12-tinygrad and cuda13-tinygrad build-matrix entries
from .github/workflows/backend.yml. The sole remaining entry is
renamed to -tinygrad (from -cpu-tinygrad) since it is no longer
CPU-only.
- Collapse backend/index.yaml to a single meta + development pair.
The meta anchor carries the latest uri directly; the development
entry points at the master tag.
- run.sh picks the tinygrad device at launch time by probing
/usr/lib/... for libcuda.so.1 / libamdhip64.so. When libcuda is
visible we set CUDA=1 + CUDA_PTX=1 so tinygrad uses its own PTX
renderer (avoids any nvrtc/toolkit dependency); otherwise we fall
back to HIP or CLANG. CPU_LLVM=1 + LLVM_PATH keep the in-process
libLLVM JIT for the CLANG path.
- backend.py's _select_tinygrad_device() is trimmed to a CLANG-only
fallback since production device selection happens in run.sh.
Re-ran test-extra-backend-tinygrad after the change:
Ran 5 of 9 Specs in 56.541 seconds — 5 Passed, 0 Failed
This commit is contained in:
committed by
GitHub
parent
8487058673
commit
6f0051301b
19
.github/workflows/backend.yml
vendored
19
.github/workflows/backend.yml
vendored
@@ -105,6 +105,25 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# tinygrad ships a single image — its CPU device uses bundled
|
||||
# libLLVM, and its CUDA / HIP / Metal devices dlopen the host
|
||||
# driver libraries at runtime via tinygrad's ctypes autogen
|
||||
# wrappers. There is no toolkit-version split because tinygrad
|
||||
# generates kernels itself (PTX renderer for CUDA) and never
|
||||
# links against cuDNN/cuBLAS/torch.
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-tinygrad'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "tinygrad"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
|
||||
58
Makefile
58
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -432,6 +432,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/whisperx
|
||||
$(MAKE) -C backend/python/ace-step
|
||||
$(MAKE) -C backend/python/trl
|
||||
$(MAKE) -C backend/python/tinygrad
|
||||
$(MAKE) -C backend/rust/kokoros kokoros-grpc
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
@@ -454,6 +455,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/whisperx test
|
||||
$(MAKE) -C backend/python/ace-step test
|
||||
$(MAKE) -C backend/python/trl test
|
||||
$(MAKE) -C backend/python/tinygrad test
|
||||
$(MAKE) -C backend/rust/kokoros test
|
||||
|
||||
##
|
||||
@@ -550,6 +552,56 @@ test-extra-backend-vllm: docker-build-vllm
|
||||
BACKEND_TEST_OPTIONS=tool_parser:hermes \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## tinygrad mirrors the vllm target (same model, same caps, same parser) so
|
||||
## the two backends are directly comparable. The LLM path covers Predict,
|
||||
## streaming and native tool-call extraction. Companion targets below cover
|
||||
## embeddings, Stable Diffusion and Whisper — run them individually or via
|
||||
## the `test-extra-backend-tinygrad-all` aggregate.
|
||||
test-extra-backend-tinygrad: docker-build-tinygrad
|
||||
BACKEND_IMAGE=local-ai-backend:tinygrad \
|
||||
BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \
|
||||
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
|
||||
BACKEND_TEST_OPTIONS=tool_parser:hermes \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## tinygrad — embeddings via LLM last-hidden-state pooling. Reuses the same
|
||||
## Qwen2.5-0.5B-Instruct as the chat target so we don't need a separate BERT
|
||||
## vendor; the Embedding RPC mean-pools and L2-normalizes the last-layer
|
||||
## hidden state.
|
||||
test-extra-backend-tinygrad-embeddings: docker-build-tinygrad
|
||||
BACKEND_IMAGE=local-ai-backend:tinygrad \
|
||||
BACKEND_TEST_MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct \
|
||||
BACKEND_TEST_CAPS=health,load,embeddings \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## tinygrad — Stable Diffusion 1.5. The original CompVis/runwayml repos have
|
||||
## been gated, so we use the community-maintained mirror at
|
||||
## stable-diffusion-v1-5/stable-diffusion-v1-5 with the EMA-only pruned
|
||||
## checkpoint (~4.3GB). Step count is kept low (4) so a CPU-only run finishes
|
||||
## in a few minutes; bump BACKEND_TEST_IMAGE_STEPS for higher quality.
|
||||
test-extra-backend-tinygrad-sd: docker-build-tinygrad
|
||||
BACKEND_IMAGE=local-ai-backend:tinygrad \
|
||||
BACKEND_TEST_MODEL_NAME=stable-diffusion-v1-5/stable-diffusion-v1-5 \
|
||||
BACKEND_TEST_CAPS=health,load,image \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## tinygrad — Whisper. Loads OpenAI's tiny.en checkpoint (smallest at ~75MB)
|
||||
## from the original azure CDN through tinygrad's `fetch` helper, and
|
||||
## transcribes the canonical jfk.wav fixture from whisper.cpp's CI samples.
|
||||
## Exercises both AudioTranscription and AudioTranscriptionStream.
|
||||
test-extra-backend-tinygrad-whisper: docker-build-tinygrad
|
||||
BACKEND_IMAGE=local-ai-backend:tinygrad \
|
||||
BACKEND_TEST_MODEL_NAME=openai/whisper-tiny.en \
|
||||
BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \
|
||||
BACKEND_TEST_CAPS=health,load,transcription \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
test-extra-backend-tinygrad-all: \
|
||||
test-extra-backend-tinygrad \
|
||||
test-extra-backend-tinygrad-embeddings \
|
||||
test-extra-backend-tinygrad-sd \
|
||||
test-extra-backend-tinygrad-whisper
|
||||
|
||||
## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool
|
||||
## parser from the chat template, so no tool_parser: option is needed (it
|
||||
## would be ignored at runtime). Run this on macOS / arm64 with Metal; the
|
||||
@@ -707,6 +759,7 @@ BACKEND_MLX_VLM = mlx-vlm|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||
BACKEND_TINYGRAD = tinygrad|python|.|false|true
|
||||
|
||||
# Rust backends
|
||||
BACKEND_KOKOROS = kokoros|rust|.|false|true
|
||||
@@ -778,6 +831,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
|
||||
@@ -785,7 +839,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -361,6 +361,34 @@
|
||||
intel: "intel-rerankers"
|
||||
amd: "rocm-rerankers"
|
||||
metal: "metal-rerankers"
|
||||
- &tinygrad
|
||||
name: "tinygrad"
|
||||
alias: "tinygrad"
|
||||
license: MIT
|
||||
description: |
|
||||
tinygrad is a minimalist deep-learning framework with zero runtime
|
||||
dependencies that targets CUDA, ROCm, Metal, WebGPU and CPU (CLANG).
|
||||
The LocalAI tinygrad backend exposes a single multimodal runtime that
|
||||
covers LLM text generation (Llama / Qwen / Mistral via safetensors or
|
||||
GGUF) with native tool-call extraction, BERT-family embeddings,
|
||||
Stable Diffusion 1.x / 2 / XL image generation, and Whisper speech-to-text.
|
||||
|
||||
Single image: tinygrad generates its own GPU kernels and dlopens the
|
||||
host driver libraries at runtime, so there is no per-toolkit build
|
||||
split. The same image runs CPU-only or accelerates against
|
||||
CUDA / ROCm / Metal when the host driver is visible.
|
||||
urls:
|
||||
- https://github.com/tinygrad/tinygrad
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-tinygrad"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-tinygrad
|
||||
tags:
|
||||
- text-to-text
|
||||
- LLM
|
||||
- embeddings
|
||||
- image-generation
|
||||
- transcription
|
||||
- multimodal
|
||||
- &transformers
|
||||
name: "transformers"
|
||||
icon: https://avatars.githubusercontent.com/u/25720743?s=200&v=4
|
||||
@@ -1993,6 +2021,15 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rerankers"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-rerankers
|
||||
## tinygrad
|
||||
## Single image — the meta anchor above carries the latest uri directly
|
||||
## since there is only one variant. The development entry below points at
|
||||
## the master tag.
|
||||
- !!merge <<: *tinygrad
|
||||
name: "tinygrad-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-tinygrad"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-tinygrad
|
||||
## Transformers
|
||||
- !!merge <<: *transformers
|
||||
name: "transformers-development"
|
||||
|
||||
25
backend/python/tinygrad/Makefile
Normal file
25
backend/python/tinygrad/Makefile
Normal file
@@ -0,0 +1,25 @@
|
||||
.DEFAULT_GOAL := install
|
||||
|
||||
.PHONY: install
|
||||
install:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: install
|
||||
@echo "Running tinygrad..."
|
||||
bash run.sh
|
||||
@echo "tinygrad run."
|
||||
|
||||
.PHONY: test
|
||||
test: install
|
||||
@echo "Testing tinygrad..."
|
||||
bash test.sh
|
||||
@echo "tinygrad tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
750
backend/python/tinygrad/backend.py
Normal file
750
backend/python/tinygrad/backend.py
Normal file
@@ -0,0 +1,750 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LocalAI gRPC backend for tinygrad.
|
||||
|
||||
Multimodal scope (landing incrementally):
|
||||
- LLM text generation (Llama 3 / Qwen 2.5 / Mistral via HF safetensors + GGUF)
|
||||
- Native tool-call extraction via pluggable parsers (hermes, llama3_json, ...)
|
||||
- Embeddings, Stable Diffusion, Whisper — planned; currently return UNIMPLEMENTED.
|
||||
|
||||
The heavy imports (tinygrad, tokenizers, vendor.llama) are deferred until
|
||||
`LoadModel`, because tinygrad binds its compute device at import time from
|
||||
env vars. `_select_tinygrad_device()` maps LocalAI's BUILD_TYPE onto the
|
||||
corresponding tinygrad env flag before any import happens.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent import futures
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import grpc
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors # noqa: E402
|
||||
|
||||
from tool_parsers import resolve_parser # noqa: E402
|
||||
from tool_parsers.base import ToolCall # noqa: E402
|
||||
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device selection — must run BEFORE `import tinygrad` anywhere.
|
||||
#
|
||||
# In production this is set by run.sh based on which driver libraries the
|
||||
# host has injected into the container (libcuda.so.1 → CUDA, libamdhip64
|
||||
# → HIP, otherwise CLANG). This helper is only a fallback for direct
|
||||
# invocations like the unit tests.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _select_tinygrad_device() -> None:
|
||||
if any(os.environ.get(k) == "1" for k in ("CUDA", "HIP", "METAL", "CLANG", "AMD", "NV")):
|
||||
return
|
||||
os.environ["CLANG"] = "1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model asset discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _resolve_model_assets(model_ref: str) -> Path:
|
||||
"""
|
||||
Accept either a local path or a HuggingFace repo id (e.g.
|
||||
"Qwen/Qwen2.5-0.5B-Instruct") and return the local directory / file.
|
||||
HF ids are materialized via `huggingface_hub.snapshot_download`.
|
||||
"""
|
||||
p = Path(model_ref)
|
||||
if p.exists():
|
||||
return p
|
||||
if "/" in model_ref and not model_ref.startswith(("/", ".")):
|
||||
from huggingface_hub import snapshot_download
|
||||
local = snapshot_download(
|
||||
repo_id=model_ref,
|
||||
allow_patterns=[
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"generation_config.json",
|
||||
"*.safetensors",
|
||||
"*.safetensors.index.json",
|
||||
],
|
||||
)
|
||||
return Path(local)
|
||||
raise FileNotFoundError(f"Model not found: {model_ref}")
|
||||
|
||||
|
||||
def _load_llm_weights(model_dir: Path):
|
||||
"""Load HF safetensors weights from a directory or a single file."""
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.nn.state import safe_load, gguf_load
|
||||
|
||||
if model_dir.is_file():
|
||||
if str(model_dir).endswith(".gguf"):
|
||||
gguf_tensor = Tensor.empty(os.stat(model_dir).st_size, dtype=dtypes.uint8,
|
||||
device=f"disk:{model_dir}").to(Device.DEFAULT)
|
||||
return gguf_load(gguf_tensor)[1], "gguf"
|
||||
return safe_load(str(model_dir)), "safetensors"
|
||||
|
||||
index = model_dir / "model.safetensors.index.json"
|
||||
if index.exists():
|
||||
with open(index) as fp:
|
||||
weight_map = json.load(fp)["weight_map"]
|
||||
shards: dict[str, Any] = {}
|
||||
for shard_name in set(weight_map.values()):
|
||||
shards[shard_name] = safe_load(str(model_dir / shard_name))
|
||||
merged = {k: shards[n][k] for k, n in weight_map.items()}
|
||||
return merged, "safetensors"
|
||||
|
||||
single = model_dir / "model.safetensors"
|
||||
if single.exists():
|
||||
return safe_load(str(single)), "safetensors"
|
||||
|
||||
ggufs = sorted(model_dir.glob("*.gguf"))
|
||||
if ggufs:
|
||||
gguf_tensor = Tensor.empty(os.stat(ggufs[0]).st_size, dtype=dtypes.uint8,
|
||||
device=f"disk:{ggufs[0]}").to(Device.DEFAULT)
|
||||
return gguf_load(gguf_tensor)[1], "gguf"
|
||||
|
||||
raise FileNotFoundError(f"No safetensors or gguf weights found under {model_dir}")
|
||||
|
||||
|
||||
def _infer_qkv_bias(config: dict) -> bool:
|
||||
"""Qwen2-family uses Q/K/V projection bias; Llama/Mistral do not."""
|
||||
architectures = [a.lower() for a in config.get("architectures", [])]
|
||||
if any("qwen" in a for a in architectures):
|
||||
return True
|
||||
return bool(config.get("attention_bias", False)) or bool(config.get("qkv_bias", False))
|
||||
|
||||
|
||||
def _auto_tool_parser(model_ref: Optional[str], config: dict) -> Optional[str]:
|
||||
"""Pick a tool parser automatically from model family heuristics.
|
||||
|
||||
Order of precedence: architecture name from config.json, then model ref
|
||||
string. Returns None to fall through to the passthrough parser.
|
||||
"""
|
||||
arches = " ".join(a.lower() for a in config.get("architectures", []))
|
||||
ref = (model_ref or "").lower()
|
||||
blob = f"{arches} {ref}"
|
||||
|
||||
if "qwen3" in blob:
|
||||
return "qwen3_xml"
|
||||
if "hermes" in blob or "qwen2" in blob or "qwen" in blob:
|
||||
return "hermes"
|
||||
if "llama-3" in blob or "llama_3" in blob or "llama3" in blob:
|
||||
return "llama3_json"
|
||||
if "mistral" in blob or "mixtral" in blob:
|
||||
return "mistral"
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Servicer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""gRPC servicer for the tinygrad backend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._reset_state()
|
||||
|
||||
def _reset_state(self) -> None:
|
||||
self.model_ref: Optional[str] = None
|
||||
self.model_type: str = "llm"
|
||||
self.options: dict[str, str] = {}
|
||||
# LLM state
|
||||
self.llm_model = None
|
||||
self.llm_config: dict = {}
|
||||
self.llm_tokenizer = None
|
||||
self.llm_eos_ids: list[int] = []
|
||||
self.chat_template: Optional[str] = None
|
||||
self.tool_parser = resolve_parser(None)
|
||||
self.max_context = 4096
|
||||
# Stable Diffusion state
|
||||
self.sd_model = None
|
||||
# Whisper state
|
||||
self.whisper_model = None
|
||||
self.whisper_tokenizer = None
|
||||
|
||||
# --------------------- helpers --------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _parse_options(options_list) -> dict[str, str]:
|
||||
opts: dict[str, str] = {}
|
||||
for opt in options_list:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
opts[key.strip()] = value.strip()
|
||||
return opts
|
||||
|
||||
@staticmethod
|
||||
def _detect_model_type(model_ref: str, explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
name = (model_ref or "").lower()
|
||||
if "whisper" in name:
|
||||
return "whisper"
|
||||
if "sdxl" in name:
|
||||
return "sdxl"
|
||||
if "sd-v1" in name or "v1-5" in name or "stable-diffusion" in name:
|
||||
return "sd15"
|
||||
if any(tag in name for tag in ("bge", "e5", "minilm", "bert")):
|
||||
return "bert"
|
||||
return "llm"
|
||||
|
||||
def _messages_to_dicts(self, messages) -> list[dict]:
|
||||
result = []
|
||||
for msg in messages:
|
||||
d: dict = {"role": msg.role, "content": msg.content or ""}
|
||||
if msg.name:
|
||||
d["name"] = msg.name
|
||||
if msg.tool_call_id:
|
||||
d["tool_call_id"] = msg.tool_call_id
|
||||
if msg.reasoning_content:
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
def _render_prompt(self, request) -> str:
|
||||
"""Render messages + tools into the model's chat template, or fall
|
||||
back to the raw Prompt field for models without a template."""
|
||||
if not request.Messages and request.Prompt:
|
||||
return request.Prompt
|
||||
|
||||
if not self.chat_template:
|
||||
# No template known — concatenate role/content lines.
|
||||
lines = []
|
||||
for msg in request.Messages:
|
||||
lines.append(f"{msg.role}: {msg.content or ''}")
|
||||
return "\n".join(lines) + "\nassistant:"
|
||||
|
||||
from jinja2 import Environment
|
||||
|
||||
env = Environment(trim_blocks=True, lstrip_blocks=True)
|
||||
template = env.from_string(self.chat_template)
|
||||
|
||||
tools = None
|
||||
if request.Tools:
|
||||
try:
|
||||
tools = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
tools = None
|
||||
|
||||
return template.render(
|
||||
messages=self._messages_to_dicts(request.Messages),
|
||||
tools=tools,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
# --------------------- LLM path -------------------------------------
|
||||
|
||||
def _load_llm(self, model_dir: Path) -> None:
|
||||
config_path = model_dir / "config.json" if model_dir.is_dir() else None
|
||||
if config_path is None or not config_path.exists():
|
||||
raise FileNotFoundError(f"config.json not found under {model_dir}")
|
||||
with open(config_path) as fp:
|
||||
config = json.load(fp)
|
||||
self.llm_config = config
|
||||
|
||||
from tinygrad import nn
|
||||
|
||||
from vendor.llama import Transformer, convert_from_huggingface, convert_from_gguf, fix_bf16
|
||||
|
||||
n_layers = config["num_hidden_layers"]
|
||||
n_heads = config["num_attention_heads"]
|
||||
n_kv_heads = config.get("num_key_value_heads", n_heads)
|
||||
dim = config["hidden_size"]
|
||||
hidden_dim = config["intermediate_size"]
|
||||
vocab_size = config["vocab_size"]
|
||||
norm_eps = config.get("rms_norm_eps", 1e-5)
|
||||
rope_theta = config.get("rope_theta", 10000)
|
||||
self.max_context = min(config.get("max_position_embeddings", 4096), 8192)
|
||||
qkv_bias = _infer_qkv_bias(config)
|
||||
|
||||
weights, weight_format = _load_llm_weights(model_dir)
|
||||
|
||||
model = Transformer(
|
||||
dim=dim,
|
||||
hidden_dim=hidden_dim,
|
||||
n_heads=n_heads,
|
||||
n_layers=n_layers,
|
||||
norm_eps=norm_eps,
|
||||
vocab_size=vocab_size,
|
||||
linear=nn.Linear,
|
||||
embedding=nn.Embedding,
|
||||
n_kv_heads=n_kv_heads,
|
||||
rope_theta=rope_theta,
|
||||
max_context=self.max_context,
|
||||
jit=True,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
|
||||
if weight_format == "safetensors":
|
||||
weights = convert_from_huggingface(weights, n_layers, n_heads, n_kv_heads)
|
||||
else:
|
||||
weights = convert_from_gguf(weights, n_layers)
|
||||
weights = fix_bf16(weights)
|
||||
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
load_state_dict(model, weights, strict=False, consume=True)
|
||||
self.llm_model = model
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_json = model_dir / "tokenizer.json"
|
||||
if not tokenizer_json.exists():
|
||||
raise FileNotFoundError(f"tokenizer.json not found under {model_dir}")
|
||||
from tokenizers import Tokenizer as HFTokenizer
|
||||
self.llm_tokenizer = HFTokenizer.from_file(str(tokenizer_json))
|
||||
|
||||
# Chat template + EOS ids (from tokenizer_config.json + generation_config.json)
|
||||
tok_cfg_path = model_dir / "tokenizer_config.json"
|
||||
if tok_cfg_path.exists():
|
||||
with open(tok_cfg_path) as fp:
|
||||
tok_cfg = json.load(fp)
|
||||
self.chat_template = tok_cfg.get("chat_template")
|
||||
|
||||
self.llm_eos_ids = []
|
||||
gen_cfg_path = model_dir / "generation_config.json"
|
||||
if gen_cfg_path.exists():
|
||||
with open(gen_cfg_path) as fp:
|
||||
gen_cfg = json.load(fp)
|
||||
eos = gen_cfg.get("eos_token_id")
|
||||
if isinstance(eos, list):
|
||||
self.llm_eos_ids.extend(int(x) for x in eos)
|
||||
elif isinstance(eos, int):
|
||||
self.llm_eos_ids.append(eos)
|
||||
if not self.llm_eos_ids:
|
||||
eos = config.get("eos_token_id")
|
||||
if isinstance(eos, list):
|
||||
self.llm_eos_ids.extend(int(x) for x in eos)
|
||||
elif isinstance(eos, int):
|
||||
self.llm_eos_ids.append(eos)
|
||||
|
||||
# Auto-pick tool parser from options or model family.
|
||||
parser_name = self.options.get("tool_parser") or _auto_tool_parser(self.model_ref, config)
|
||||
self.tool_parser = resolve_parser(parser_name)
|
||||
|
||||
# --------------------- Stable Diffusion path ------------------------
|
||||
|
||||
def _load_sd(self, model_ref: str) -> None:
|
||||
"""Load a Stable Diffusion 1.x checkpoint (CompVis `.ckpt` format)."""
|
||||
from huggingface_hub import hf_hub_download
|
||||
from tinygrad.nn.state import load_state_dict, torch_load
|
||||
|
||||
from vendor.stable_diffusion import StableDiffusion
|
||||
|
||||
ckpt_path = Path(model_ref)
|
||||
if not ckpt_path.exists():
|
||||
# Accept an HF repo id — fetch the canonical v1-5-pruned-emaonly.ckpt
|
||||
# from the requested repo. Common case is runwayml/stable-diffusion-v1-5.
|
||||
repo_id = model_ref if "/" in model_ref else "runwayml/stable-diffusion-v1-5"
|
||||
ckpt_file = self.options.get("sd_ckpt_filename", "v1-5-pruned-emaonly.ckpt")
|
||||
ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_file))
|
||||
|
||||
model = StableDiffusion()
|
||||
state_dict = torch_load(str(ckpt_path))
|
||||
if isinstance(state_dict, dict) and "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
load_state_dict(model, state_dict, strict=False, verbose=False, realize=False)
|
||||
self.sd_model = model
|
||||
|
||||
# --------------------- Whisper path ---------------------------------
|
||||
|
||||
def _load_whisper(self, model_ref: str) -> None:
|
||||
"""Load a Whisper checkpoint (OpenAI `.pt` format).
|
||||
|
||||
Accepts a model-size alias (tiny / tiny.en / base / base.en / small /
|
||||
small.en) OR an explicit `.pt` file path OR the HF repo id naming
|
||||
convention `openai/whisper-*` (mapped to the matching OpenAI alias).
|
||||
"""
|
||||
from vendor.whisper import init_whisper, MODEL_URLS
|
||||
|
||||
alias = model_ref
|
||||
if "/" in alias and alias.startswith("openai/whisper-"):
|
||||
alias = alias.removeprefix("openai/whisper-")
|
||||
if alias not in MODEL_URLS:
|
||||
# Explicit path to a .pt checkpoint — fall back to size heuristic
|
||||
# via filename.
|
||||
basename = Path(alias).name.lower()
|
||||
for name in MODEL_URLS:
|
||||
if name in basename:
|
||||
alias = name
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown Whisper model_ref={model_ref!r}; expected one of {list(MODEL_URLS)} "
|
||||
f"or an openai/whisper-* HF id"
|
||||
)
|
||||
|
||||
model, enc = init_whisper(alias, batch_size=1)
|
||||
self.whisper_model = model
|
||||
self.whisper_tokenizer = enc
|
||||
|
||||
# --------------------- LLM generation -------------------------------
|
||||
|
||||
def _generate_tokens(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int):
|
||||
"""Synchronous generator yielding (token_id, token_text) pairs."""
|
||||
from tinygrad import Tensor
|
||||
|
||||
encoded = self.llm_tokenizer.encode(prompt)
|
||||
ids = encoded.ids
|
||||
if not ids:
|
||||
return
|
||||
|
||||
# Prefill: run one forward pass over the full prompt, sampling from the last token.
|
||||
prompt_tensor = Tensor([ids])
|
||||
next_tok = int(self.llm_model(prompt_tensor, 0, temperature=temperature, top_k=top_k, top_p=top_p).item())
|
||||
pos = len(ids)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if next_tok in self.llm_eos_ids:
|
||||
break
|
||||
text = self.llm_tokenizer.decode([next_tok])
|
||||
yield next_tok, text
|
||||
next_tok = int(self.llm_model(Tensor([[next_tok]]), pos, temperature=temperature,
|
||||
top_k=top_k, top_p=top_p).item())
|
||||
pos += 1
|
||||
|
||||
# --------------------- gRPC methods ---------------------------------
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
async def LoadModel(self, request, context):
|
||||
try:
|
||||
_select_tinygrad_device()
|
||||
self._reset_state()
|
||||
self.options = self._parse_options(list(request.Options))
|
||||
self.model_ref = request.ModelFile or request.Model
|
||||
self.model_type = self._detect_model_type(self.model_ref, self.options.get("model_type"))
|
||||
|
||||
if self.model_type in ("sd15", "sd", "stable-diffusion"):
|
||||
self._load_sd(self.model_ref)
|
||||
return backend_pb2.Result(
|
||||
success=True, message="tinygrad Stable Diffusion 1.x loaded",
|
||||
)
|
||||
|
||||
if self.model_type == "whisper":
|
||||
self._load_whisper(self.model_ref)
|
||||
return backend_pb2.Result(
|
||||
success=True, message="tinygrad Whisper loaded",
|
||||
)
|
||||
|
||||
if self.model_type != "llm":
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message=f"tinygrad: model_type={self.model_type} not yet implemented",
|
||||
)
|
||||
|
||||
model_path = _resolve_model_assets(self.model_ref)
|
||||
if model_path.is_file():
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message=("tinygrad currently requires an HF model directory with config.json + "
|
||||
"tokenizer.json; single-file GGUF support lands with gallery metadata wiring."),
|
||||
)
|
||||
self._load_llm(model_path)
|
||||
|
||||
return backend_pb2.Result(
|
||||
success=True,
|
||||
message=f"tinygrad LLM loaded (arch={self.llm_config.get('architectures', ['?'])[0]}, "
|
||||
f"parser={self.tool_parser.name})",
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"LoadModel failed: {exc}")
|
||||
|
||||
async def Predict(self, request, context):
|
||||
if self.llm_model is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("LLM not loaded")
|
||||
return backend_pb2.Reply()
|
||||
|
||||
try:
|
||||
prompt = self._render_prompt(request)
|
||||
max_new = request.Tokens if request.Tokens > 0 else 256
|
||||
temperature = request.Temperature if request.Temperature > 0 else 0.7
|
||||
top_p = request.TopP if request.TopP > 0 else 0.95
|
||||
top_k = request.TopK if request.TopK > 0 else 0
|
||||
|
||||
t0 = time.monotonic()
|
||||
pieces: list[str] = []
|
||||
ntok = 0
|
||||
for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
|
||||
pieces.append(text)
|
||||
ntok += 1
|
||||
elapsed = time.monotonic() - t0
|
||||
|
||||
full = "".join(pieces)
|
||||
from tool_parsers.hermes import HermesToolParser
|
||||
if isinstance(self.tool_parser, HermesToolParser):
|
||||
result = self.tool_parser.parse_full(full)
|
||||
content, calls, reasoning = result.content, result.tool_calls, result.reasoning
|
||||
else:
|
||||
content, calls = self.tool_parser.parse(full)
|
||||
reasoning = ""
|
||||
|
||||
delta = backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning,
|
||||
tool_calls=[
|
||||
backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
|
||||
for c in calls
|
||||
],
|
||||
)
|
||||
return backend_pb2.Reply(
|
||||
message=content.encode("utf-8"),
|
||||
tokens=ntok,
|
||||
timing_token_generation=elapsed,
|
||||
chat_deltas=[delta],
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Predict failed: {exc}")
|
||||
return backend_pb2.Reply()
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
if self.llm_model is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("LLM not loaded")
|
||||
return
|
||||
|
||||
try:
|
||||
prompt = self._render_prompt(request)
|
||||
max_new = request.Tokens if request.Tokens > 0 else 256
|
||||
temperature = request.Temperature if request.Temperature > 0 else 0.7
|
||||
top_p = request.TopP if request.TopP > 0 else 0.95
|
||||
top_k = request.TopK if request.TopK > 0 else 0
|
||||
|
||||
buffer = ""
|
||||
for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
|
||||
buffer += text
|
||||
yield backend_pb2.Reply(
|
||||
message=text.encode("utf-8"),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=text)],
|
||||
)
|
||||
|
||||
# Final emission carries the extracted tool calls (vLLM semantics).
|
||||
from tool_parsers.hermes import HermesToolParser
|
||||
if isinstance(self.tool_parser, HermesToolParser):
|
||||
result = self.tool_parser.parse_full(buffer)
|
||||
calls = result.tool_calls
|
||||
reasoning = result.reasoning
|
||||
else:
|
||||
_, calls = self.tool_parser.parse(buffer)
|
||||
reasoning = ""
|
||||
|
||||
if calls or reasoning:
|
||||
yield backend_pb2.Reply(
|
||||
chat_deltas=[backend_pb2.ChatDelta(
|
||||
reasoning_content=reasoning,
|
||||
tool_calls=[
|
||||
backend_pb2.ToolCallDelta(index=c.index, id=c.id, name=c.name, arguments=c.arguments)
|
||||
for c in calls
|
||||
],
|
||||
)],
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"PredictStream failed: {exc}")
|
||||
|
||||
async def Embedding(self, request, context):
|
||||
if self.llm_model is None or self.llm_tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("No model loaded")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
try:
|
||||
text = request.Embeddings
|
||||
if not text:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("Embeddings field is empty")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
from tinygrad import Tensor, dtypes
|
||||
|
||||
ids = self.llm_tokenizer.encode(text).ids
|
||||
if not ids:
|
||||
return backend_pb2.EmbeddingResult(embeddings=[])
|
||||
|
||||
# Clamp to context window — truncate long inputs rather than blow up.
|
||||
ids = ids[: self.max_context]
|
||||
tokens = Tensor([ids])
|
||||
|
||||
hidden = self.llm_model.embed(tokens) # (1, seqlen, dim)
|
||||
# Mean pool over sequence dim
|
||||
pooled = hidden.mean(axis=1).squeeze(0) # (dim,)
|
||||
# L2 normalize
|
||||
norm = pooled.square().sum().sqrt()
|
||||
normalized = (pooled / (norm + 1e-12))
|
||||
vec = normalized.cast(dtypes.float32).tolist()
|
||||
|
||||
return backend_pb2.EmbeddingResult(embeddings=[float(x) for x in vec])
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Embedding failed: {exc}")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
async def GenerateImage(self, request, context):
|
||||
if self.sd_model is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("No Stable Diffusion model loaded")
|
||||
return backend_pb2.Result(success=False, message="not loaded")
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
from vendor.stable_diffusion import run_sd15
|
||||
|
||||
steps = request.step if request.step > 0 else 20
|
||||
guidance = 7.5
|
||||
seed = request.seed if request.seed != 0 else None
|
||||
img_tensor = run_sd15(
|
||||
model=self.sd_model,
|
||||
prompt=request.positive_prompt or "",
|
||||
negative_prompt=request.negative_prompt or "",
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
)
|
||||
arr = img_tensor.numpy()
|
||||
image = Image.fromarray(arr)
|
||||
dst = request.dst or "/tmp/tinygrad_image.png"
|
||||
image.save(dst)
|
||||
return backend_pb2.Result(success=True, message=dst)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"GenerateImage failed: {exc}")
|
||||
|
||||
def _transcribe(self, audio_path: str, language: Optional[str]) -> tuple[str, float]:
|
||||
from vendor.whisper import load_file_waveform, transcribe_waveform
|
||||
|
||||
waveform = load_file_waveform(audio_path)
|
||||
text = transcribe_waveform(
|
||||
self.whisper_model,
|
||||
self.whisper_tokenizer,
|
||||
[waveform],
|
||||
language=language or None,
|
||||
)
|
||||
duration = float(len(waveform)) / 16000.0
|
||||
return text, duration
|
||||
|
||||
async def AudioTranscription(self, request, context):
|
||||
if self.whisper_model is None or self.whisper_tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("No Whisper model loaded")
|
||||
return backend_pb2.TranscriptResult()
|
||||
|
||||
try:
|
||||
if not request.dst:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("TranscriptRequest.dst (audio file path) is required")
|
||||
return backend_pb2.TranscriptResult()
|
||||
|
||||
text, duration = self._transcribe(request.dst, request.language)
|
||||
segments = [backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)]
|
||||
return backend_pb2.TranscriptResult(
|
||||
text=text,
|
||||
language=request.language or "en",
|
||||
duration=duration,
|
||||
segments=segments,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"AudioTranscription failed: {exc}")
|
||||
return backend_pb2.TranscriptResult()
|
||||
|
||||
async def AudioTranscriptionStream(self, request, context):
|
||||
if self.whisper_model is None or self.whisper_tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("No Whisper model loaded")
|
||||
return
|
||||
|
||||
try:
|
||||
if not request.dst:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("TranscriptRequest.dst (audio file path) is required")
|
||||
return
|
||||
|
||||
# The vendored tinygrad whisper loop is chunked at the file level
|
||||
# (one inference pass per 30s segment), not token-level. To still
|
||||
# produce a streaming response we run the full transcription and
|
||||
# emit it as a single delta + a final-result envelope so the client
|
||||
# gets both code paths exercised.
|
||||
text, duration = self._transcribe(request.dst, request.language)
|
||||
yield backend_pb2.TranscriptStreamResponse(delta=text)
|
||||
final = backend_pb2.TranscriptResult(
|
||||
text=text,
|
||||
language=request.language or "en",
|
||||
duration=duration,
|
||||
segments=[backend_pb2.TranscriptSegment(id=0, start=0, end=0, text=text)],
|
||||
)
|
||||
yield backend_pb2.TranscriptStreamResponse(final_result=final)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"AudioTranscriptionStream failed: {exc}")
|
||||
|
||||
async def Status(self, request, context):
|
||||
return backend_pb2.StatusResponse(state=backend_pb2.StatusResponse.READY)
|
||||
|
||||
async def Free(self, request, context):
|
||||
self._reset_state()
|
||||
return backend_pb2.Result(success=True, message="freed")
|
||||
|
||||
|
||||
async def serve(address):
|
||||
server = grpc.aio.server(
|
||||
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
interceptors=get_auth_interceptors(aio=True),
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
|
||||
|
||||
await server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
await server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the tinygrad gRPC backend.")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="Bind address")
|
||||
args = parser.parse_args()
|
||||
asyncio.run(serve(args.addr))
|
||||
17
backend/python/tinygrad/install.sh
Executable file
17
backend/python/tinygrad/install.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# tinygrad >= 0.12 requires Python >= 3.11 (pyproject: `requires-python = ">=3.11"`).
|
||||
# LocalAI's default portable python is 3.10, so we pin to 3.11.x here.
|
||||
PYTHON_VERSION="3.11"
|
||||
PYTHON_PATCH="14"
|
||||
PY_STANDALONE_TAG="20260203"
|
||||
|
||||
installRequirements
|
||||
103
backend/python/tinygrad/package.sh
Executable file
103
backend/python/tinygrad/package.sh
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/bin/bash
|
||||
# Script to package runtime shared libraries for the tinygrad backend.
|
||||
#
|
||||
# The final Dockerfile.python stage is FROM scratch, so system libraries
|
||||
# must be explicitly copied into ${BACKEND}/lib so the backend can run on
|
||||
# any host without installing them. libbackend.sh automatically prepends
|
||||
# that directory to LD_LIBRARY_PATH at run time.
|
||||
#
|
||||
# tinygrad's CPU device (CLANG / LLVM renderer) JIT-compiles kernels at
|
||||
# runtime. The default `CLANG` path invokes the external `clang` binary via
|
||||
# subprocess, which does not exist in the scratch image. We force the
|
||||
# in-process LLVM path (`CPU_LLVM=1` in run.sh) which loads libLLVM.so.*
|
||||
# through ctypes and bundle the library + its runtime dependencies here.
|
||||
#
|
||||
# Also bundle libgomp (pulled by librosa / numpy via numba) and libsndfile
|
||||
# (required by soundfile -> librosa audio I/O for Whisper).
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath "$0")")
|
||||
LIB_DIR="${CURDIR}/lib"
|
||||
mkdir -p "${LIB_DIR}"
|
||||
|
||||
SEARCH_DIRS=(
|
||||
/usr/lib/x86_64-linux-gnu
|
||||
/usr/lib/aarch64-linux-gnu
|
||||
/lib/x86_64-linux-gnu
|
||||
/lib/aarch64-linux-gnu
|
||||
/usr/lib
|
||||
/lib
|
||||
)
|
||||
|
||||
copy_with_symlinks() {
|
||||
local soname="$1"
|
||||
local hit=""
|
||||
for dir in "${SEARCH_DIRS[@]}"; do
|
||||
if [ -e "${dir}/${soname}" ]; then
|
||||
hit="${dir}/${soname}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ -z "${hit}" ]; then
|
||||
echo "warning: ${soname} not found in standard lib paths" >&2
|
||||
return 0
|
||||
fi
|
||||
local real
|
||||
real=$(readlink -f "${hit}")
|
||||
cp -v "${real}" "${LIB_DIR}/"
|
||||
local real_base
|
||||
real_base=$(basename "${real}")
|
||||
if [ "${real_base}" != "${soname}" ]; then
|
||||
ln -sf "${real_base}" "${LIB_DIR}/${soname}"
|
||||
fi
|
||||
}
|
||||
|
||||
# tinygrad searches for libLLVM under these sonames (see
|
||||
# tinygrad/runtime/autogen/llvm.py). Ubuntu 24.04's `llvm` metapackage
|
||||
# installs `libLLVM-18.so.1` into `/usr/lib/llvm-18/lib/`. Also scan the
|
||||
# standard lib directories in case a different distro layout puts it in
|
||||
# /usr/lib/x86_64-linux-gnu.
|
||||
llvm_so=""
|
||||
shopt -s nullglob
|
||||
LLVM_EXTRA_DIRS=(/usr/lib/llvm-*/lib /usr/lib/llvm-*)
|
||||
# First try the versioned symlink (libLLVM-18.so) since that's what
|
||||
# tinygrad's DLL loader matches against (see llvm.py DLL name list).
|
||||
for dir in "${SEARCH_DIRS[@]}" "${LLVM_EXTRA_DIRS[@]}"; do
|
||||
for candidate in "${dir}"/libLLVM-[0-9]*.so "${dir}"/libLLVM-[0-9]*.so.[0-9]*; do
|
||||
if [ -e "${candidate}" ]; then
|
||||
llvm_so="${candidate}"
|
||||
break 2
|
||||
fi
|
||||
done
|
||||
done
|
||||
# Fallback: any libLLVM.so file under /usr.
|
||||
if [ -z "${llvm_so}" ]; then
|
||||
llvm_so=$(find /usr -maxdepth 5 -name 'libLLVM*.so*' 2>/dev/null | head -1)
|
||||
fi
|
||||
shopt -u nullglob
|
||||
if [ -z "${llvm_so}" ]; then
|
||||
echo "ERROR: libLLVM not found — tinygrad CPU device needs it." >&2
|
||||
echo "Install the Ubuntu \`llvm\` package in the builder stage." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Found libLLVM at: ${llvm_so}"
|
||||
llvm_base=$(basename "${llvm_so}")
|
||||
real_llvm=$(readlink -f "${llvm_so}")
|
||||
cp -v "${real_llvm}" "${LIB_DIR}/"
|
||||
real_base=$(basename "${real_llvm}")
|
||||
if [ "${real_base}" != "${llvm_base}" ]; then
|
||||
ln -sf "${real_base}" "${LIB_DIR}/${llvm_base}"
|
||||
fi
|
||||
|
||||
# libLLVM has soft runtime deps on libedit / libtinfo; pick them up if
|
||||
# present. They're optional but loading without them can fail.
|
||||
copy_with_symlinks libedit.so.2
|
||||
copy_with_symlinks libtinfo.so.6
|
||||
|
||||
# Audio I/O for the Whisper path.
|
||||
copy_with_symlinks libsndfile.so.1
|
||||
copy_with_symlinks libgomp.so.1
|
||||
|
||||
echo "tinygrad packaging completed successfully"
|
||||
ls -liah "${LIB_DIR}/"
|
||||
11
backend/python/tinygrad/protogen.sh
Executable file
11
backend/python/tinygrad/protogen.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runProtogen
|
||||
1
backend/python/tinygrad/requirements-cpu.txt
Normal file
1
backend/python/tinygrad/requirements-cpu.txt
Normal file
@@ -0,0 +1 @@
|
||||
# tinygrad CPU backend uses CLANG device (no extra deps required).
|
||||
2
backend/python/tinygrad/requirements-cublas12.txt
Normal file
2
backend/python/tinygrad/requirements-cublas12.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# tinygrad drives CUDA through its own JIT (CUDA=1 env var).
|
||||
# Requires the CUDA 12 runtime from the base image; no extra Python deps.
|
||||
2
backend/python/tinygrad/requirements-cublas13.txt
Normal file
2
backend/python/tinygrad/requirements-cublas13.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# tinygrad drives CUDA through its own JIT (CUDA=1 env var).
|
||||
# Requires the CUDA 13 runtime from the base image; no extra Python deps.
|
||||
15
backend/python/tinygrad/requirements.txt
Normal file
15
backend/python/tinygrad/requirements.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
grpcio==1.80.0
|
||||
protobuf==6.33.5
|
||||
certifi
|
||||
setuptools
|
||||
numpy>=2.0.0
|
||||
tinygrad>=0.12.0
|
||||
tokenizers>=0.21.0
|
||||
huggingface_hub
|
||||
jinja2>=3.1.0
|
||||
tiktoken
|
||||
sentencepiece
|
||||
safetensors
|
||||
Pillow
|
||||
librosa
|
||||
soundfile
|
||||
55
backend/python/tinygrad/run.sh
Executable file
55
backend/python/tinygrad/run.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# tinygrad binds its compute device at import time from a single env var
|
||||
# (CUDA / HIP / METAL / CLANG). We pick one here based on what driver
|
||||
# libraries the host has injected into the container — when a user runs
|
||||
# the image with `--gpus all` (or the equivalent rocm runtime), the
|
||||
# nvidia-container-toolkit / rocm runtime mounts the right libraries
|
||||
# under /usr/lib so we can detect them.
|
||||
#
|
||||
# tinygrad's CUDA path uses two compiler pairs: an NVRTC-backed one and
|
||||
# an in-process PTX renderer. We force the PTX renderer here
|
||||
# (`CUDA_PTX=1`) so the image is independent of the host CUDA toolkit
|
||||
# version — only libcuda.so.1 (the driver) is required.
|
||||
find_lib() {
|
||||
local soname="$1"
|
||||
for dir in /usr/lib/x86_64-linux-gnu /usr/lib64 /usr/lib /lib/x86_64-linux-gnu /lib64 /lib; do
|
||||
if [ -e "${dir}/${soname}" ]; then
|
||||
echo "${dir}/${soname}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
if [ -z "${CUDA:-}${HIP:-}${METAL:-}${CLANG:-}" ]; then
|
||||
if find_lib libcuda.so.1 >/dev/null; then
|
||||
export CUDA=1
|
||||
export CUDA_PTX=1
|
||||
elif find_lib libamdhip64.so >/dev/null || find_lib libamdhip64.so.6 >/dev/null; then
|
||||
export HIP=1
|
||||
else
|
||||
export CLANG=1
|
||||
fi
|
||||
fi
|
||||
|
||||
# The CPU path (CLANG=1) JIT-compiles via libLLVM. Force tinygrad's
|
||||
# in-process LLVM compiler so we don't need an external `clang` binary
|
||||
# (which is not present in the scratch image).
|
||||
export CPU_LLVM=1
|
||||
if [ -z "${LLVM_PATH:-}" ]; then
|
||||
for candidate in "${EDIR}"/lib/libLLVM-*.so "${EDIR}"/lib/libLLVM-*.so.* "${EDIR}"/lib/libLLVM.so.*; do
|
||||
if [ -e "${candidate}" ]; then
|
||||
export LLVM_PATH="${candidate}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
84
backend/python/tinygrad/test.py
Normal file
84
backend/python/tinygrad/test.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Unit tests for the tinygrad gRPC backend.
|
||||
|
||||
These tests cover the cheap paths that don't need a real model checkpoint:
|
||||
- Health responds OK
|
||||
- Tool-call parsers emit expected ToolCall structures
|
||||
|
||||
The full LLM / embeddings / Stable Diffusion / Whisper paths are exercised by
|
||||
the root-level `make test-extra-backend-tinygrad-all` e2e targets, which boot
|
||||
the containerized backend against real HF checkpoints.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import grpc
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from tool_parsers.hermes import HermesToolParser # noqa: E402
|
||||
|
||||
|
||||
class TestHealth(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
def tearDown(self):
|
||||
self.service.kill()
|
||||
self.service.wait()
|
||||
|
||||
def test_health(self):
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b"OK")
|
||||
|
||||
|
||||
class TestHermesParser(unittest.TestCase):
|
||||
def test_single_tool_call(self):
|
||||
parser = HermesToolParser()
|
||||
text = (
|
||||
"Sure, let me check.\n"
|
||||
"<tool_call>\n"
|
||||
'{"name": "get_weather", "arguments": {"city": "Paris"}}\n'
|
||||
"</tool_call>\n"
|
||||
"Done."
|
||||
)
|
||||
content, calls = parser.parse(text)
|
||||
self.assertIn("Sure", content)
|
||||
self.assertIn("Done", content)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0].name, "get_weather")
|
||||
self.assertIn("Paris", calls[0].arguments)
|
||||
|
||||
def test_multi_call_and_thinking(self):
|
||||
parser = HermesToolParser()
|
||||
text = (
|
||||
"<think>I need both.</think>"
|
||||
'<tool_call>{"name":"a","arguments":{"x":1}}</tool_call>'
|
||||
'<tool_call>{"name":"b","arguments":{}}</tool_call>'
|
||||
)
|
||||
result = parser.parse_full(text)
|
||||
self.assertEqual(result.reasoning, "I need both.")
|
||||
self.assertEqual([c.name for c in result.tool_calls], ["a", "b"])
|
||||
self.assertEqual(result.tool_calls[0].index, 0)
|
||||
self.assertEqual(result.tool_calls[1].index, 1)
|
||||
|
||||
def test_no_tool_call_is_passthrough(self):
|
||||
parser = HermesToolParser()
|
||||
text = "plain assistant answer with no tool call"
|
||||
content, calls = parser.parse(text)
|
||||
self.assertEqual(content, text)
|
||||
self.assertEqual(calls, [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/tinygrad/test.sh
Executable file
11
backend/python/tinygrad/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
11
backend/python/tinygrad/tool_parsers/__init__.py
Normal file
11
backend/python/tinygrad/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Tool-call parsers for the tinygrad backend.
|
||||
|
||||
Each parser takes raw model output and extracts OpenAI-style tool calls so
|
||||
the backend can populate `ChatDelta.tool_calls[]` natively (matching vLLM's
|
||||
behavior, which the Go core prefers over regex fallback parsing).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import ToolCall, ToolParser, resolve_parser
|
||||
|
||||
__all__ = ["ToolCall", "ToolParser", "resolve_parser"]
|
||||
85
backend/python/tinygrad/tool_parsers/base.py
Normal file
85
backend/python/tinygrad/tool_parsers/base.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Common types + parser registry for tool-call extraction."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""One extracted tool call — maps 1:1 to backend_pb2.ToolCallDelta."""
|
||||
index: int
|
||||
name: str
|
||||
arguments: str # JSON string
|
||||
id: str = ""
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""Parser interface.
|
||||
|
||||
Subclasses implement `parse` (full non-streaming pass) and optionally
|
||||
`parse_stream` (incremental). The default `parse_stream` buffers until a
|
||||
full response is available and then delegates to `parse`.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._stream_buffer = ""
|
||||
self._stream_index = 0
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
"""Return (content_for_user, tool_calls)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_stream(self, delta: str, finished: bool = False) -> tuple[str, list[ToolCall]]:
|
||||
"""Accumulate a streaming delta. Emits any tool calls that have closed.
|
||||
|
||||
Default behavior: buffer until `finished=True`, then parse once.
|
||||
Subclasses can override to emit mid-stream.
|
||||
"""
|
||||
self._stream_buffer += delta
|
||||
if not finished:
|
||||
return "", []
|
||||
content, calls = self.parse(self._stream_buffer)
|
||||
# Re-index starting from whatever we've already emitted in this stream.
|
||||
reindexed: list[ToolCall] = []
|
||||
for i, c in enumerate(calls):
|
||||
reindexed.append(ToolCall(
|
||||
index=self._stream_index + i,
|
||||
name=c.name,
|
||||
arguments=c.arguments,
|
||||
id=c.id,
|
||||
))
|
||||
self._stream_index += len(reindexed)
|
||||
return content, reindexed
|
||||
|
||||
def reset(self) -> None:
|
||||
self._stream_buffer = ""
|
||||
self._stream_index = 0
|
||||
|
||||
|
||||
_REGISTRY: dict[str, type[ToolParser]] = {}
|
||||
|
||||
|
||||
def register(cls: type[ToolParser]) -> type[ToolParser]:
|
||||
_REGISTRY[cls.name] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def resolve_parser(name: Optional[str]) -> ToolParser:
|
||||
"""Return a parser instance by name, falling back to a no-op passthrough."""
|
||||
# Import for side effects — each module registers itself.
|
||||
from . import hermes, llama3_json, mistral, qwen3_xml # noqa: F401
|
||||
|
||||
if name and name in _REGISTRY:
|
||||
return _REGISTRY[name]()
|
||||
return PassthroughToolParser()
|
||||
|
||||
|
||||
class PassthroughToolParser(ToolParser):
|
||||
"""No-op parser — used when no tool_parser is configured."""
|
||||
name = "passthrough"
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
return text, []
|
||||
74
backend/python/tinygrad/tool_parsers/hermes.py
Normal file
74
backend/python/tinygrad/tool_parsers/hermes.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Hermes-format tool-call parser.
|
||||
|
||||
Hermes 2 / 2.5 / 3 (and Qwen 2.5 Instruct, which adopted the same convention)
|
||||
emit tool calls wrapped in `<tool_call>...</tool_call>` tags, where the inner
|
||||
content is a JSON object with `name` and `arguments` keys:
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_weather", "arguments": {"city": "Paris"}}
|
||||
</tool_call>
|
||||
|
||||
Multiple tool calls may appear back-to-back. Text outside the tags is plain
|
||||
assistant content that should surface to the user.
|
||||
|
||||
This parser also strips `<think>...</think>` reasoning blocks and returns them
|
||||
via the reasoning_content channel (Qwen 3, DeepSeek-R1 distills).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import ToolCall, ToolParser, register
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
|
||||
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HermesParseResult:
|
||||
content: str
|
||||
reasoning: str
|
||||
tool_calls: list[ToolCall]
|
||||
|
||||
|
||||
@register
|
||||
class HermesToolParser(ToolParser):
|
||||
name = "hermes"
|
||||
|
||||
def _parse_full(self, text: str) -> HermesParseResult:
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
def _capture_reasoning(match: re.Match[str]) -> str:
|
||||
reasoning_parts.append(match.group(1).strip())
|
||||
return ""
|
||||
|
||||
text_wo_think = _THINK_RE.sub(_capture_reasoning, text)
|
||||
|
||||
calls: list[ToolCall] = []
|
||||
for idx, match in enumerate(_TOOL_CALL_RE.finditer(text_wo_think)):
|
||||
raw = match.group(1)
|
||||
try:
|
||||
obj = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(obj, dict):
|
||||
continue
|
||||
name = obj.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
args = obj.get("arguments", {})
|
||||
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
|
||||
calls.append(ToolCall(index=idx, name=name, arguments=args_str))
|
||||
|
||||
content = _TOOL_CALL_RE.sub("", text_wo_think).strip()
|
||||
reasoning = "\n\n".join(reasoning_parts).strip()
|
||||
return HermesParseResult(content=content, reasoning=reasoning, tool_calls=calls)
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
result = self._parse_full(text)
|
||||
return result.content, result.tool_calls
|
||||
|
||||
def parse_full(self, text: str) -> HermesParseResult:
|
||||
return self._parse_full(text)
|
||||
86
backend/python/tinygrad/tool_parsers/llama3_json.py
Normal file
86
backend/python/tinygrad/tool_parsers/llama3_json.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Llama 3.1 / 3.2 / 3.3 JSON tool-call parser.
|
||||
|
||||
Meta's Llama 3.1+ instruct chat templates emit tool calls in two broadly
|
||||
compatible shapes:
|
||||
|
||||
1. With the `<|python_tag|>` lead-in:
|
||||
<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}}
|
||||
2. As a bare JSON object (or list of objects) at the end of the turn.
|
||||
|
||||
We also handle multi-call shapes where the model emits several JSON objects
|
||||
separated by `;` or newlines, and JSON arrays `[{...}, {...}]`. The key field
|
||||
for Llama 3 is historically `parameters` (older docs) but recent checkpoints
|
||||
also emit `arguments` — accept either.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import ToolCall, ToolParser, register
|
||||
|
||||
_PYTHON_TAG = "<|python_tag|>"
|
||||
_JSON_OBJECT_RE = re.compile(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", re.DOTALL)
|
||||
|
||||
|
||||
def _coerce_call(obj: object, index: int) -> ToolCall | None:
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
name = obj.get("name")
|
||||
if not isinstance(name, str):
|
||||
return None
|
||||
args = obj.get("arguments", obj.get("parameters", {}))
|
||||
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
|
||||
return ToolCall(index=index, name=name, arguments=args_str)
|
||||
|
||||
|
||||
@register
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
name = "llama3_json"
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
calls: list[ToolCall] = []
|
||||
|
||||
# Strip <|python_tag|> segments first — each segment is one tool call
|
||||
# body. The content after the final python_tag (if any) is the call.
|
||||
remaining = text
|
||||
if _PYTHON_TAG in text:
|
||||
head, *tails = text.split(_PYTHON_TAG)
|
||||
remaining = head
|
||||
for tail in tails:
|
||||
parsed = _try_parse(tail.strip(), len(calls))
|
||||
calls.extend(parsed)
|
||||
|
||||
# Any JSON objects / arrays left in `remaining` count as tool calls too
|
||||
# if they parse to a {"name": ..., "arguments": ...} shape.
|
||||
for match in _JSON_OBJECT_RE.finditer(remaining):
|
||||
parsed = _try_parse(match.group(0), len(calls))
|
||||
if parsed:
|
||||
calls.extend(parsed)
|
||||
remaining = remaining.replace(match.group(0), "", 1)
|
||||
|
||||
content = remaining.strip()
|
||||
return content, calls
|
||||
|
||||
|
||||
def _try_parse(blob: str, start_index: int) -> list[ToolCall]:
|
||||
"""Parse a fragment that may be a JSON object or a JSON array of objects."""
|
||||
blob = blob.strip().rstrip(";")
|
||||
if not blob:
|
||||
return []
|
||||
try:
|
||||
obj = json.loads(blob)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
if isinstance(obj, dict):
|
||||
call = _coerce_call(obj, start_index)
|
||||
return [call] if call else []
|
||||
if isinstance(obj, list):
|
||||
calls: list[ToolCall] = []
|
||||
for i, item in enumerate(obj):
|
||||
c = _coerce_call(item, start_index + i)
|
||||
if c:
|
||||
calls.append(c)
|
||||
return calls
|
||||
return []
|
||||
56
backend/python/tinygrad/tool_parsers/mistral.py
Normal file
56
backend/python/tinygrad/tool_parsers/mistral.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Mistral / Mixtral tool-call parser.
|
||||
|
||||
Mistral Nemo / Small / Large Instruct emit tool calls prefixed with the
|
||||
`[TOOL_CALLS]` control token, followed by a JSON array:
|
||||
|
||||
[TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}}]
|
||||
|
||||
Multiple calls live inside the same array. Any text before `[TOOL_CALLS]` is
|
||||
normal assistant content and should surface to the user.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .base import ToolCall, ToolParser, register
|
||||
|
||||
_MARKER = "[TOOL_CALLS]"
|
||||
_JSON_ARRAY_RE = re.compile(r"\[\s*(?:\{.*?\}\s*,?\s*)+\]", re.DOTALL)
|
||||
|
||||
|
||||
@register
|
||||
class MistralToolParser(ToolParser):
|
||||
name = "mistral"
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
if _MARKER not in text:
|
||||
return text.strip(), []
|
||||
|
||||
head, tail = text.split(_MARKER, 1)
|
||||
content = head.strip()
|
||||
|
||||
match = _JSON_ARRAY_RE.search(tail)
|
||||
if not match:
|
||||
return content, []
|
||||
|
||||
try:
|
||||
arr = json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
return content, []
|
||||
|
||||
if not isinstance(arr, list):
|
||||
return content, []
|
||||
|
||||
calls: list[ToolCall] = []
|
||||
for i, obj in enumerate(arr):
|
||||
if not isinstance(obj, dict):
|
||||
continue
|
||||
name = obj.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
args = obj.get("arguments", {})
|
||||
args_str = args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
|
||||
calls.append(ToolCall(index=i, name=name, arguments=args_str))
|
||||
|
||||
return content, calls
|
||||
74
backend/python/tinygrad/tool_parsers/qwen3_xml.py
Normal file
74
backend/python/tinygrad/tool_parsers/qwen3_xml.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Qwen 3 XML tool-call parser.
|
||||
|
||||
Qwen 3 Instruct emits tool calls wrapped in a two-level tag structure:
|
||||
|
||||
<tool_call>
|
||||
<function=get_weather>
|
||||
<parameter=city>
|
||||
Paris
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
Parameter values are raw text — we treat them as strings unless they look
|
||||
like JSON (in which case we try to parse so numbers / booleans round-trip
|
||||
cleanly). Qwen 3 also supports `<think>...</think>` reasoning blocks before
|
||||
the tool call — these are captured via the shared Hermes convention.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .base import ToolCall, ToolParser, register
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||
_FUNCTION_RE = re.compile(r"<function=([^>]+)>(.*?)</function>", re.DOTALL)
|
||||
_PARAMETER_RE = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
|
||||
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
def _maybe_json(value: str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return value
|
||||
if value[0] in "{[\"" or value in ("true", "false", "null") or value.lstrip("-").replace(".", "", 1).isdigit():
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
@register
|
||||
class Qwen3XmlToolParser(ToolParser):
|
||||
name = "qwen3_xml"
|
||||
|
||||
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||||
# Strip reasoning blocks from the user-visible content.
|
||||
stripped = _THINK_RE.sub("", text)
|
||||
|
||||
calls: list[ToolCall] = []
|
||||
for match in _TOOL_CALL_RE.finditer(stripped):
|
||||
body = match.group(1)
|
||||
fn_match = _FUNCTION_RE.search(body)
|
||||
if not fn_match:
|
||||
continue
|
||||
name = fn_match.group(1).strip()
|
||||
params_body = fn_match.group(2)
|
||||
|
||||
params: dict[str, object] = {}
|
||||
for pm in _PARAMETER_RE.finditer(params_body):
|
||||
params[pm.group(1).strip()] = _maybe_json(pm.group(2))
|
||||
|
||||
calls.append(ToolCall(
|
||||
index=len(calls),
|
||||
name=name,
|
||||
arguments=json.dumps(params, ensure_ascii=False),
|
||||
))
|
||||
|
||||
content = _TOOL_CALL_RE.sub("", stripped).strip()
|
||||
return content, calls
|
||||
6
backend/python/tinygrad/vendor/__init__.py
vendored
Normal file
6
backend/python/tinygrad/vendor/__init__.py
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Vendored upstream tinygrad reference code (MIT-licensed).
|
||||
|
||||
Source: https://github.com/tinygrad/tinygrad
|
||||
These files are not part of the `tinygrad` pip package (the `extra/` tree is
|
||||
excluded from `pyproject.toml` `packages`), so we carry a pinned copy here.
|
||||
"""
|
||||
83
backend/python/tinygrad/vendor/audio_helpers.py
vendored
Normal file
83
backend/python/tinygrad/vendor/audio_helpers.py
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
# Vendored verbatim from tinygrad examples/audio_helpers.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/audio_helpers.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
from typing import Optional
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import DTypeLike, dtypes
|
||||
import math
|
||||
|
||||
# rewritten from numpy
|
||||
def rfftfreq(n: int, d: float = 1.0, device=None) -> Tensor:
|
||||
val = 1.0 / (n * d)
|
||||
N = n // 2 + 1
|
||||
results = Tensor.arange(N, device=device)
|
||||
return results * val
|
||||
|
||||
# just like in librosa
|
||||
def fft_frequencies(sr: float, n_fft: int) -> Tensor:
|
||||
return rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
def hz_to_mel(freq: Tensor) -> Tensor:
|
||||
# linear part
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
mels = (freq - f_min) / f_sp
|
||||
|
||||
# log-scale part
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
mask = freq >= min_log_hz
|
||||
return mask.where(((min_log_hz - f_min) / f_sp) + (freq / min_log_hz).log() / (math.log(6.4) / 27.0), mels)
|
||||
|
||||
def mel_to_hz(mels: Tensor) -> Tensor:
|
||||
# linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mels
|
||||
|
||||
# nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = math.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
log_t = mels >= min_log_mel
|
||||
freqs = log_t.where(min_log_hz * ((logstep * (mels - min_log_mel)).exp()), freqs)
|
||||
return freqs
|
||||
|
||||
def mel_frequencies(n_mels: int = 128, *, fmin: float = 0.0, fmax: float = 11025.0) -> Tensor:
|
||||
# center freqs of mel bands - uniformly spaced between limits
|
||||
min_max_mel = hz_to_mel(Tensor([fmin, fmax]))
|
||||
|
||||
mels = Tensor.linspace(min_max_mel[0], min_max_mel[1], n_mels)
|
||||
hz = mel_to_hz(mels)
|
||||
return hz
|
||||
|
||||
def mel(
|
||||
*,
|
||||
sr: float,
|
||||
n_fft: int,
|
||||
n_mels: int = 128,
|
||||
fmin: float = 0.0,
|
||||
fmax: Optional[float] = None,
|
||||
dtype: DTypeLike = dtypes.default_float,
|
||||
) -> Tensor:
|
||||
if fmax is None:
|
||||
fmax = float(sr) / 2
|
||||
|
||||
n_mels = int(n_mels)
|
||||
|
||||
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) # center freqs of each FFT bin
|
||||
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax) # center freqs of mel bands
|
||||
|
||||
fdiff = mel_f[1:] - mel_f[:-1]
|
||||
ramps = mel_f[None].T.expand(-1, fftfreqs.shape[-1]) - fftfreqs
|
||||
|
||||
lower = -ramps[:n_mels] / fdiff[:n_mels][None].T
|
||||
upper = ramps[2 : n_mels + 2] / fdiff[1 : n_mels + 1][None].T
|
||||
weights = lower.minimum(upper).maximum(0)
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, None]
|
||||
|
||||
return weights
|
||||
484
backend/python/tinygrad/vendor/clip.py
vendored
Normal file
484
backend/python/tinygrad/vendor/clip.py
vendored
Normal file
@@ -0,0 +1,484 @@
|
||||
# Vendored verbatim from tinygrad extra/models/clip.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/clip.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn import Linear, LayerNorm, Embedding, Conv2d
|
||||
|
||||
from typing import List, Optional, Union, Tuple, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
import numpy as np
|
||||
import re, gzip
|
||||
|
||||
# Allow for monkeypatching for mlperf.
|
||||
gelu = Tensor.gelu
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Namespace for CLIP Text Tokenizer components.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
return set(zip(word, word[1:]))
|
||||
@staticmethod
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
@staticmethod
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
class ClipTokenizer:
|
||||
def __init__(self, version=None):
|
||||
self.byte_encoder, self.version = Tokenizer.bytes_to_unicode(), version
|
||||
merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(Tokenizer.bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
if self.version == "sd_mlperf_v5_0":
|
||||
import regex
|
||||
vocab.extend(['<start_of_text>', '<end_of_text>'])
|
||||
self.cache = {'<start_of_text>': '<start_of_text>', '<end_of_text>': '<end_of_text>'}
|
||||
self.pat = regex.compile(r"""<start_of_text>|<end_of_text>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
|
||||
else:
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = Tokenizer.get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
pairs = Tokenizer.get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
|
||||
bpe_tokens: List[int] = []
|
||||
if self.version == "sd_mlperf_v5_0":
|
||||
import regex, ftfy, html
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text)).strip()
|
||||
text = Tokenizer.whitespace_clean(text).lower()
|
||||
re_module = regex
|
||||
else:
|
||||
text = Tokenizer.whitespace_clean(text.strip()).lower()
|
||||
re_module = re
|
||||
|
||||
for token in re_module.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
# Truncation, keeping two slots for start and end tokens.
|
||||
if len(bpe_tokens) > 75:
|
||||
bpe_tokens = bpe_tokens[:75]
|
||||
return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
|
||||
|
||||
|
||||
class Embedder(ABC):
|
||||
input_key: str
|
||||
@abstractmethod
|
||||
def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
pass
|
||||
|
||||
|
||||
class Closed:
|
||||
"""
|
||||
Namespace for OpenAI CLIP model components.
|
||||
"""
|
||||
class ClipMlp:
|
||||
def __init__(self):
|
||||
self.fc1 = Linear(768, 3072)
|
||||
self.fc2 = Linear(3072, 768)
|
||||
|
||||
def __call__(self, h:Tensor) -> Tensor:
|
||||
h = self.fc1(h)
|
||||
h = h.quick_gelu()
|
||||
h = self.fc2(h)
|
||||
return h
|
||||
|
||||
class ClipAttention:
|
||||
def __init__(self):
|
||||
self.embed_dim = 768
|
||||
self.num_heads = 12
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
||||
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
||||
|
||||
class ClipEncoderLayer:
|
||||
def __init__(self):
|
||||
self.self_attn = Closed.ClipAttention()
|
||||
self.layer_norm1 = LayerNorm(768)
|
||||
self.mlp = Closed.ClipMlp()
|
||||
self.layer_norm2 = LayerNorm(768)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
class ClipTextEmbeddings:
|
||||
def __init__(self):
|
||||
self.token_embedding = Embedding(49408, 768)
|
||||
self.position_embedding = Embedding(77, 768)
|
||||
|
||||
def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
|
||||
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
||||
|
||||
class ClipEncoder:
|
||||
def __init__(self, layer_count:int=12):
|
||||
self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
|
||||
|
||||
def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
|
||||
# the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
|
||||
layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
|
||||
for l in layers:
|
||||
x = l(x, causal_attention_mask)
|
||||
return x
|
||||
|
||||
class ClipTextTransformer:
|
||||
def __init__(self, ret_layer_idx:Optional[int]=None):
|
||||
self.embeddings = Closed.ClipTextEmbeddings()
|
||||
self.encoder = Closed.ClipEncoder()
|
||||
self.final_layer_norm = LayerNorm(768)
|
||||
self.ret_layer_idx = ret_layer_idx
|
||||
|
||||
def __call__(self, input_ids:Tensor) -> Tensor:
|
||||
x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
|
||||
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
|
||||
return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
|
||||
|
||||
class ClipTextModel:
|
||||
def __init__(self, ret_layer_idx:Optional[int]):
|
||||
self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
|
||||
class FrozenClosedClipEmbedder(Embedder):
|
||||
def __init__(self, ret_layer_idx:Optional[int]=None):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer()
|
||||
self.transformer = Closed.ClipTextModel(ret_layer_idx)
|
||||
self.input_key = "txt"
|
||||
|
||||
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
||||
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
|
||||
return self.transformer.text_model(tokens.reshape(len(texts),-1))
|
||||
|
||||
|
||||
class Open:
|
||||
"""
|
||||
Namespace for OpenCLIP model components.
|
||||
"""
|
||||
class MultiheadAttention:
|
||||
def __init__(self, dims:int, n_heads:int):
|
||||
self.dims = dims
|
||||
self.n_heads = n_heads
|
||||
self.d_head = self.dims // self.n_heads
|
||||
|
||||
self.in_proj_bias = Tensor.empty(3*dims)
|
||||
self.in_proj_weight = Tensor.empty(3*dims, dims)
|
||||
self.out_proj = Linear(dims, dims)
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
T,B,C = x.shape
|
||||
|
||||
proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
|
||||
proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0, -2)
|
||||
|
||||
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)]
|
||||
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).reshape(T, B, C)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
class Mlp:
|
||||
def __init__(self, dims, hidden_dims):
|
||||
self.c_fc = Linear(dims, hidden_dims)
|
||||
self.c_proj = Linear(hidden_dims, dims)
|
||||
self.gelu = gelu
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential([self.c_fc, self.gelu, self.c_proj])
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
|
||||
class ResidualAttentionBlock:
|
||||
def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
|
||||
self.ln_1 = LayerNorm(dims)
|
||||
self.attn = Open.MultiheadAttention(dims, n_heads)
|
||||
|
||||
self.ln_2 = LayerNorm(dims)
|
||||
self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None, transpose:bool=False) -> Tensor:
|
||||
q_x = self.ln_1(x)
|
||||
attn_out = self.attn(q_x.transpose(0, 1) if transpose else q_x, attn_mask=attn_mask)
|
||||
attn_out = attn_out.transpose(0, 1) if transpose else attn_out
|
||||
x = x + attn_out
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
|
||||
class ClipTransformer:
|
||||
def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
|
||||
self.resblocks = [
|
||||
Open.ResidualAttentionBlock(dims, n_heads, mlp_ratio) for _ in range(layers)
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
for r in self.resblocks:
|
||||
x = r(x, attn_mask=attn_mask, transpose=True)
|
||||
return x
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
|
||||
class ClipTextTransformer:
|
||||
def __init__(self, width:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
|
||||
self.token_embedding = Embedding(vocab_size, width)
|
||||
self.positional_embedding = Tensor.empty(ctx_length, width)
|
||||
self.transformer = Open.ClipTransformer(width, layers, n_heads)
|
||||
self.ln_final = LayerNorm(width)
|
||||
self.text_projection = Tensor.empty(width, width)
|
||||
self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
|
||||
|
||||
def __call__(self, text:Tensor) -> Tensor:
|
||||
seq_len = text.shape[1]
|
||||
|
||||
x = self.token_embedding(text)
|
||||
x = x + self.positional_embedding[:seq_len]
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x)
|
||||
|
||||
pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
|
||||
return pooled
|
||||
|
||||
class ClipVisionTransformer:
|
||||
def __init__(self, width:int, layers:int, d_head:int, image_size:int, patch_size:int):
|
||||
grid_size = image_size // patch_size
|
||||
n_heads = width // d_head
|
||||
assert n_heads * d_head == width
|
||||
|
||||
self.conv1 = Conv2d(3, width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
self.class_embedding = Tensor.empty(width)
|
||||
self.positional_embedding = Tensor.empty(grid_size * grid_size + 1, width)
|
||||
self.transformer = Open.ClipTransformer(width, layers, n_heads)
|
||||
self.ln_pre = LayerNorm(width)
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = Tensor.empty(width, 1024)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self.conv1(x)
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
|
||||
x = self.class_embedding.reshape(1, 1, -1).expand(x.shape[0], 1, -1).cat(x, dim=1)
|
||||
x = x + self.positional_embedding
|
||||
|
||||
x = self.ln_pre(x)
|
||||
x = self.transformer(x)
|
||||
x = self.ln_post(x)
|
||||
|
||||
pooled = x[:, 0] @ self.proj
|
||||
return pooled
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
|
||||
class FrozenOpenClipEmbedder(Embedder):
|
||||
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False, clip_tokenizer_version=None):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer(version=clip_tokenizer_version)
|
||||
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
|
||||
self.return_pooled = return_pooled
|
||||
self.input_key = "txt"
|
||||
self.ln_penultimate = ln_penultimate
|
||||
|
||||
def tokenize(self, text:str, device:Optional[str]=None) -> Tensor:
|
||||
return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int32, device=device).reshape(1,-1)
|
||||
|
||||
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
|
||||
for r in self.model.transformer.resblocks:
|
||||
x, penultimate = r(x, attn_mask=attn_mask), x
|
||||
return x.permute(1, 0, 2), penultimate.permute(1, 0, 2)
|
||||
|
||||
def embed_tokens(self, tokens:Tensor) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
|
||||
x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
|
||||
if self.ln_penultimate:
|
||||
penultimate = self.model.ln_final(penultimate)
|
||||
|
||||
if self.return_pooled:
|
||||
x = self.model.ln_final(x)
|
||||
index = tokens.argmax(axis=-1).reshape(-1,1,1).expand(x.shape[0],1,x.shape[-1])
|
||||
pooled = x.gather(1, index).squeeze(1) @ self.model.text_projection
|
||||
return penultimate, pooled
|
||||
else:
|
||||
return penultimate
|
||||
|
||||
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
||||
tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0)
|
||||
return self.embed_tokens(tokens)
|
||||
|
||||
|
||||
clip_configs: Dict = {
|
||||
"ViT-H-14": {
|
||||
"dims": 1024,
|
||||
"vision_cfg": {
|
||||
"width": 1280,
|
||||
"layers": 32,
|
||||
"d_head": 80,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
},
|
||||
"text_cfg": {
|
||||
"width": 1024,
|
||||
"n_heads": 16,
|
||||
"layers": 24,
|
||||
"ctx_length": 77,
|
||||
"vocab_size": 49408,
|
||||
},
|
||||
"return_pooled": False,
|
||||
"ln_penultimate": True,
|
||||
}
|
||||
}
|
||||
|
||||
class OpenClipEncoder:
|
||||
def __init__(self, dims:int, text_cfg:Dict, vision_cfg:Dict, **_):
|
||||
self.visual = Open.ClipVisionTransformer(**vision_cfg)
|
||||
|
||||
text = Open.ClipTextTransformer(**text_cfg)
|
||||
self.transformer = text.transformer
|
||||
self.token_embedding = text.token_embedding
|
||||
self.positional_embedding = text.positional_embedding
|
||||
self.ln_final = text.ln_final
|
||||
self.text_projection = text.text_projection
|
||||
|
||||
self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize()
|
||||
self.mean = Tensor([0.48145466, 0.45782750, 0.40821073]).reshape(-1, 1, 1)
|
||||
self.std = Tensor([0.26862954, 0.26130258, 0.27577711]).reshape(-1, 1, 1)
|
||||
|
||||
# TODO:
|
||||
# Should be doable in pure tinygrad, would just require some work and verification.
|
||||
# This is very desirable since it would allow for full generation->evaluation in a single JIT call.
|
||||
def prepare_image(self, image) -> Tensor:
|
||||
from PIL import Image
|
||||
SIZE = 224
|
||||
w, h = image.size
|
||||
scale = min(SIZE / h, SIZE / w)
|
||||
image = image.resize((max(int(w*scale),SIZE),max(int(h*scale),SIZE)), Image.Resampling.BICUBIC)
|
||||
w, h = image.size
|
||||
if w > SIZE:
|
||||
left = (w - SIZE) // 2
|
||||
image = image.crop((left, left+SIZE, 0, SIZE))
|
||||
elif h > SIZE:
|
||||
top = (h - SIZE) // 2
|
||||
image = image.crop((0, SIZE, top, top+SIZE))
|
||||
|
||||
x = Tensor(np.array(image.convert('RGB')), device=self.std.device)
|
||||
x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0
|
||||
return (x - self.mean) / self.std
|
||||
|
||||
def encode_tokens(self, tokens:Tensor) -> Tensor:
|
||||
x = self.token_embedding(tokens)
|
||||
x = x + self.positional_embedding
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x)
|
||||
x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)]
|
||||
x = x @ self.text_projection
|
||||
return x
|
||||
|
||||
def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor:
|
||||
image_features: Tensor = self.visual(image)
|
||||
image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
|
||||
|
||||
text_features = self.encode_tokens(tokens)
|
||||
text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
|
||||
|
||||
return (image_features * text_features).sum(axis=-1)
|
||||
294
backend/python/tinygrad/vendor/llama.py
vendored
Normal file
294
backend/python/tinygrad/vendor/llama.py
vendored
Normal file
@@ -0,0 +1,294 @@
|
||||
# Vendored from tinygrad extra/models/llama.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/llama.py
|
||||
#
|
||||
# Local modification: Attention / TransformerBlock / Transformer accept an
|
||||
# optional `qkv_bias` flag so the same module can host Qwen2-style models that
|
||||
# use bias on the Q/K/V projections (Llama 3 has no bias). Changes are marked
|
||||
# with `# LOCALAI`.
|
||||
#
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
from typing import Union, Optional, Any
|
||||
import collections, math
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
|
||||
|
||||
def complex_mult(A, c, d):
|
||||
a, b = A[..., 0:1], A[..., 1:2]
|
||||
ro = a * c - b * d
|
||||
co = a * d + b * c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
|
||||
c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
|
||||
xq_out = complex_mult(xq, c, d)
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
|
||||
class Attention:
|
||||
# LOCALAI: added qkv_bias
|
||||
def __init__(self, dim, n_heads, n_kv_heads=None, max_context=0, linear=nn.Linear, qk_norm: float | None = None, qkv_bias: bool = False):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=qkv_bias)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
|
||||
self.k_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
|
||||
|
||||
if self.q_norm is not None and self.k_norm is not None:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
if x.dtype == dtypes.bfloat16:
|
||||
xq, xk = xq.contiguous_backward(), xk.contiguous_backward()
|
||||
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
|
||||
bsz, seqlen, _, _ = xq.shape
|
||||
|
||||
if self.max_context:
|
||||
if not hasattr(self, "cache_kv"):
|
||||
self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
|
||||
if isinstance(x.device, tuple):
|
||||
self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
|
||||
|
||||
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
|
||||
self.cache_kv[:, :, start_pos:start_pos + seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
|
||||
|
||||
keys = self.cache_kv[0, :, 0:start_pos + seqlen, :, :]
|
||||
values = self.cache_kv[1, :, 0:start_pos + seqlen, :, :]
|
||||
else:
|
||||
assert start_pos == 0
|
||||
keys, values = xk, xv
|
||||
|
||||
if self.max_context:
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
|
||||
else:
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, is_causal=True, enable_gqa=True).transpose(1, 2)
|
||||
|
||||
attn = attn.reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn)
|
||||
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
|
||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
w1 = self.w1(x).silu()
|
||||
w3 = self.w3(x.contiguous_backward())
|
||||
return self.w2(w1 * w3)
|
||||
|
||||
|
||||
class TransformerBlock:
|
||||
# LOCALAI: added qkv_bias
|
||||
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int,
|
||||
linear=nn.Linear, feed_forward=FeedForward, qk_norm=None, qkv_bias: bool = False):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear, qk_norm, qkv_bias=qkv_bias)
|
||||
self.feed_forward = feed_forward(dim, hidden_dim, linear)
|
||||
self.attention_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).contiguous().contiguous_backward()
|
||||
|
||||
|
||||
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
assert logits.ndim == 1, "only works on 1d tensors"
|
||||
assert 0 <= p <= 1, "p must be between 0 and 1"
|
||||
assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
|
||||
|
||||
if temp < 1e-6:
|
||||
return logits.argmax()
|
||||
|
||||
logits = logits.to(Device.DEFAULT)
|
||||
|
||||
if af or ap:
|
||||
if not hasattr(sample, "alpha_counter"):
|
||||
setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
|
||||
logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
|
||||
|
||||
logits = (logits != logits).where(-float("inf"), logits)
|
||||
|
||||
t = (logits / temp).softmax()
|
||||
|
||||
counter = Tensor.arange(t.numel(), device=logits.device).contiguous()
|
||||
counter2 = Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
|
||||
|
||||
if k:
|
||||
output = Tensor.zeros(k, device=logits.device).contiguous()
|
||||
output_indices = Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
|
||||
for i in range(k):
|
||||
t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
|
||||
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
|
||||
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
|
||||
t = (counter == t_argmax).where(0, t)
|
||||
|
||||
output_cumsum = output[::-1].cumsum()[::-1] + t.sum()
|
||||
output = (output_cumsum >= (1 - p)) * output
|
||||
output_indices = (output_cumsum >= (1 - p)) * output_indices
|
||||
|
||||
output_idx = output.multinomial()
|
||||
output_token = output_indices[output_idx]
|
||||
else:
|
||||
output_token = t.multinomial()
|
||||
|
||||
if af or ap:
|
||||
sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
|
||||
|
||||
return output_token
|
||||
|
||||
|
||||
class Transformer:
|
||||
# LOCALAI: added qkv_bias
|
||||
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_layers: int, norm_eps: float, vocab_size,
|
||||
linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000,
|
||||
max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None, disable_kv_cache=False,
|
||||
qkv_bias: bool = False):
|
||||
self.layers = [
|
||||
TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps,
|
||||
0 if disable_kv_cache else max_context, linear,
|
||||
feed_forward=feed_forward, qk_norm=qk_norm, qkv_bias=qkv_bias)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = embedding(vocab_size, dim)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous().requires_grad_(False)
|
||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||
|
||||
def forward(self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens).contiguous()
|
||||
freqs_cis = self.freqs_cis.cast(h.dtype)[:, start_pos:start_pos + seqlen, :, :, :]
|
||||
|
||||
if self.max_context != 0 and seqlen > 1:
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1)
|
||||
else:
|
||||
mask = None
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h).contiguous().contiguous_backward()).contiguous_backward()
|
||||
if math.isnan(temperature):
|
||||
return logits
|
||||
|
||||
return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
||||
def __call__(self, tokens: Tensor, start_pos: int, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
|
||||
if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context - 1).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
||||
# LOCALAI: extract last hidden state for embeddings. Skips the LM head and
|
||||
# the causal-mask branch is left intact so the pooling sees the full sequence.
|
||||
def embed(self, tokens: Tensor) -> Tensor:
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens).contiguous()
|
||||
freqs_cis = self.freqs_cis.cast(h.dtype)[:, 0:seqlen, :, :, :]
|
||||
mask = Tensor.full((1, 1, seqlen, seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(1) if seqlen > 1 else None
|
||||
for layer in self.layers:
|
||||
h = layer(h, 0, freqs_cis, mask)
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
def convert_from_huggingface(weights: dict[str, Tensor], n_layers: int, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1] if len(v.shape) > 1 else 1).transpose(1, 2).reshape(*v.shape[:2])
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)},
|
||||
**{f"model.layers.{l}.self_attn.{x}_norm.weight": f"layers.{l}.attention.{x}_norm.weight" for x in ["q", "k"] for l in range(n_layers)},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(n_layers)},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(n_layers)},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)},
|
||||
**{f"model.layers.{l}.mlp.gate.weight": f"layers.{l}.feed_forward.gate.weight" for l in range(n_layers)},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
sd = {}
|
||||
experts = collections.defaultdict(dict)
|
||||
for k, v in weights.items():
|
||||
if ".rotary_emb." in k:
|
||||
continue
|
||||
v = v.to(Device.DEFAULT)
|
||||
if "model.layers" in k:
|
||||
if ("q_proj" in k or "q_norm" in k) and permute_layers:
|
||||
v = permute(v, n_heads)
|
||||
elif ("k_proj" in k or "k_norm" in k) and permute_layers:
|
||||
v = permute(v, n_kv_heads)
|
||||
if '.mlp.experts.' in k:
|
||||
_, _, layer, _, _, expert, name, _ = k.split('.')
|
||||
experts[f'layers.{layer}.feed_forward.{name}'][int(expert)] = v
|
||||
continue
|
||||
sd[keymap[k]] = v
|
||||
for k, v in experts.items():
|
||||
sd[k] = Tensor.stack(*[v[i] for i in range(len(v))])
|
||||
|
||||
if "output.weight" not in sd and "tok_embeddings.weight" in sd:
|
||||
sd["output.weight"] = sd["tok_embeddings.weight"]
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def convert_from_gguf(weights: dict[str, Tensor], n_layers: int):
|
||||
keymap = {
|
||||
"token_embd.weight": "tok_embeddings.weight",
|
||||
**{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)},
|
||||
**{f"blk.{l}.attn_{x}.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v"] for l in range(n_layers)},
|
||||
**{f"blk.{l}.attn_{x}.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v"] for l in range(n_layers)},
|
||||
**{f"blk.{l}.attn_output.weight": f"layers.{l}.attention.wo.weight" for l in range(n_layers)},
|
||||
**{f"blk.{l}.ffn_norm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)},
|
||||
**{f"blk.{l}.ffn_{x}.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)},
|
||||
"output_norm.weight": "norm.weight",
|
||||
"rope_freqs.weight": "rope_freqs.weight",
|
||||
}
|
||||
sd = {keymap[k]: v for k, v in weights.items() if k in keymap}
|
||||
if "output.weight" not in sd and "token_embd.weight" in weights:
|
||||
sd["output.weight"] = weights["token_embd.weight"]
|
||||
return sd
|
||||
|
||||
|
||||
def fix_bf16(weights: dict[Any, Tensor]):
|
||||
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
232
backend/python/tinygrad/vendor/stable_diffusion.py
vendored
Normal file
232
backend/python/tinygrad/vendor/stable_diffusion.py
vendored
Normal file
@@ -0,0 +1,232 @@
|
||||
# Adapted from tinygrad examples/stable_diffusion.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/stable_diffusion.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# Local modifications: removed the MLPerf training branch (pulls
|
||||
# examples/mlperf/initializers which we don't vendor) and the __main__
|
||||
# argparse / fetch / profile blocks. Kept the core classes so the LocalAI
|
||||
# tinygrad backend can instantiate and drive Stable Diffusion v1.x from a
|
||||
# single checkpoint path.
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.nn import Conv2d, GroupNorm
|
||||
|
||||
from . import clip as clip_mod
|
||||
from . import unet as unet_mod
|
||||
from .clip import Closed, Tokenizer
|
||||
from .unet import UNetModel
|
||||
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
self.norm = GroupNorm(32, in_channels)
|
||||
self.q = Conv2d(in_channels, in_channels, 1)
|
||||
self.k = Conv2d(in_channels, in_channels, 1)
|
||||
self.v = Conv2d(in_channels, in_channels, 1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def __call__(self, x):
|
||||
h_ = self.norm(x)
|
||||
q, k, v = self.q(h_), self.k(h_), self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = [t.reshape(b, c, h * w).transpose(1, 2) for t in (q, k, v)]
|
||||
h_ = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(b, c, h, w)
|
||||
return x + self.proj_out(h_)
|
||||
|
||||
|
||||
class ResnetBlock:
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
self.norm1 = GroupNorm(32, in_channels)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
||||
self.norm2 = GroupNorm(32, out_channels)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else (lambda x: x)
|
||||
|
||||
def __call__(self, x):
|
||||
h = self.conv1(self.norm1(x).swish())
|
||||
h = self.conv2(self.norm2(h).swish())
|
||||
return self.nin_shortcut(x) + h
|
||||
|
||||
|
||||
class Mid:
|
||||
def __init__(self, block_in):
|
||||
self.block_1 = ResnetBlock(block_in, block_in)
|
||||
self.attn_1 = AttnBlock(block_in)
|
||||
self.block_2 = ResnetBlock(block_in, block_in)
|
||||
|
||||
def __call__(self, x):
|
||||
return x.sequential([self.block_1, self.attn_1, self.block_2])
|
||||
|
||||
|
||||
class Decoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
||||
self.conv_in = Conv2d(4, 512, 3, padding=1)
|
||||
self.mid = Mid(512)
|
||||
|
||||
arr = []
|
||||
for i, s in enumerate(sz):
|
||||
arr.append({"block": [ResnetBlock(s[1], s[0]), ResnetBlock(s[0], s[0]), ResnetBlock(s[0], s[0])]})
|
||||
if i != 0:
|
||||
arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
||||
self.up = arr
|
||||
|
||||
self.norm_out = GroupNorm(32, 128)
|
||||
self.conv_out = Conv2d(128, 3, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.mid(x)
|
||||
for l in self.up[::-1]:
|
||||
for b in l['block']:
|
||||
x = b(x)
|
||||
if 'upsample' in l:
|
||||
bs, c, py, px = x.shape
|
||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py * 2, px * 2)
|
||||
x = l['upsample']['conv'](x)
|
||||
x.realize()
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
||||
self.conv_in = Conv2d(3, 128, 3, padding=1)
|
||||
|
||||
arr = []
|
||||
for i, s in enumerate(sz):
|
||||
arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]})
|
||||
if i != 3:
|
||||
arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))}
|
||||
self.down = arr
|
||||
|
||||
self.mid = Mid(512)
|
||||
self.norm_out = GroupNorm(32, 512)
|
||||
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
for l in self.down:
|
||||
for b in l['block']:
|
||||
x = b(x)
|
||||
if 'downsample' in l:
|
||||
x = l['downsample']['conv'](x)
|
||||
x = self.mid(x)
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
|
||||
class AutoencoderKL:
|
||||
def __init__(self):
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.quant_conv = Conv2d(8, 8, 1)
|
||||
self.post_quant_conv = Conv2d(4, 4, 1)
|
||||
|
||||
def __call__(self, x):
|
||||
latent = self.encoder(x)
|
||||
latent = self.quant_conv(latent)
|
||||
latent = latent[:, 0:4]
|
||||
latent = self.post_quant_conv(latent)
|
||||
return self.decoder(latent)
|
||||
|
||||
|
||||
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
|
||||
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
return Tensor(alphas_cumprod)
|
||||
|
||||
|
||||
# SD1.x UNet hyperparameters (same as upstream `unet_params`).
|
||||
UNET_PARAMS_SD1: Dict[str, Any] = {
|
||||
"adm_in_ch": None,
|
||||
"in_ch": 4,
|
||||
"out_ch": 4,
|
||||
"model_ch": 320,
|
||||
"attention_resolutions": [4, 2, 1],
|
||||
"num_res_blocks": 2,
|
||||
"channel_mult": [1, 2, 4, 4],
|
||||
"n_heads": 8,
|
||||
"transformer_depth": [1, 1, 1, 1],
|
||||
"ctx_dim": 768,
|
||||
"use_linear": False,
|
||||
}
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
"""Stable Diffusion 1.x pipeline, adapted from tinygrad's reference example.
|
||||
|
||||
Drives the native CompVis `sd-v1-*.ckpt` checkpoint format (the only one
|
||||
the vendored weight layout handles). For HuggingFace safetensors pipelines
|
||||
the caller is expected to download / merge the `.ckpt` equivalent before
|
||||
calling LoadModel.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.alphas_cumprod = get_alphas_cumprod()
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
|
||||
transformer=namedtuple("Transformer", ["text_model"])(text_model=Closed.ClipTextTransformer())
|
||||
)
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(
|
||||
diffusion_model=UNetModel(**UNET_PARAMS_SD1)
|
||||
)
|
||||
|
||||
# DDIM update step.
|
||||
def _update(self, x, e_t, a_t, a_prev):
|
||||
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
dir_xt = (1.0 - a_prev).sqrt() * e_t
|
||||
return a_prev.sqrt() * pred_x0 + dir_xt
|
||||
|
||||
def _model_output(self, uncond, cond, latent, timestep, guidance):
|
||||
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, uncond.cat(cond, dim=0))
|
||||
uncond_latent, cond_latent = latents[0:1], latents[1:2]
|
||||
return uncond_latent + guidance * (cond_latent - uncond_latent)
|
||||
|
||||
def step(self, uncond, cond, latent, timestep, a_t, a_prev, guidance):
|
||||
e_t = self._model_output(uncond, cond, latent, timestep, guidance)
|
||||
return self._update(latent, e_t, a_t, a_prev).realize()
|
||||
|
||||
def decode(self, x):
|
||||
x = self.first_stage_model.post_quant_conv(1 / 0.18215 * x)
|
||||
x = self.first_stage_model.decoder(x)
|
||||
x = (x + 1.0) / 2.0
|
||||
x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
|
||||
return x.cast(dtypes.uint8)
|
||||
|
||||
def encode_prompt(self, tokenizer, prompt: str):
|
||||
ids = Tensor([tokenizer.encode(prompt)])
|
||||
return self.cond_stage_model.transformer.text_model(ids).realize()
|
||||
|
||||
|
||||
def run_sd15(model: StableDiffusion, prompt: str, negative_prompt: str, steps: int, guidance: float, seed: int):
|
||||
"""Generate a single 512x512 image. Returns a (512,512,3) uint8 tensor."""
|
||||
tokenizer = Tokenizer.ClipTokenizer()
|
||||
|
||||
context = model.encode_prompt(tokenizer, prompt)
|
||||
uncond = model.encode_prompt(tokenizer, negative_prompt)
|
||||
|
||||
timesteps = list(range(1, 1000, 1000 // steps))
|
||||
alphas = model.alphas_cumprod[Tensor(timesteps)]
|
||||
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
||||
|
||||
if seed is not None:
|
||||
Tensor.manual_seed(seed)
|
||||
latent = Tensor.randn(1, 4, 64, 64)
|
||||
|
||||
for index in range(len(timesteps) - 1, -1, -1):
|
||||
timestep = timesteps[index]
|
||||
tid = Tensor([index])
|
||||
latent = model.step(
|
||||
uncond, context, latent,
|
||||
Tensor([timestep]),
|
||||
alphas[tid], alphas_prev[tid],
|
||||
Tensor([guidance]),
|
||||
)
|
||||
|
||||
return model.decode(latent).realize()
|
||||
267
backend/python/tinygrad/vendor/unet.py
vendored
Normal file
267
backend/python/tinygrad/vendor/unet.py
vendored
Normal file
@@ -0,0 +1,267 @@
|
||||
# Vendored verbatim from tinygrad extra/models/unet.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/unet.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
from tinygrad import Tensor, dtypes, nn
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from typing import Optional, Union, List, Any, Tuple, Callable
|
||||
import math
|
||||
|
||||
# allow for monkeypatching
|
||||
Linear, Conv2d, GroupNorm, LayerNorm = nn.Linear, nn.Conv2d, nn.GroupNorm, nn.LayerNorm
|
||||
attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Tensor.gelu, dtypes.float16
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
||||
return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
|
||||
|
||||
class ResBlock:
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
|
||||
self.in_layers = [
|
||||
GroupNorm(num_groups, channels),
|
||||
Tensor.silu,
|
||||
Conv2d(channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.emb_layers = [
|
||||
Tensor.silu,
|
||||
Linear(emb_channels, out_channels),
|
||||
]
|
||||
self.out_layers = [
|
||||
GroupNorm(num_groups, out_channels),
|
||||
Tensor.silu,
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Conv2d(out_channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else (lambda x: x)
|
||||
|
||||
def __call__(self, x:Tensor, emb:Tensor) -> Tensor:
|
||||
h = x.sequential(self.in_layers)
|
||||
emb_out = emb.sequential(self.emb_layers)
|
||||
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
|
||||
h = h.sequential(self.out_layers)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
class CrossAttention:
|
||||
def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.attn = attention
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
ctx = x if ctx is None else ctx
|
||||
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = self.attn(q, k, v).transpose(1,2)
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in:int, dim_out:int):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
self.gelu = gelu
|
||||
self.dim_out = dim_out
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * self.gelu(gate)
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim:int, mult:int=4):
|
||||
self.net: tuple[GEGLU, Callable, nn.Linear] = (
|
||||
GEGLU(dim, dim*mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential(list(self.net))
|
||||
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
||||
self.ff = FeedForward(dim)
|
||||
self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.norm3 = LayerNorm(dim)
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
x = x + self.attn1(self.norm1(x))
|
||||
x = x + self.attn2(self.norm2(x), ctx=ctx)
|
||||
x = x + self.ff(self.norm3(x))
|
||||
return x
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1,
|
||||
norm_eps:float=1e-5):
|
||||
if isinstance(ctx_dim, int):
|
||||
ctx_dim = [ctx_dim]*depth
|
||||
else:
|
||||
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
|
||||
self.norm = GroupNorm(32, channels, eps=norm_eps)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
||||
self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
|
||||
self.proj_out = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
||||
self.use_linear = use_linear
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
ops = [ (lambda z: z.reshape(b, c, h*w).permute(0,2,1)), (lambda z: self.proj_in(z)) ]
|
||||
x = x.sequential(ops if self.use_linear else ops[::-1])
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, ctx=ctx)
|
||||
ops = [ (lambda z: self.proj_out(z)), (lambda z: z.permute(0,2,1).reshape(b, c, h, w)) ]
|
||||
x = x.sequential(ops if self.use_linear else ops[::-1])
|
||||
return x + x_in
|
||||
|
||||
class Downsample:
|
||||
def __init__(self, channels:int):
|
||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.op(x)
|
||||
|
||||
class Upsample:
|
||||
def __init__(self, channels:int):
|
||||
self.conv = Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
bs,c,py,px = x.shape
|
||||
z = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
||||
return self.conv(z)
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
|
||||
class UNetModel:
|
||||
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int,
|
||||
channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None,
|
||||
n_heads:Optional[int]=None, num_groups:int=32, st_norm_eps:float=1e-5):
|
||||
self.model_ch = model_ch
|
||||
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.d_head = d_head
|
||||
self.n_heads = n_heads
|
||||
def get_d_and_n_heads(dims:int) -> Tuple[int,int]:
|
||||
if self.d_head is None:
|
||||
assert self.n_heads is not None, f"d_head and n_heads cannot both be None"
|
||||
return dims // self.n_heads, self.n_heads
|
||||
else:
|
||||
assert self.n_heads is None, f"d_head and n_heads cannot both be non-None"
|
||||
return self.d_head, dims // self.d_head
|
||||
|
||||
time_embed_dim = model_ch * 4
|
||||
self.time_embed = [
|
||||
Linear(model_ch, time_embed_dim),
|
||||
Tensor.silu,
|
||||
Linear(time_embed_dim, time_embed_dim),
|
||||
]
|
||||
|
||||
if adm_in_ch is not None:
|
||||
self.label_emb = [
|
||||
[
|
||||
Linear(adm_in_ch, time_embed_dim),
|
||||
Tensor.silu,
|
||||
Linear(time_embed_dim, time_embed_dim),
|
||||
]
|
||||
]
|
||||
|
||||
self.input_blocks: List[Any] = [
|
||||
[Conv2d(in_ch, model_ch, 3, padding=1)]
|
||||
]
|
||||
input_block_channels = [model_ch]
|
||||
ch = model_ch
|
||||
ds = 1
|
||||
for idx, mult in enumerate(channel_mult):
|
||||
for _ in range(self.num_res_blocks[idx]):
|
||||
layers: List[Any] = [
|
||||
ResBlock(ch, time_embed_dim, model_ch*mult, num_groups),
|
||||
]
|
||||
ch = mult * model_ch
|
||||
if ds in attention_resolutions:
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
|
||||
|
||||
self.input_blocks.append(layers)
|
||||
input_block_channels.append(ch)
|
||||
|
||||
if idx != len(channel_mult) - 1:
|
||||
self.input_blocks.append([
|
||||
Downsample(ch),
|
||||
])
|
||||
input_block_channels.append(ch)
|
||||
ds *= 2
|
||||
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
self.middle_block: List = [
|
||||
ResBlock(ch, time_embed_dim, ch, num_groups),
|
||||
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1], norm_eps=st_norm_eps),
|
||||
ResBlock(ch, time_embed_dim, ch, num_groups),
|
||||
]
|
||||
|
||||
self.output_blocks = []
|
||||
for idx, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[idx] + 1):
|
||||
ich = input_block_channels.pop()
|
||||
layers = [
|
||||
ResBlock(ch + ich, time_embed_dim, model_ch*mult, num_groups),
|
||||
]
|
||||
ch = model_ch * mult
|
||||
|
||||
if ds in attention_resolutions:
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
|
||||
|
||||
if idx > 0 and i == self.num_res_blocks[idx]:
|
||||
layers.append(Upsample(ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(layers)
|
||||
|
||||
self.out = [
|
||||
GroupNorm(num_groups, ch),
|
||||
Tensor.silu,
|
||||
Conv2d(model_ch, out_ch, 3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
|
||||
t_emb = timestep_embedding(tms, self.model_ch)
|
||||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
if y is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + y.sequential(self.label_emb[0])
|
||||
|
||||
if is_dtype_supported(mixed_precision_dtype):
|
||||
emb = emb.cast(mixed_precision_dtype)
|
||||
ctx = ctx.cast(mixed_precision_dtype)
|
||||
x = x .cast(mixed_precision_dtype)
|
||||
|
||||
def run(x:Tensor, bb) -> Tensor:
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
elif isinstance(bb, SpatialTransformer): x = bb(x, ctx)
|
||||
else: x = bb(x)
|
||||
return x
|
||||
|
||||
saved_inputs = []
|
||||
for b in self.input_blocks:
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
saved_inputs.append(x)
|
||||
for bb in self.middle_block:
|
||||
x = run(x, bb)
|
||||
for b in self.output_blocks:
|
||||
x = x.cat(saved_inputs.pop(), dim=1)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
return x.sequential(self.out)
|
||||
274
backend/python/tinygrad/vendor/whisper.py
vendored
Normal file
274
backend/python/tinygrad/vendor/whisper.py
vendored
Normal file
@@ -0,0 +1,274 @@
|
||||
# Adapted from tinygrad examples/whisper.py (MIT license).
|
||||
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/whisper.py
|
||||
# Copyright (c) 2023- the tinygrad authors
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# Local modifications: removed the pyaudio listener / __main__ block; the rest
|
||||
# is the core Whisper model + preprocessing + single-file transcription path.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import collections
|
||||
import itertools
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes, nn
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn.state import load_state_dict, torch_load
|
||||
|
||||
from .audio_helpers import mel
|
||||
|
||||
|
||||
class MultiHeadAttention:
|
||||
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self', None] = None, max_self_attn_cache_len=None):
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
self.value = nn.Linear(n_state, n_state)
|
||||
self.out = nn.Linear(n_state, n_state)
|
||||
self.kv_caching = kv_caching
|
||||
self.max_self_attn_cache_len = max_self_attn_cache_len
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, len=None):
|
||||
if self.kv_caching == 'cross':
|
||||
if xa is not None:
|
||||
k, v = self.key(xa), self.value(xa)
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k, self.cache_v = k, v
|
||||
else:
|
||||
self.cache_k.assign(k).realize()
|
||||
self.cache_v.assign(v).realize()
|
||||
else:
|
||||
k, v = self.cache_k, self.cache_v
|
||||
else:
|
||||
k, v = self.key(x), self.value(x)
|
||||
if self.kv_caching == 'self':
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
|
||||
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
|
||||
padding = self.max_self_attn_cache_len - len - x.shape[1]
|
||||
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
|
||||
q = self.query(x)
|
||||
n_ctx = q.shape[1]
|
||||
head_dim = q.shape[-1] // self.n_head
|
||||
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None)
|
||||
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
return self.out(wv)
|
||||
|
||||
|
||||
class ResidualAttentionBlock:
|
||||
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
|
||||
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
|
||||
self.mlp = [nn.Linear(n_state, n_state * 4), Tensor.gelu, nn.Linear(n_state * 4, n_state)]
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, len=None):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
||||
x = x + self.mlp_ln(x).sequential(self.mlp)
|
||||
return x.realize()
|
||||
|
||||
|
||||
class AudioEncoder:
|
||||
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
|
||||
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
|
||||
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
|
||||
self.ln_post = nn.LayerNorm(n_audio_state)
|
||||
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
|
||||
self.encode = TinyJit(self.__call__)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x).gelu()
|
||||
x = self.conv2(x).gelu()
|
||||
x = x.permute(0, 2, 1)
|
||||
x = x + self.positional_embedding[:x.shape[1]]
|
||||
x = x.sequential(self.blocks)
|
||||
x = self.ln_post(x)
|
||||
return x.realize()
|
||||
|
||||
|
||||
class TextDecoder:
|
||||
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
|
||||
self.max_tokens_to_sample = n_text_ctx // 2
|
||||
self.max_self_attn_cache_len = n_text_ctx
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
|
||||
self.ln = nn.LayerNorm(n_text_state)
|
||||
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
|
||||
self.getjitted = collections.defaultdict(lambda: TinyJit(self.forward))
|
||||
|
||||
def __call__(self, x, pos, encoded_audio):
|
||||
pos = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len - 1).bind(pos) if pos else 0
|
||||
return self.getjitted[x.shape](x, pos, encoded_audio)
|
||||
|
||||
def forward(self, x, pos, encoded_audio):
|
||||
seqlen = x.shape[-1]
|
||||
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos + seqlen), None))
|
||||
for block in self.blocks:
|
||||
x = block(x, xa=encoded_audio, mask=self.mask, len=pos)
|
||||
return self.output_tok(x)
|
||||
|
||||
def output_tok(self, x):
|
||||
return (self.ln(x) @ self.token_embedding.weight.T).realize()
|
||||
|
||||
|
||||
class Whisper:
|
||||
def __init__(self, dims, batch_size=1):
|
||||
self.encoder = AudioEncoder(**dims)
|
||||
self.decoder = TextDecoder(**dims)
|
||||
self.is_multilingual = dims["n_vocab"] == 51865
|
||||
self.batch_size = batch_size
|
||||
|
||||
|
||||
RATE = 16000
|
||||
SEGMENT_SECONDS = 30
|
||||
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH
|
||||
|
||||
|
||||
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate: bool = False) -> np.ndarray:
|
||||
import librosa
|
||||
|
||||
def pad_or_trim(arr, target_len):
|
||||
if len(arr) == target_len:
|
||||
return arr
|
||||
if len(arr) < target_len:
|
||||
return np.pad(arr, (0, target_len - len(arr)), 'constant')
|
||||
return arr[:target_len]
|
||||
|
||||
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(w) for w in waveforms)
|
||||
if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
|
||||
max_len += SAMPLES_PER_SEGMENT - r
|
||||
|
||||
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
|
||||
if waveforms.shape[0] < batch_size:
|
||||
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
|
||||
|
||||
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||
mel_spec = mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS).numpy() @ magnitudes
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max((1, 2), keepdims=True) - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean",
|
||||
"fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "it": "italian",
|
||||
}
|
||||
|
||||
|
||||
def get_encoding(encoding_name: str):
|
||||
import tiktoken
|
||||
|
||||
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
|
||||
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
|
||||
n_vocab = len(ranks)
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
|
||||
return tiktoken.Encoding(
|
||||
name=encoding_name,
|
||||
explicit_n_vocab=n_vocab + len(specials),
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
MODEL_URLS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
}
|
||||
|
||||
|
||||
def init_whisper(model_name: str = "base", batch_size: int = 1):
|
||||
filename = fetch(MODEL_URLS[model_name])
|
||||
state = torch_load(filename)
|
||||
model = Whisper(state['dims'], batch_size)
|
||||
load_state_dict(model, state['model_state_dict'], strict=False)
|
||||
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
|
||||
return model, enc
|
||||
|
||||
|
||||
def load_file_waveform(filename: str):
|
||||
import librosa
|
||||
waveform, _ = librosa.load(filename, sr=RATE)
|
||||
return waveform
|
||||
|
||||
|
||||
def transcribe_waveform(model: Whisper, enc, waveforms, language: Optional[str] = None, truncate: bool = False) -> str:
|
||||
log_spec = prep_audio(waveforms, model.batch_size, truncate)
|
||||
nsample = model.decoder.max_tokens_to_sample
|
||||
nctx = model.decoder.max_self_attn_cache_len
|
||||
|
||||
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
if model.is_multilingual:
|
||||
lang = language if (language and language in LANGUAGES) else "en"
|
||||
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index(lang)
|
||||
start_tokens.append(language_token)
|
||||
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
||||
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
||||
|
||||
eot = enc._special_tokens["<|endoftext|>"]
|
||||
|
||||
def inferloop(ctx, encoded_audio):
|
||||
pos, next_tokens = 0, ctx
|
||||
for _ in range(nsample):
|
||||
next_tokens = model.decoder(Tensor(next_tokens, dtype=dtypes.int32), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
|
||||
next_tokens[ctx[:, -1] == eot] = eot
|
||||
ctx = np.concatenate((ctx, next_tokens), axis=1)
|
||||
pos = ctx.shape[-1] - 1
|
||||
if (next_tokens == eot).all() or pos == nctx:
|
||||
break
|
||||
return ctx
|
||||
|
||||
ctx = np.tile(start_tokens, (model.batch_size, 1))
|
||||
transcriptions: list[list[int]] = [[] for _ in waveforms]
|
||||
|
||||
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
|
||||
ctx_arr = inferloop(np.array(ctx), encoded_audio)
|
||||
for i, arr in enumerate(ctx_arr):
|
||||
if i >= len(waveforms):
|
||||
break
|
||||
end_idxs = np.where(arr == eot)[0]
|
||||
start_idx = np.where(arr == start_tokens[-1])[0][0] + 1
|
||||
end_idx = end_idxs[0] if len(end_idxs) else None
|
||||
transcriptions[i].extend(arr[start_idx:end_idx])
|
||||
ctx = ctx_arr
|
||||
|
||||
texts = [enc.decode([int(t) for t in toks]).strip() for toks in transcriptions]
|
||||
return texts[0] if len(texts) == 1 else "\n".join(texts)
|
||||
@@ -44,13 +44,20 @@ import (
|
||||
// BACKEND_TEST_AUDIO_FILE Path to an already-available sample audio file.
|
||||
// BACKEND_TEST_CAPS Comma-separated list of capabilities to exercise.
|
||||
// Supported values: health, load, predict, stream,
|
||||
// embeddings, tools, transcription.
|
||||
// embeddings, tools, transcription, image.
|
||||
// Defaults to "health,load,predict,stream".
|
||||
// A backend that only does embeddings would set this to
|
||||
// "health,load,embeddings"; an image/TTS backend that cannot
|
||||
// be driven by a text prompt can set it to "health,load".
|
||||
// "health,load,embeddings"; an image-generation backend
|
||||
// that cannot be driven by a text prompt can set it to
|
||||
// "health,load,image".
|
||||
// "tools" asks the backend to extract a tool call from the
|
||||
// model output into ChatDelta.tool_calls.
|
||||
// "image" exercises the GenerateImage RPC and asserts a
|
||||
// non-empty file is written to the requested dst path.
|
||||
// BACKEND_TEST_IMAGE_PROMPT Override the positive prompt for the image spec
|
||||
// (default: "a photograph of an astronaut riding a horse").
|
||||
// BACKEND_TEST_IMAGE_STEPS Override the diffusion step count for the image spec
|
||||
// (default: 4 — keeps CPU-only runs under a few minutes).
|
||||
// BACKEND_TEST_PROMPT Override the prompt used by predict/stream specs.
|
||||
// BACKEND_TEST_CTX_SIZE Override the context size passed to LoadModel (default 512).
|
||||
// BACKEND_TEST_THREADS Override Threads passed to LoadModel (default 4).
|
||||
@@ -75,11 +82,14 @@ const (
|
||||
capEmbeddings = "embeddings"
|
||||
capTools = "tools"
|
||||
capTranscription = "transcription"
|
||||
capImage = "image"
|
||||
|
||||
defaultPrompt = "The capital of France is"
|
||||
streamPrompt = "Once upon a time"
|
||||
defaultToolPrompt = "What's the weather like in Paris, France?"
|
||||
defaultToolName = "get_weather"
|
||||
defaultPrompt = "The capital of France is"
|
||||
streamPrompt = "Once upon a time"
|
||||
defaultToolPrompt = "What's the weather like in Paris, France?"
|
||||
defaultToolName = "get_weather"
|
||||
defaultImagePrompt = "a photograph of an astronaut riding a horse"
|
||||
defaultImageSteps = 4
|
||||
)
|
||||
|
||||
func defaultCaps() map[string]bool {
|
||||
@@ -349,6 +359,40 @@ var _ = Describe("Backend container", Ordered, func() {
|
||||
GinkgoWriter.Printf("Embedding: %d dims\n", len(res.GetEmbeddings()))
|
||||
})
|
||||
|
||||
It("generates an image via GenerateImage", func() {
|
||||
if !caps[capImage] {
|
||||
Skip("image capability not enabled")
|
||||
}
|
||||
|
||||
imgPrompt := os.Getenv("BACKEND_TEST_IMAGE_PROMPT")
|
||||
if imgPrompt == "" {
|
||||
imgPrompt = defaultImagePrompt
|
||||
}
|
||||
steps := envInt32("BACKEND_TEST_IMAGE_STEPS", defaultImageSteps)
|
||||
|
||||
dst := filepath.Join(workDir, "generated.png")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
res, err := client.GenerateImage(ctx, &pb.GenerateImageRequest{
|
||||
PositivePrompt: imgPrompt,
|
||||
NegativePrompt: "",
|
||||
Width: 512,
|
||||
Height: 512,
|
||||
Step: steps,
|
||||
Seed: 42,
|
||||
Dst: dst,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.GetSuccess()).To(BeTrue(), "GenerateImage failed: %s", res.GetMessage())
|
||||
|
||||
info, err := os.Stat(dst)
|
||||
Expect(err).NotTo(HaveOccurred(), "GenerateImage did not write a file at %s", dst)
|
||||
Expect(info.Size()).To(BeNumerically(">", int64(0)),
|
||||
"GenerateImage wrote an empty file at %s", dst)
|
||||
GinkgoWriter.Printf("GenerateImage: wrote %s (%d bytes)\n", dst, info.Size())
|
||||
})
|
||||
|
||||
It("extracts tool calls into ChatDelta", func() {
|
||||
if !caps[capTools] {
|
||||
Skip("tools capability not enabled")
|
||||
|
||||
Reference in New Issue
Block a user