Files
LocalAI/backend/python/sglang/backend.py
Richard Palethorpe c894d9c826 feat(sglang): wire engine_args, add cuda13 build, ship MTP gallery demos (#9686)
Bring the sglang Python backend up to feature parity with vllm by adding
the same engine_args:-map plumbing the vLLM backend already has. Any
ServerArgs field (~380 in sglang 0.5.11) becomes settable from a model
YAML, including the speculative-decoding flags needed for Multi-Token
Prediction. Validation matches the vllm backend's: keys are checked
against dataclasses.fields(ServerArgs), unknown keys raise ValueError
with a difflib close-match suggestion at LoadModel time, and the typed
ModelOptions fields keep their existing meaning with engine_args
overriding them.

Backend code:
* backend/python/sglang/backend.py: add _apply_engine_args, import
  dataclasses/difflib/ServerArgs, call from LoadModel; rename Seed ->
  sampling_seed (sglang 0.5.11 renamed the SamplingParams field).
* backend/python/sglang/test.py + test.sh + Makefile: six unit tests
  exercising the helper directly (no engine load required).

Build / CI / backend gallery (cuda13 + l4t13 paths are now first-class):
* backend/python/sglang/install.sh: add --prerelease=allow because
  sglang 0.5.11 hard-pins flash-attn-4 which only ships beta wheels;
  add --index-strategy=unsafe-best-match for cublas12 so the cu128
  torch index wins over default-PyPI's cu130; new pyproject.toml-driven
  l4t13 install path so [tool.uv.sources] can pin torch/torchvision/
  torchaudio/sglang to the jetson-ai-lab index without forcing every
  transitive PyPI dep through the L4T mirror's flaky proxy (mirrors the
  equivalent fix in backend/python/vllm/install.sh).
* backend/python/sglang/pyproject.toml (new): L4T project spec with
  explicit-source jetson-ai-lab index. Replaces requirements-l4t13.txt
  for the l4t13 BUILD_PROFILE; other profiles still go through the
  requirements-*.txt pipeline via libbackend.sh's installRequirements.
* backend/python/sglang/requirements-l4t13.txt: removed; superseded
  by pyproject.toml.
* backend/python/sglang/requirements-cublas{12,13}{,-after}.txt: pin
  sglang>=0.5.11 (Gemma 4 floor); add cu130 torch index for cublas13
  (new files) and cu128 torch index for cublas12 (default PyPI now
  ships cu130 torch wheels by default and breaks cu12 hosts).
* backend/index.yaml: add cuda13-sglang and cuda13-sglang-development
  capability mappings + image entries pointing at
  quay.io/.../-gpu-nvidia-cuda-13-sglang.
* .github/workflows/backend.yml: new cublas13 sglang matrix entry,
  mirroring vllm's cuda13 build.

Model gallery + docs:
* gallery/sglang.yaml: base sglang config template, mirrors vllm.yaml.
* gallery/sglang-gemma-4-{e2b,e4b}-mtp.yaml: Gemma 4 MTP demos
  transcribed verbatim from the SGLang Gemma 4 cookbook MTP commands.
* gallery/sglang-mimo-7b-mtp.yaml: MiMo-7B-RL with built-in MTP heads
  + online fp8 weight quantization, verified end-to-end on a 16 GB
  RTX 5070 Ti at ~88 tok/s. Uses mem_fraction_static: 0.7 because the
  MTP draft worker's vocab embedding is loaded unquantised and OOMs
  the static reservation at sglang's 0.85 default.
* gallery/index.yaml: three new entries (gemma-4-e2b-it:sglang-mtp,
  gemma-4-e4b-it:sglang-mtp, mimo-7b-mtp:sglang).
* docs/content/features/text-generation.md: new SGLang section with
  setup, engine_args reference, MTP demos, version requirements.
* .agents/sglang-backend.md (new): agent one-pager covering the flat
  ServerArgs structure, the typed-vs-engine_args precedence, the
  speculative-decoding cheatsheet, and the mem_fraction_static gotcha
  documented above.
* AGENTS.md: index entry for the new agent doc.

Known limitation: the two Gemma 4 MTP gallery entries ship a recipe
that doesn't yet run on stock libraries. The drafter checkpoints
(google/gemma-4-{E2B,E4B}-it-assistant) declare
model_type: gemma4_assistant / Gemma4AssistantForCausalLM, which
neither transformers (<=5.6.0, including the SGLang cookbook's pinned
commit 91b1ab1f... and main HEAD) nor sglang's own model registry
(<=0.5.11) registers as of 2026-05-06. They will start working when
HF or sglang upstream registers the architecture -- no LocalAI
changes needed. The MiMo MTP demo and the non-MTP Gemma 4 paths work
today on this build (verified on RTX 5070 Ti, 16 GB).

