refactor(tinygrad): reuse tinygrad.apps.llm instead of vendored Transformer (#9380)

Drop the 295-line vendor/llama.py fork in favor of `tinygrad.apps.llm`,
which now provides the Transformer blocks, GGUF loader (incl. Q4/Q6/Q8
quantization), KV-cache and generate loop we were maintaining ourselves.

What changed:
- New vendor/appsllm_adapter.py (~90 LOC) — HF -> GGUF-native state-dict
  keymap, Transformer kwargs builder, `_embed_hidden` helper, and a hard
  rejection of qkv_bias models (Qwen2 / 2.5 are no longer supported; the
  apps.llm Transformer ties `bias=False` on Q/K/V projections).
- backend.py routes both safetensors and GGUF paths through
  apps.llm.Transformer. Generation now delegates to its (greedy-only)
  `generate()`; Temperature / TopK / TopP / RepetitionPenalty are still
  accepted on the wire but ignored — documented in the module docstring.
- Jinja chat render now passes `enable_thinking=False` so Qwen3's
  reasoning preamble doesn't eat the tool-call token budget on small
  models.
- Embedding path uses `_embed_hidden` (block stack + output_norm) rather
  than the custom `embed()` method we were carrying on the vendored
  Transformer.
- test.py gains TestAppsLLMAdapter covering the keymap rename, tied
  embedding fallback, unknown-key skipping, and qkv_bias rejection.
- Makefile fixtures move from Qwen/Qwen2.5-0.5B-Instruct to Qwen/Qwen3-0.6B
  (apps.llm-compatible) and tool_parser from qwen3_xml to hermes (the
  HF chat template emits hermes-style JSON tool calls).

Verified with the docker-backed targets:
  test-extra-backend-tinygrad             5/5 PASS
  test-extra-backend-tinygrad-embeddings  3/3 PASS
  test-extra-backend-tinygrad-whisper     4/4 PASS
  test-extra-backend-tinygrad-sd          3/3 PASS
This commit is contained in:
Ettore Di Giacinto
2026-04-16 22:41:18 +02:00
committed by GitHub
parent b4e30692a2
commit a0cbc46be9
5 changed files with 342 additions and 430 deletions

View File

@@ -2,15 +2,28 @@
"""
LocalAI gRPC backend for tinygrad.
Multimodal scope (landing incrementally):
- LLM text generation (Llama 3 / Qwen 2.5 / Mistral via HF safetensors + GGUF)
- Native tool-call extraction via pluggable parsers (hermes, llama3_json, ...)
- Embeddings, Stable Diffusion, Whisper — planned; currently return UNIMPLEMENTED.
LLM execution is delegated to `tinygrad.apps.llm.Transformer` — we keep
only a thin HF → GGUF-name adapter (vendor/appsllm_adapter.py) for the
safetensors path; GGUF models load through `Transformer.from_gguf()`
with native Q4/Q6/Q8 support.
The heavy imports (tinygrad, tokenizers, vendor.llama) are deferred until
`LoadModel`, because tinygrad binds its compute device at import time from
env vars. `_select_tinygrad_device()` maps LocalAI's BUILD_TYPE onto the
corresponding tinygrad env flag before any import happens.
Scope:
- LLM text generation via apps.llm (Qwen3 / Qwen3.5 / Llama 3.x /
GLM-4 / OLMoE / Kimi-K2 / Moonlight — anything apps.llm supports).
- Native tool-call extraction via pluggable parsers (hermes,
llama3_json, qwen3_xml, mistral).
- Embeddings — mean-pooled last-hidden-state over the block stack.
- Stable Diffusion 1.x, Whisper — handled by the vendored paths.
Sampling is greedy-only because `apps.llm.Transformer.generate` (in the
tinygrad 0.12.0 PyPI release) ends with `.argmax(-1)` and takes no
temperature / top-k / top-p / repetition-penalty arguments. These
request fields are accepted and ignored.
The heavy imports (tinygrad, tokenizers, tinygrad.apps.llm) are deferred
until `LoadModel`, because tinygrad binds its compute device at import
time from env vars. `_select_tinygrad_device()` maps LocalAI's BUILD_TYPE
onto the corresponding tinygrad env flag before any import happens.
"""
from __future__ import annotations
@@ -62,8 +75,10 @@ def _select_tinygrad_device() -> None:
def _resolve_model_assets(model_ref: str) -> Path:
"""
Accept either a local path or a HuggingFace repo id (e.g.
"Qwen/Qwen2.5-0.5B-Instruct") and return the local directory / file.
HF ids are materialized via `huggingface_hub.snapshot_download`.
"unsloth/Qwen3.5-0.8B-GGUF") and return the local directory / file.
HF ids are materialized via `huggingface_hub.snapshot_download` — we
pull both safetensors (for fp16 HF repos) and GGUF (for quantized
repos) so the same code path handles either.
"""
p = Path(model_ref)
if p.exists():
@@ -80,23 +95,27 @@ def _resolve_model_assets(model_ref: str) -> Path:
"generation_config.json",
"*.safetensors",
"*.safetensors.index.json",
"*.gguf",
],
)
return Path(local)
raise FileNotFoundError(f"Model not found: {model_ref}")
def _load_llm_weights(model_dir: Path):
"""Load HF safetensors weights from a directory or a single file."""
from tinygrad import Device, Tensor, dtypes
from tinygrad.nn.state import safe_load, gguf_load
def _gguf_path(model_ref: Path) -> Optional[Path]:
"""Return the GGUF file to load from a path that may be a file or dir."""
if model_ref.is_file() and str(model_ref).endswith(".gguf"):
return model_ref
if model_ref.is_dir():
ggufs = sorted(model_ref.glob("*.gguf"))
if ggufs:
return ggufs[0]
return None
if model_dir.is_file():
if str(model_dir).endswith(".gguf"):
gguf_tensor = Tensor.empty(os.stat(model_dir).st_size, dtype=dtypes.uint8,
device=f"disk:{model_dir}").to(Device.DEFAULT)
return gguf_load(gguf_tensor)[1], "gguf"
return safe_load(str(model_dir)), "safetensors"
def _load_hf_safetensors(model_dir: Path) -> dict[str, Any]:
"""Load sharded or single-file HF safetensors from a directory."""
from tinygrad.nn.state import safe_load
index = model_dir / "model.safetensors.index.json"
if index.exists():
@@ -105,28 +124,13 @@ def _load_llm_weights(model_dir: Path):
shards: dict[str, Any] = {}
for shard_name in set(weight_map.values()):
shards[shard_name] = safe_load(str(model_dir / shard_name))
merged = {k: shards[n][k] for k, n in weight_map.items()}
return merged, "safetensors"
return {k: shards[n][k] for k, n in weight_map.items()}
single = model_dir / "model.safetensors"
if single.exists():
return safe_load(str(single)), "safetensors"
return safe_load(str(single))
ggufs = sorted(model_dir.glob("*.gguf"))
if ggufs:
gguf_tensor = Tensor.empty(os.stat(ggufs[0]).st_size, dtype=dtypes.uint8,
device=f"disk:{ggufs[0]}").to(Device.DEFAULT)
return gguf_load(gguf_tensor)[1], "gguf"
raise FileNotFoundError(f"No safetensors or gguf weights found under {model_dir}")
def _infer_qkv_bias(config: dict) -> bool:
"""Qwen2-family uses Q/K/V projection bias; Llama/Mistral do not."""
architectures = [a.lower() for a in config.get("architectures", [])]
if any("qwen" in a for a in architectures):
return True
return bool(config.get("attention_bias", False)) or bool(config.get("qkv_bias", False))
raise FileNotFoundError(f"No safetensors weights found under {model_dir}")
def _auto_tool_parser(model_ref: Optional[str], config: dict) -> Optional[str]:
@@ -252,69 +256,102 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
messages=self._messages_to_dicts(request.Messages),
tools=tools,
add_generation_prompt=True,
# Qwen3's chat template enables <think>...</think> reasoning
# by default. On small models (0.6B) that reasoning preamble
# eats the whole token budget before a tool call emerges, so
# we disable it. Templates that don't know this var ignore it.
enable_thinking=False,
)
# --------------------- LLM path -------------------------------------
def _load_llm(self, model_dir: Path) -> None:
config_path = model_dir / "config.json" if model_dir.is_dir() else None
if config_path is None or not config_path.exists():
raise FileNotFoundError(f"config.json not found under {model_dir}")
with open(config_path) as fp:
config = json.load(fp)
self.llm_config = config
def _load_llm(self, model_path: Path) -> None:
"""Load an LLM through `tinygrad.apps.llm.Transformer`.
from tinygrad import nn
Two paths:
- GGUF file (anywhere in the tree) → `Transformer.from_gguf()`
handles config, weight conversion (incl. Q4/Q6/Q8 quantization)
and RoPE permute natively.
- HF safetensors directory → build `TransformerConfig` from
config.json and load weights via a small HF→GGUF-name adapter.
"""
from tinygrad import Device, Tensor, dtypes
from tinygrad.apps.llm import Transformer
from tinygrad.nn.state import load_state_dict
from vendor.llama import Transformer, convert_from_huggingface, convert_from_gguf, fix_bf16
n_layers = config["num_hidden_layers"]
n_heads = config["num_attention_heads"]
n_kv_heads = config.get("num_key_value_heads", n_heads)
dim = config["hidden_size"]
hidden_dim = config["intermediate_size"]
vocab_size = config["vocab_size"]
norm_eps = config.get("rms_norm_eps", 1e-5)
rope_theta = config.get("rope_theta", 10000)
self.max_context = min(config.get("max_position_embeddings", 4096), 8192)
qkv_bias = _infer_qkv_bias(config)
weights, weight_format = _load_llm_weights(model_dir)
model = Transformer(
dim=dim,
hidden_dim=hidden_dim,
n_heads=n_heads,
n_layers=n_layers,
norm_eps=norm_eps,
vocab_size=vocab_size,
linear=nn.Linear,
embedding=nn.Embedding,
n_kv_heads=n_kv_heads,
rope_theta=rope_theta,
max_context=self.max_context,
jit=True,
qkv_bias=qkv_bias,
from vendor.appsllm_adapter import (
_hf_to_appsllm_state_dict,
_hf_to_transformer_kwargs,
)
if weight_format == "safetensors":
weights = convert_from_huggingface(weights, n_layers, n_heads, n_kv_heads)
max_context_cap = 8192
gguf_file = _gguf_path(model_path)
if gguf_file is not None:
# GGUF path: apps.llm handles everything — config, quant, RoPE.
gguf_tensor = Tensor.empty(
os.stat(gguf_file).st_size, dtype=dtypes.uint8,
device=f"disk:{gguf_file}",
).to(Device.DEFAULT)
model, kv = Transformer.from_gguf(gguf_tensor, max_context=max_context_cap)
self.llm_model = model
self.max_context = model.max_context
# Preserve a config-shaped dict for tool-parser heuristics and
# the "loaded" message.
arch = kv.get("general.architecture", "")
self.llm_config = {
"architectures": [kv.get("general.name", arch) or arch],
"gguf_kv": kv,
}
# Tokenizer: prefer sidecar tokenizer.json (richer HF Jinja2
# templates), fall back to apps.llm's SimpleTokenizer built
# from GGUF metadata.
self._load_tokenizer_for_dir(model_path if model_path.is_dir() else gguf_file.parent, gguf_kv=kv)
else:
weights = convert_from_gguf(weights, n_layers)
weights = fix_bf16(weights)
# HF safetensors path.
if not model_path.is_dir():
raise FileNotFoundError(f"Expected HF model directory, got file: {model_path}")
config_path = model_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"config.json not found under {model_path}")
with open(config_path) as fp:
hf_config = json.load(fp)
self.llm_config = hf_config
from tinygrad.nn.state import load_state_dict
load_state_dict(model, weights, strict=False, consume=True)
self.llm_model = model
raw_weights = _load_hf_safetensors(model_path)
n_layers = hf_config["num_hidden_layers"]
state_dict = _hf_to_appsllm_state_dict(raw_weights, n_layers)
# Tokenizer
kwargs = _hf_to_transformer_kwargs(hf_config, state_dict, max_context_cap)
self.max_context = kwargs["max_context"]
model = Transformer(**kwargs)
load_state_dict(model, state_dict, strict=False, consume=True)
self.llm_model = model
self._load_tokenizer_for_dir(model_path, gguf_kv=None)
# Auto-pick tool parser from options or model family.
parser_name = self.options.get("tool_parser") or _auto_tool_parser(self.model_ref, self.llm_config)
self.tool_parser = resolve_parser(parser_name)
def _load_tokenizer_for_dir(self, model_dir: Path, gguf_kv: Optional[dict]) -> None:
"""Load HF tokenizer + chat template + EOS ids from a model directory.
Falls back to apps.llm's `SimpleTokenizer.from_gguf_kv` when there
is no `tokenizer.json` sidecar (single-file GGUF, no HF repo).
"""
tokenizer_json = model_dir / "tokenizer.json"
if not tokenizer_json.exists():
if tokenizer_json.exists():
from tokenizers import Tokenizer as HFTokenizer
self.llm_tokenizer = HFTokenizer.from_file(str(tokenizer_json))
elif gguf_kv is not None:
from tinygrad.apps.llm import SimpleTokenizer
self.llm_tokenizer = SimpleTokenizer.from_gguf_kv(gguf_kv)
else:
raise FileNotFoundError(f"tokenizer.json not found under {model_dir}")
from tokenizers import Tokenizer as HFTokenizer
self.llm_tokenizer = HFTokenizer.from_file(str(tokenizer_json))
# Chat template + EOS ids (from tokenizer_config.json + generation_config.json)
tok_cfg_path = model_dir / "tokenizer_config.json"
if tok_cfg_path.exists():
with open(tok_cfg_path) as fp:
@@ -322,26 +359,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.chat_template = tok_cfg.get("chat_template")
self.llm_eos_ids = []
gen_cfg_path = model_dir / "generation_config.json"
if gen_cfg_path.exists():
with open(gen_cfg_path) as fp:
gen_cfg = json.load(fp)
eos = gen_cfg.get("eos_token_id")
for cfg_name in ("generation_config.json", "config.json"):
cfg_path = model_dir / cfg_name
if not cfg_path.exists():
continue
with open(cfg_path) as fp:
cfg = json.load(fp)
eos = cfg.get("eos_token_id")
if isinstance(eos, list):
self.llm_eos_ids.extend(int(x) for x in eos)
elif isinstance(eos, int):
self.llm_eos_ids.append(eos)
if not self.llm_eos_ids:
eos = config.get("eos_token_id")
if isinstance(eos, list):
self.llm_eos_ids.extend(int(x) for x in eos)
elif isinstance(eos, int):
if self.llm_eos_ids:
break
if not self.llm_eos_ids and gguf_kv is not None:
eos = gguf_kv.get("tokenizer.ggml.eos_token_id")
if isinstance(eos, int):
self.llm_eos_ids.append(eos)
# Auto-pick tool parser from options or model family.
parser_name = self.options.get("tool_parser") or _auto_tool_parser(self.model_ref, config)
self.tool_parser = resolve_parser(parser_name)
# --------------------- Stable Diffusion path ------------------------
def _load_sd(self, model_ref: str) -> None:
@@ -400,28 +435,37 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# --------------------- LLM generation -------------------------------
def _generate_tokens(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int):
"""Synchronous generator yielding (token_id, token_text) pairs."""
from tinygrad import Tensor
def _encode_prompt(self, prompt: str) -> list[int]:
"""Normalize tokenizer output: HF `tokenizers.Tokenizer.encode()`
returns an `Encoding` with `.ids`; apps.llm's `SimpleTokenizer.encode()`
returns `list[int]` directly."""
encoded = self.llm_tokenizer.encode(prompt)
ids = encoded.ids
return list(getattr(encoded, "ids", encoded))
def _decode_tokens(self, ids: list[int]) -> str:
return self.llm_tokenizer.decode(ids)
def _generate_tokens(self, prompt: str, max_new_tokens: int, temperature: float):
"""Yield (token_id, token_text) pairs using `apps.llm.Transformer.generate()`.
tinygrad 0.12.0's `generate()` is greedy-only (its `forward` ends
with `.argmax(-1)` and it takes no temperature / top-k / top-p
knobs). We accept `temperature` in the signature for API
compatibility but it is ignored.
"""
del temperature # tinygrad.apps.llm.Transformer.generate is greedy-only
ids = self._encode_prompt(prompt)
if not ids:
return
# Prefill: run one forward pass over the full prompt, sampling from the last token.
prompt_tensor = Tensor([ids])
next_tok = int(self.llm_model(prompt_tensor, 0, temperature=temperature, top_k=top_k, top_p=top_p).item())
pos = len(ids)
for _ in range(max_new_tokens):
count = 0
for next_tok in self.llm_model.generate(list(ids)):
if next_tok in self.llm_eos_ids:
break
text = self.llm_tokenizer.decode([next_tok])
yield next_tok, text
next_tok = int(self.llm_model(Tensor([[next_tok]]), pos, temperature=temperature,
top_k=top_k, top_p=top_p).item())
pos += 1
yield next_tok, self._decode_tokens([next_tok])
count += 1
if count >= max_new_tokens:
break
# --------------------- gRPC methods ---------------------------------
@@ -455,12 +499,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
)
model_path = _resolve_model_assets(self.model_ref)
if model_path.is_file():
return backend_pb2.Result(
success=False,
message=("tinygrad currently requires an HF model directory with config.json + "
"tokenizer.json; single-file GGUF support lands with gallery metadata wiring."),
)
self._load_llm(model_path)
return backend_pb2.Result(
@@ -483,13 +521,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
prompt = self._render_prompt(request)
max_new = request.Tokens if request.Tokens > 0 else 256
temperature = request.Temperature if request.Temperature > 0 else 0.7
top_p = request.TopP if request.TopP > 0 else 0.95
top_k = request.TopK if request.TopK > 0 else 0
t0 = time.monotonic()
pieces: list[str] = []
ntok = 0
for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
for _, text in self._generate_tokens(prompt, max_new, temperature):
pieces.append(text)
ntok += 1
elapsed = time.monotonic() - t0
@@ -534,11 +570,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
prompt = self._render_prompt(request)
max_new = request.Tokens if request.Tokens > 0 else 256
temperature = request.Temperature if request.Temperature > 0 else 0.7
top_p = request.TopP if request.TopP > 0 else 0.95
top_k = request.TopK if request.TopK > 0 else 0
buffer = ""
for _, text in self._generate_tokens(prompt, max_new, temperature, top_p, top_k):
for _, text in self._generate_tokens(prompt, max_new, temperature):
buffer += text
yield backend_pb2.Reply(
message=text.encode("utf-8"),
@@ -585,8 +619,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.EmbeddingResult()
from tinygrad import Tensor, dtypes
from vendor.appsllm_adapter import _embed_hidden
ids = self.llm_tokenizer.encode(text).ids
ids = self._encode_prompt(text)
if not ids:
return backend_pb2.EmbeddingResult(embeddings=[])
@@ -594,7 +629,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
ids = ids[: self.max_context]
tokens = Tensor([ids])
hidden = self.llm_model.embed(tokens) # (1, seqlen, dim)
hidden = _embed_hidden(self.llm_model, tokens) # (1, seqlen, dim)
# Mean pool over sequence dim
pooled = hidden.mean(axis=1).squeeze(0) # (dim,)
# L2 normalize

View File

@@ -22,6 +22,7 @@ import backend_pb2_grpc
sys.path.insert(0, os.path.dirname(__file__))
from tool_parsers.hermes import HermesToolParser # noqa: E402
from vendor.appsllm_adapter import _hf_to_appsllm_state_dict # noqa: E402
class TestHealth(unittest.TestCase):
@@ -80,5 +81,73 @@ class TestHermesParser(unittest.TestCase):
self.assertEqual(calls, [])
class TestAppsLLMAdapter(unittest.TestCase):
"""Smoke tests for the HF → tinygrad.apps.llm state-dict keymap."""
def _fake_hf_weights(self, n_layers: int = 2, include_lm_head: bool = True):
keys = [
"model.embed_tokens.weight",
"model.norm.weight",
]
if include_lm_head:
keys.append("lm_head.weight")
for l in range(n_layers):
keys += [
f"model.layers.{l}.input_layernorm.weight",
f"model.layers.{l}.post_attention_layernorm.weight",
f"model.layers.{l}.self_attn.q_proj.weight",
f"model.layers.{l}.self_attn.k_proj.weight",
f"model.layers.{l}.self_attn.v_proj.weight",
f"model.layers.{l}.self_attn.o_proj.weight",
f"model.layers.{l}.self_attn.q_norm.weight",
f"model.layers.{l}.self_attn.k_norm.weight",
f"model.layers.{l}.mlp.gate_proj.weight",
f"model.layers.{l}.mlp.up_proj.weight",
f"model.layers.{l}.mlp.down_proj.weight",
]
# sentinel objects so we can verify identity-based aliasing
return {k: object() for k in keys}
def test_keymap_renames_every_hf_key(self):
hf = self._fake_hf_weights(n_layers=2)
sd = _hf_to_appsllm_state_dict(hf, 2)
expected = {
"token_embd.weight", "output_norm.weight", "output.weight",
"blk.0.attn_norm.weight", "blk.0.ffn_norm.weight",
"blk.0.attn_q.weight", "blk.0.attn_k.weight", "blk.0.attn_v.weight",
"blk.0.attn_output.weight",
"blk.0.attn_q_norm.weight", "blk.0.attn_k_norm.weight",
"blk.0.ffn_gate.weight", "blk.0.ffn_up.weight", "blk.0.ffn_down.weight",
"blk.1.attn_norm.weight", "blk.1.ffn_norm.weight",
"blk.1.attn_q.weight", "blk.1.attn_k.weight", "blk.1.attn_v.weight",
"blk.1.attn_output.weight",
"blk.1.attn_q_norm.weight", "blk.1.attn_k_norm.weight",
"blk.1.ffn_gate.weight", "blk.1.ffn_up.weight", "blk.1.ffn_down.weight",
}
self.assertEqual(set(sd.keys()), expected)
def test_tied_embedding_fallback_when_lm_head_missing(self):
hf = self._fake_hf_weights(n_layers=1, include_lm_head=False)
sd = _hf_to_appsllm_state_dict(hf, 1)
self.assertIn("output.weight", sd)
self.assertIs(sd["output.weight"], sd["token_embd.weight"])
def test_unknown_keys_are_skipped(self):
hf = self._fake_hf_weights(n_layers=1)
hf["model.layers.0.self_attn.rotary_emb.inv_freq"] = object()
hf["model.some_unknown.weight"] = object()
sd = _hf_to_appsllm_state_dict(hf, 1)
self.assertNotIn("model.some_unknown.weight", sd)
# Renamed keys still present
self.assertIn("blk.0.attn_q.weight", sd)
def test_qkv_bias_models_rejected(self):
hf = self._fake_hf_weights(n_layers=1)
hf["model.layers.0.self_attn.q_proj.bias"] = object()
with self.assertRaises(ValueError) as ctx:
_hf_to_appsllm_state_dict(hf, 1)
self.assertIn("Qwen3", str(ctx.exception))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,102 @@
"""Glue code between LocalAI's HF-shaped model assets and tinygrad.apps.llm.
apps.llm's `Transformer` uses GGUF-native weight names and consumes a
`TransformerConfig` dataclass. LocalAI resolves models from HuggingFace
snapshots (HF safetensors + config.json) so we translate both sides here.
This module does NOT subclass anything from apps.llm. With the Qwen3+
scope the backend targets, we can use `apps.llm.Transformer` unchanged
(no qkv_bias, no RoPE permute). Everything below is a thin adapter.
"""
from __future__ import annotations
from typing import Any
def _hf_to_appsllm_state_dict(hf_weights: dict[str, Any], n_layers: int) -> dict[str, Any]:
"""Rename a HuggingFace-style state dict to the GGUF-native keys that
`tinygrad.apps.llm.Transformer` expects.
HF and apps.llm both store RoPE weights in half-split layout, so no
permute is required — only a direct key rename and a tied-embedding
fallback for models like Llama 3.2 that drop `lm_head.weight`.
"""
keymap: dict[str, str] = {
"model.embed_tokens.weight": "token_embd.weight",
"model.norm.weight": "output_norm.weight",
"lm_head.weight": "output.weight",
}
for layer in range(n_layers):
keymap[f"model.layers.{layer}.input_layernorm.weight"] = f"blk.{layer}.attn_norm.weight"
keymap[f"model.layers.{layer}.post_attention_layernorm.weight"] = f"blk.{layer}.ffn_norm.weight"
for hf_proj, gguf_proj in (("q", "q"), ("k", "k"), ("v", "v"), ("o", "output")):
keymap[f"model.layers.{layer}.self_attn.{hf_proj}_proj.weight"] = f"blk.{layer}.attn_{gguf_proj}.weight"
keymap[f"model.layers.{layer}.self_attn.q_norm.weight"] = f"blk.{layer}.attn_q_norm.weight"
keymap[f"model.layers.{layer}.self_attn.k_norm.weight"] = f"blk.{layer}.attn_k_norm.weight"
for hf_name, gguf_name in (("gate", "gate"), ("up", "up"), ("down", "down")):
keymap[f"model.layers.{layer}.mlp.{hf_name}_proj.weight"] = f"blk.{layer}.ffn_{gguf_name}.weight"
# Fail loudly if the model carries Q/K/V projection bias (Qwen2 / 2.5).
# apps.llm's `TransformerBlock` hardcodes `bias=False`, so these weights
# would be silently dropped by `load_state_dict(strict=False)` and the
# model would produce garbage. Supported families (Qwen3, Qwen3.5,
# Llama 3.x, GLM-4, Mistral) have no qkv bias.
bias_keys = [k for k in hf_weights
if k.startswith("model.layers.") and
any(k.endswith(f".self_attn.{p}_proj.bias") for p in ("q", "k", "v"))]
if bias_keys:
raise ValueError(
"tinygrad backend: model has Q/K/V projection bias ("
f"{bias_keys[0]} etc). Supported families are Qwen3, Qwen3.5, "
"Llama 3.x, GLM-4, Mistral. For Qwen2 / 2.5 please use a "
"newer model or the vLLM / llama.cpp backends."
)
sd = {dst: hf_weights[src] for src, dst in keymap.items() if src in hf_weights}
if "output.weight" not in sd and "token_embd.weight" in sd:
sd["output.weight"] = sd["token_embd.weight"]
return sd
def _hf_to_transformer_kwargs(hf_config: dict, state_dict: dict[str, Any], max_context: int) -> dict:
"""Build the kwargs dict for `tinygrad.apps.llm.Transformer(**kwargs)`.
Supports dense Qwen3 / Qwen3.5 / Llama 3.x / GLM-4 / Mistral-shaped
models. The tinygrad 0.12.0 `Transformer` takes keyword-only args (no
`TransformerConfig` dataclass) — so we return a plain dict.
"""
n_heads = hf_config["num_attention_heads"]
head_dim = hf_config.get("head_dim") or (hf_config["hidden_size"] // n_heads)
# Detect qk_norm presence from the GGUF-shaped state dict (matches
# apps.llm's own heuristic in `from_gguf`).
qk_norm = 0
qn = state_dict.get("blk.0.attn_q_norm.weight")
if qn is not None:
qk_norm = int(qn.shape[0])
max_pos = hf_config.get("max_position_embeddings", 4096)
return dict(
num_blocks=hf_config["num_hidden_layers"],
dim=hf_config["hidden_size"],
hidden_dim=hf_config["intermediate_size"],
n_heads=n_heads,
n_kv_heads=hf_config.get("num_key_value_heads", n_heads),
norm_eps=hf_config.get("rms_norm_eps", 1e-5),
vocab_size=hf_config["vocab_size"],
head_dim=head_dim,
rope_theta=float(hf_config.get("rope_theta", 10000.0)),
max_context=min(max_pos, max_context),
qk_norm=qk_norm,
)
def _embed_hidden(model, tokens):
"""Return mean-poolable hidden states by running the block stack
without going through the LM head + Gumbel-max sampler baked into
`Transformer.forward`."""
x = model.token_embd(tokens).float()
for blk in model.blk:
x = blk(x, 0)
return model.output_norm(x)

View File

@@ -1,294 +0,0 @@
# Vendored from tinygrad extra/models/llama.py (MIT license).
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/extra/models/llama.py
#
# Local modification: Attention / TransformerBlock / Transformer accept an
# optional `qkv_bias` flag so the same module can host Qwen2-style models that
# use bias on the Q/K/V projections (Llama 3 has no bias). Changes are marked
# with `# LOCALAI`.
#
# Copyright (c) 2023- the tinygrad authors
# SPDX-License-Identifier: MIT
from typing import Union, Optional, Any
import collections, math
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv, DEBUG
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2)
def complex_mult(A, c, d):
a, b = A[..., 0:1], A[..., 1:2]
ro = a * c - b * d
co = a * d + b * c
return ro.cat(co, dim=-1)
def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
xq_out = complex_mult(xq, c, d)
xk_out = complex_mult(xk, c, d)
return xq_out.flatten(3), xk_out.flatten(3)
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
bs, seqlen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
class Attention:
# LOCALAI: added qkv_bias
def __init__(self, dim, n_heads, n_kv_heads=None, max_context=0, linear=nn.Linear, qk_norm: float | None = None, qkv_bias: bool = False):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.max_context = max_context
self.wq = linear(dim, self.n_heads * self.head_dim, bias=qkv_bias)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=qkv_bias)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
self.k_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor] = None) -> Tensor:
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
if self.q_norm is not None and self.k_norm is not None:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
if x.dtype == dtypes.bfloat16:
xq, xk = xq.contiguous_backward(), xk.contiguous_backward()
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
bsz, seqlen, _, _ = xq.shape
if self.max_context:
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
if isinstance(x.device, tuple):
self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
self.cache_kv[:, :, start_pos:start_pos + seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
keys = self.cache_kv[0, :, 0:start_pos + seqlen, :, :]
values = self.cache_kv[1, :, 0:start_pos + seqlen, :, :]
else:
assert start_pos == 0
keys, values = xk, xv
if self.max_context:
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
else:
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, is_causal=True, enable_gqa=True).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)
return self.wo(attn)
class FeedForward:
def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
self.w1 = linear(dim, hidden_dim, bias=False)
self.w2 = linear(hidden_dim, dim, bias=False)
self.w3 = linear(dim, hidden_dim, bias=False)
def __call__(self, x: Tensor) -> Tensor:
w1 = self.w1(x).silu()
w3 = self.w3(x.contiguous_backward())
return self.w2(w1 * w3)
class TransformerBlock:
# LOCALAI: added qkv_bias
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int,
linear=nn.Linear, feed_forward=FeedForward, qk_norm=None, qkv_bias: bool = False):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear, qk_norm, qkv_bias=qkv_bias)
self.feed_forward = feed_forward(dim, hidden_dim, linear)
self.attention_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).contiguous().contiguous_backward()
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
assert logits.ndim == 1, "only works on 1d tensors"
assert 0 <= p <= 1, "p must be between 0 and 1"
assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
if temp < 1e-6:
return logits.argmax()
logits = logits.to(Device.DEFAULT)
if af or ap:
if not hasattr(sample, "alpha_counter"):
setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
logits = (logits != logits).where(-float("inf"), logits)
t = (logits / temp).softmax()
counter = Tensor.arange(t.numel(), device=logits.device).contiguous()
counter2 = Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
if k:
output = Tensor.zeros(k, device=logits.device).contiguous()
output_indices = Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
for i in range(k):
t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
t = (counter == t_argmax).where(0, t)
output_cumsum = output[::-1].cumsum()[::-1] + t.sum()
output = (output_cumsum >= (1 - p)) * output
output_indices = (output_cumsum >= (1 - p)) * output_indices
output_idx = output.multinomial()
output_token = output_indices[output_idx]
else:
output_token = t.multinomial()
if af or ap:
sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
return output_token
class Transformer:
# LOCALAI: added qkv_bias
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_layers: int, norm_eps: float, vocab_size,
linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000,
max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None, disable_kv_cache=False,
qkv_bias: bool = False):
self.layers = [
TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps,
0 if disable_kv_cache else max_context, linear,
feed_forward=feed_forward, qk_norm=qk_norm, qkv_bias=qkv_bias)
for _ in range(n_layers)
]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
self.max_context = max_context
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous().requires_grad_(False)
self.forward_jit = TinyJit(self.forward) if jit else None
def forward(self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens).contiguous()
freqs_cis = self.freqs_cis.cast(h.dtype)[:, start_pos:start_pos + seqlen, :, :, :]
if self.max_context != 0 and seqlen > 1:
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1)
else:
mask = None
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h).contiguous().contiguous_backward()).contiguous_backward()
if math.isnan(temperature):
return logits
return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
def __call__(self, tokens: Tensor, start_pos: int, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context - 1).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
# LOCALAI: extract last hidden state for embeddings. Skips the LM head and
# the causal-mask branch is left intact so the pooling sees the full sequence.
def embed(self, tokens: Tensor) -> Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens).contiguous()
freqs_cis = self.freqs_cis.cast(h.dtype)[:, 0:seqlen, :, :, :]
mask = Tensor.full((1, 1, seqlen, seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(1) if seqlen > 1 else None
for layer in self.layers:
h = layer(h, 0, freqs_cis, mask)
return self.norm(h)
def convert_from_huggingface(weights: dict[str, Tensor], n_layers: int, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1] if len(v.shape) > 1 else 1).transpose(1, 2).reshape(*v.shape[:2])
keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)},
**{f"model.layers.{l}.self_attn.{x}_norm.weight": f"layers.{l}.attention.{x}_norm.weight" for x in ["q", "k"] for l in range(n_layers)},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(n_layers)},
**{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(n_layers)},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)},
**{f"model.layers.{l}.mlp.gate.weight": f"layers.{l}.feed_forward.gate.weight" for l in range(n_layers)},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
sd = {}
experts = collections.defaultdict(dict)
for k, v in weights.items():
if ".rotary_emb." in k:
continue
v = v.to(Device.DEFAULT)
if "model.layers" in k:
if ("q_proj" in k or "q_norm" in k) and permute_layers:
v = permute(v, n_heads)
elif ("k_proj" in k or "k_norm" in k) and permute_layers:
v = permute(v, n_kv_heads)
if '.mlp.experts.' in k:
_, _, layer, _, _, expert, name, _ = k.split('.')
experts[f'layers.{layer}.feed_forward.{name}'][int(expert)] = v
continue
sd[keymap[k]] = v
for k, v in experts.items():
sd[k] = Tensor.stack(*[v[i] for i in range(len(v))])
if "output.weight" not in sd and "tok_embeddings.weight" in sd:
sd["output.weight"] = sd["tok_embeddings.weight"]
return sd
def convert_from_gguf(weights: dict[str, Tensor], n_layers: int):
keymap = {
"token_embd.weight": "tok_embeddings.weight",
**{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(n_layers)},
**{f"blk.{l}.attn_{x}.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v"] for l in range(n_layers)},
**{f"blk.{l}.attn_{x}.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v"] for l in range(n_layers)},
**{f"blk.{l}.attn_output.weight": f"layers.{l}.attention.wo.weight" for l in range(n_layers)},
**{f"blk.{l}.ffn_norm.weight": f"layers.{l}.ffn_norm.weight" for l in range(n_layers)},
**{f"blk.{l}.ffn_{x}.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(n_layers)},
"output_norm.weight": "norm.weight",
"rope_freqs.weight": "rope_freqs.weight",
}
sd = {keymap[k]: v for k, v in weights.items() if k in keymap}
if "output.weight" not in sd and "token_embd.weight" in weights:
sd["output.weight"] = weights["token_embd.weight"]
return sd
def fix_bf16(weights: dict[Any, Tensor]):
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}