Assisted-by: Claude:claude-opus-4-7 [Read] [Edit] [Bash] [WebFetch] [WebSearch]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-07 17:27:29 +02:00

566 lines
23 KiB
Python

#!/usr/bin/env python3
"""LocalAI gRPC backend for sglang.
Wraps sglang's async Engine API behind the Backend gRPC contract defined
in backend.proto. Mirrors the structure of backend/python/vllm/backend.py
so that the two backends stay behavior-equivalent at the protocol level.
The streaming path applies sglang's per-request FunctionCallParser and
ReasoningParser so tool_calls and reasoning_content are emitted
incrementally inside ChatDelta, which is a capability sglang exposes
natively and vLLM does not.
Like the vLLM backend, this one accepts an arbitrary ``engine_args:``
map in the model YAML; keys are validated against ``ServerArgs`` fields
and forwarded to ``Engine(**kwargs)``. That covers speculative decoding
(EAGLE/EAGLE3/DFLASH/NGRAM/STANDALONE plus MTP via NEXTN), attention
backend selection, MoE knobs, hierarchical cache, and so on.
"""
import asyncio
from concurrent import futures
import argparse
import dataclasses
import difflib
import signal
import sys
import os
import json
import gc
import uuid
import base64
import io
from typing import Dict, List, Optional, Tuple
from PIL import Image
import backend_pb2
import backend_pb2_grpc
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
# sglang imports. Engine is the stable public entry point; parser modules
# are wrapped in try/except so older / leaner installs that omit them
# still load the backend for plain text generation.
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.server_args import ServerArgs
try:
from sglang.srt.function_call.function_call_parser import FunctionCallParser
# sglang's FunctionCallParser expects a list of pydantic Tool objects
# (protocol.Tool with .function.name), not plain dicts. Wrap at the
# request boundary to match.
from sglang.srt.entrypoints.openai.protocol import Tool as SglTool
HAS_TOOL_PARSERS = True
except Exception:
FunctionCallParser = None # type: ignore
SglTool = None # type: ignore
HAS_TOOL_PARSERS = False
try:
from sglang.srt.parser.reasoning_parser import ReasoningParser
HAS_REASONING_PARSERS = True
except Exception:
ReasoningParser = None # type: ignore
HAS_REASONING_PARSERS = False
try:
from transformers import AutoTokenizer
HAS_TRANSFORMERS = True
except Exception:
AutoTokenizer = None # type: ignore
HAS_TRANSFORMERS = False
# sglang 0.5.11 renamed SamplingParams.seed -> sampling_seed (PR #21952).
# Earlier 0.5.x releases (e.g. 0.5.1.post2 — the wheel still pinned by the
# pypi.jetson-ai-lab.io sbsa/cu130 mirror used by the l4t13 build profile)
# accept only `seed`. Detect the supported keyword once at import time so
# both versions work without a hard pin floor.
try:
import inspect as _inspect
from sglang.srt.sampling.sampling_params import SamplingParams as _SamplingParams
_SEED_KEY = "sampling_seed" if "sampling_seed" in _inspect.signature(_SamplingParams).parameters else "seed"
except Exception:
_SEED_KEY = "sampling_seed"
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""gRPC servicer implementing the Backend service for sglang."""
def _parse_options(self, options_list) -> Dict[str, str]:
opts: Dict[str, str] = {}
for opt in options_list:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
opts[key.strip()] = value.strip()
return opts
def _apply_engine_args(self, engine_kwargs: dict, engine_args_json: str) -> dict:
"""Merge user-supplied engine_args (JSON object) into the kwargs dict
that will be forwarded to ``sglang.Engine`` (which constructs a
``ServerArgs`` from them).
Mirrors ``backend/python/vllm/backend.py::_apply_engine_args`` but
operates on the kwargs dict because sglang's ``Engine.__init__``
accepts ``**kwargs`` directly rather than a pre-built dataclass.
Validation happens against ``ServerArgs`` fields so a typo fails
early with a close-match suggestion instead of producing a confusing
``TypeError`` deep inside engine startup.
"""
if not engine_args_json:
return engine_kwargs
try:
extra = json.loads(engine_args_json)
except json.JSONDecodeError as e:
raise ValueError(f"engine_args is not valid JSON: {e}") from e
if not isinstance(extra, dict):
raise ValueError(
f"engine_args must be a JSON object, got {type(extra).__name__}"
)
valid = {f.name for f in dataclasses.fields(ServerArgs)}
for key in extra:
if key not in valid:
suggestion = difflib.get_close_matches(key, valid, n=1)
hint = f" did you mean {suggestion[0]!r}?" if suggestion else ""
raise ValueError(f"unknown engine_args key {key!r}.{hint}")
engine_kwargs.update(extra)
return engine_kwargs
def _messages_to_dicts(self, messages) -> List[dict]:
result: List[dict] = []
for msg in messages:
d = {"role": msg.role, "content": msg.content or ""}
if msg.name:
d["name"] = msg.name
if msg.tool_call_id:
d["tool_call_id"] = msg.tool_call_id
if msg.reasoning_content:
d["reasoning_content"] = msg.reasoning_content
if msg.tool_calls:
try:
d["tool_calls"] = json.loads(msg.tool_calls)
except json.JSONDecodeError:
pass
result.append(d)
return result
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
async def LoadModel(self, request, context):
engine_kwargs = {"model_path": request.Model}
if request.Quantization:
engine_kwargs["quantization"] = request.Quantization
if request.LoadFormat:
engine_kwargs["load_format"] = request.LoadFormat
if request.GPUMemoryUtilization:
engine_kwargs["mem_fraction_static"] = float(request.GPUMemoryUtilization)
if request.TrustRemoteCode:
engine_kwargs["trust_remote_code"] = True
if request.EnforceEager:
engine_kwargs["disable_cuda_graph"] = True
if request.TensorParallelSize:
engine_kwargs["tp_size"] = int(request.TensorParallelSize)
if request.MaxModelLen:
engine_kwargs["context_length"] = int(request.MaxModelLen)
if request.DType:
engine_kwargs["dtype"] = request.DType
opts = self._parse_options(request.Options)
# Cache parser names — actual parser instances are created per
# request because sglang's parsers are stateful.
self.tool_parser_name: Optional[str] = opts.get("tool_parser") or None
self.reasoning_parser_name: Optional[str] = opts.get("reasoning_parser") or None
# Also hand the parser names to sglang's engine so its HTTP/OAI
# paths work identically if someone hits the engine directly.
if self.tool_parser_name:
engine_kwargs["tool_call_parser"] = self.tool_parser_name
if self.reasoning_parser_name:
engine_kwargs["reasoning_parser"] = self.reasoning_parser_name
# engine_args from YAML overrides typed fields above so operators can
# tune anything ServerArgs exposes (speculative decoding, attention
# backend, MoE, hierarchical cache, …) without waiting on protobuf
# changes.
try:
engine_kwargs = self._apply_engine_args(engine_kwargs, request.EngineArgs)
except ValueError as err:
print(f"engine_args error: {err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=str(err))
try:
self.llm = Engine(**engine_kwargs)
except Exception as err:
print(f"sglang Engine init failed: {err!r}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"{err!r}")
# sglang does not expose a uniform get_tokenizer() off Engine.
# Use transformers directly — same path sglang uses internally.
self.tokenizer = None
if HAS_TRANSFORMERS:
try:
self.tokenizer = AutoTokenizer.from_pretrained(
request.Model,
trust_remote_code=bool(request.TrustRemoteCode),
)
except Exception as err:
print(f"AutoTokenizer load failed (non-fatal): {err!r}", file=sys.stderr)
print("Model loaded successfully", file=sys.stderr)
return backend_pb2.Result(message="Model loaded successfully", success=True)
async def Predict(self, request, context):
gen = self._predict(request, context, streaming=False)
res = await gen.__anext__()
return res
async def PredictStream(self, request, context):
iterations = self._predict(request, context, streaming=True)
try:
async for iteration in iterations:
yield iteration
finally:
try:
await iterations.aclose()
except Exception:
pass
async def TokenizeString(self, request, context):
if not getattr(self, "tokenizer", None):
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("tokenizer not loaded")
return backend_pb2.TokenizationResponse()
try:
tokens = self.tokenizer.encode(request.Prompt)
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
except Exception as e:
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return backend_pb2.TokenizationResponse()
async def Free(self, request, context):
try:
if hasattr(self, "llm"):
try:
self.llm.shutdown()
except Exception:
pass
del self.llm
if hasattr(self, "tokenizer"):
del self.tokenizer
self.tool_parser_name = None
self.reasoning_parser_name = None
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
return backend_pb2.Result(success=True, message="Model freed")
except Exception as e:
return backend_pb2.Result(success=False, message=str(e))
def _build_sampling_params(self, request) -> dict:
sampling_params: dict = {"temperature": 0.7, "max_new_tokens": 200}
mapping = {
"N": "n",
"PresencePenalty": "presence_penalty",
"FrequencyPenalty": "frequency_penalty",
"RepetitionPenalty": "repetition_penalty",
"Temperature": "temperature",
"TopP": "top_p",
"TopK": "top_k",
"MinP": "min_p",
"Seed": _SEED_KEY,
"StopPrompts": "stop",
"StopTokenIds": "stop_token_ids",
"IgnoreEOS": "ignore_eos",
"Tokens": "max_new_tokens",
"MinTokens": "min_new_tokens",
"SkipSpecialTokens": "skip_special_tokens",
}
for proto_field, sgl_key in mapping.items():
if not hasattr(request, proto_field):
continue
value = getattr(request, proto_field)
if value in (None, 0, 0.0, [], False, ""):
continue
# repeated fields come back as RepeatedScalarContainer — convert
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
value = list(value)
if not value:
continue
sampling_params[sgl_key] = value
# Grammar → JSON schema or EBNF structured decoding.
if getattr(request, "Grammar", ""):
grammar = request.Grammar
try:
json.loads(grammar)
sampling_params["json_schema"] = grammar
except json.JSONDecodeError:
sampling_params["ebnf"] = grammar
return sampling_params
def _build_prompt(self, request) -> str:
prompt = request.Prompt
if prompt or not request.UseTokenizerTemplate or not request.Messages:
return prompt
if self.tokenizer is None:
print(
"UseTokenizerTemplate requested but tokenizer not loaded; "
"falling back to naive concatenation",
file=sys.stderr,
)
return "\n".join(m.content or "" for m in request.Messages)
messages_dicts = self._messages_to_dicts(request.Messages)
template_kwargs: dict = {"tokenize": False, "add_generation_prompt": True}
if request.Tools:
try:
template_kwargs["tools"] = json.loads(request.Tools)
except json.JSONDecodeError:
pass
if request.Metadata.get("enable_thinking", "").lower() == "true":
template_kwargs["enable_thinking"] = True
try:
return self.tokenizer.apply_chat_template(messages_dicts, **template_kwargs)
except TypeError:
return self.tokenizer.apply_chat_template(
messages_dicts, tokenize=False, add_generation_prompt=True,
)
def _make_parsers(self, request):
"""Construct fresh per-request parser instances (stateful)."""
tool_parser = None
reasoning_parser = None
if HAS_TOOL_PARSERS and self.tool_parser_name and request.Tools:
try:
tools_raw = json.loads(request.Tools)
tools = [SglTool.model_validate(t) for t in tools_raw] if SglTool else tools_raw
tool_parser = FunctionCallParser(
tools=tools, tool_call_parser=self.tool_parser_name,
)
except Exception as e:
print(f"FunctionCallParser init failed: {e!r}", file=sys.stderr)
if HAS_REASONING_PARSERS and self.reasoning_parser_name:
try:
reasoning_parser = ReasoningParser(
model_type=self.reasoning_parser_name,
stream_reasoning=True,
)
except Exception as e:
print(f"ReasoningParser init failed: {e!r}", file=sys.stderr)
return tool_parser, reasoning_parser
async def _predict(self, request, context, streaming: bool = False):
sampling_params = self._build_sampling_params(request)
prompt = self._build_prompt(request)
tool_parser, reasoning_parser = self._make_parsers(request)
image_data = list(request.Images) if request.Images else None
video_data = list(request.Videos) if request.Videos else None
# Kick off streaming generation. We always use stream=True so the
# non-stream path still gets parser coverage on the final text.
try:
iterator = await self.llm.async_generate(
prompt=prompt,
sampling_params=sampling_params,
image_data=image_data,
video_data=video_data,
stream=True,
)
except Exception as e:
print(f"sglang async_generate failed: {e!r}", file=sys.stderr)
yield backend_pb2.Reply(message=bytes(f"error: {e!r}", "utf-8"))
return
generated_text = ""
last_chunk: Optional[dict] = None
# Track tool call ids once per (request, tool_index) to match the
# OpenAI streaming contract (id sent on first chunk for that tool).
tool_ids_seen: Dict[int, str] = {}
try:
async for chunk in iterator:
last_chunk = chunk
cumulative = chunk.get("text", "") if isinstance(chunk, dict) else ""
delta_text = cumulative[len(generated_text):] if cumulative.startswith(generated_text) else cumulative
generated_text = cumulative
if not delta_text:
continue
reasoning_delta = ""
content_delta = delta_text
if reasoning_parser is not None:
try:
r, n = reasoning_parser.parse_stream_chunk(delta_text)
reasoning_delta = r or ""
content_delta = n or ""
except Exception as e:
print(f"reasoning_parser.parse_stream_chunk: {e!r}", file=sys.stderr)
tool_call_deltas: List[backend_pb2.ToolCallDelta] = []
if tool_parser is not None and content_delta:
try:
normal_text, calls = tool_parser.parse_stream_chunk(content_delta)
content_delta = normal_text or ""
for tc in calls:
idx = int(getattr(tc, "tool_index", 0) or 0)
tc_id = tool_ids_seen.get(idx)
if tc_id is None:
tc_id = f"call_{uuid.uuid4().hex[:24]}"
tool_ids_seen[idx] = tc_id
tool_call_deltas.append(backend_pb2.ToolCallDelta(
index=idx,
id=tc_id,
name=getattr(tc, "name", "") or "",
arguments=getattr(tc, "parameters", "") or "",
))
except Exception as e:
print(f"tool_parser.parse_stream_chunk: {e!r}", file=sys.stderr)
if streaming and (content_delta or reasoning_delta or tool_call_deltas):
yield backend_pb2.Reply(
message=bytes(content_delta, "utf-8"),
chat_deltas=[backend_pb2.ChatDelta(
content=content_delta,
reasoning_content=reasoning_delta,
tool_calls=tool_call_deltas,
)],
)
finally:
try:
await iterator.aclose()
except Exception:
pass
# Extract token counts from the final chunk's meta_info.
meta = {}
if isinstance(last_chunk, dict):
meta = last_chunk.get("meta_info") or {}
prompt_tokens = int(meta.get("prompt_tokens", 0) or 0)
completion_tokens = int(meta.get("completion_tokens", 0) or 0)
# Non-streaming path: re-parse the full text with fresh parsers
# so we return a clean, complete ChatDelta. Streaming parsers
# used above have accumulated state we don't want to reuse.
final_content = generated_text
final_reasoning = ""
final_tool_calls: List[backend_pb2.ToolCallDelta] = []
if not streaming:
final_reasoning_parser = None
if HAS_REASONING_PARSERS and self.reasoning_parser_name:
try:
final_reasoning_parser = ReasoningParser(
model_type=self.reasoning_parser_name,
stream_reasoning=False,
)
except Exception:
final_reasoning_parser = None
if final_reasoning_parser is not None:
try:
r, n = final_reasoning_parser.parse_non_stream(generated_text)
final_reasoning = r or ""
final_content = n if n is not None else generated_text
except Exception as e:
print(f"reasoning_parser.parse_non_stream: {e!r}", file=sys.stderr)
if HAS_TOOL_PARSERS and self.tool_parser_name and request.Tools:
try:
tools_raw = json.loads(request.Tools)
tools = [SglTool.model_validate(t) for t in tools_raw] if SglTool else tools_raw
fresh_tool_parser = FunctionCallParser(
tools=tools, tool_call_parser=self.tool_parser_name,
)
normal, calls = fresh_tool_parser.parse_non_stream(final_content)
if calls:
final_content = normal
for tc in calls:
idx = int(getattr(tc, "tool_index", 0) or 0)
final_tool_calls.append(backend_pb2.ToolCallDelta(
index=idx,
id=f"call_{uuid.uuid4().hex[:24]}",
name=getattr(tc, "name", "") or "",
arguments=getattr(tc, "parameters", "") or "",
))
except Exception as e:
print(f"tool_parser.parse_non_stream: {e!r}", file=sys.stderr)
chat_delta = backend_pb2.ChatDelta(
content=final_content if not streaming else "",
reasoning_content=final_reasoning,
tool_calls=final_tool_calls,
)
if streaming:
yield backend_pb2.Reply(
message=b"",
prompt_tokens=prompt_tokens,
tokens=completion_tokens,
chat_deltas=[chat_delta],
)
return
yield backend_pb2.Reply(
message=bytes(final_content or "", "utf-8"),
prompt_tokens=prompt_tokens,
tokens=completion_tokens,
chat_deltas=[chat_delta],
)
async def serve(address):
server = grpc.aio.server(
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
interceptors=get_auth_interceptors(aio=True),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
await server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
await server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the sglang gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to.",
)
args = parser.parse_args()
asyncio.run(serve(args.addr))