#!/usr/bin/env python3 """LocalAI gRPC backend for sglang. Wraps sglang's async Engine API behind the Backend gRPC contract defined in backend.proto. Mirrors the structure of backend/python/vllm/backend.py so that the two backends stay behavior-equivalent at the protocol level. The streaming path applies sglang's per-request FunctionCallParser and ReasoningParser so tool_calls and reasoning_content are emitted incrementally inside ChatDelta, which is a capability sglang exposes natively and vLLM does not. Like the vLLM backend, this one accepts an arbitrary ``engine_args:`` map in the model YAML; keys are validated against ``ServerArgs`` fields and forwarded to ``Engine(**kwargs)``. That covers speculative decoding (EAGLE/EAGLE3/DFLASH/NGRAM/STANDALONE plus MTP via NEXTN), attention backend selection, MoE knobs, hierarchical cache, and so on. """ import asyncio from concurrent import futures import argparse import dataclasses import difflib import signal import sys import os import json import gc import uuid import base64 import io from typing import Dict, List, Optional, Tuple from PIL import Image import backend_pb2 import backend_pb2_grpc import grpc sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) from grpc_auth import get_auth_interceptors # sglang imports. Engine is the stable public entry point; parser modules # are wrapped in try/except so older / leaner installs that omit them # still load the backend for plain text generation. from sglang.srt.entrypoints.engine import Engine from sglang.srt.server_args import ServerArgs try: from sglang.srt.function_call.function_call_parser import FunctionCallParser # sglang's FunctionCallParser expects a list of pydantic Tool objects # (protocol.Tool with .function.name), not plain dicts. Wrap at the # request boundary to match. from sglang.srt.entrypoints.openai.protocol import Tool as SglTool HAS_TOOL_PARSERS = True except Exception: FunctionCallParser = None # type: ignore SglTool = None # type: ignore HAS_TOOL_PARSERS = False try: from sglang.srt.parser.reasoning_parser import ReasoningParser HAS_REASONING_PARSERS = True except Exception: ReasoningParser = None # type: ignore HAS_REASONING_PARSERS = False try: from transformers import AutoTokenizer HAS_TRANSFORMERS = True except Exception: AutoTokenizer = None # type: ignore HAS_TRANSFORMERS = False # sglang 0.5.11 renamed SamplingParams.seed -> sampling_seed (PR #21952). # Earlier 0.5.x releases (e.g. 0.5.1.post2 — the wheel still pinned by the # pypi.jetson-ai-lab.io sbsa/cu130 mirror used by the l4t13 build profile) # accept only `seed`. Detect the supported keyword once at import time so # both versions work without a hard pin floor. try: import inspect as _inspect from sglang.srt.sampling.sampling_params import SamplingParams as _SamplingParams _SEED_KEY = "sampling_seed" if "sampling_seed" in _inspect.signature(_SamplingParams).parameters else "seed" except Exception: _SEED_KEY = "sampling_seed" _ONE_DAY_IN_SECONDS = 60 * 60 * 24 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) class BackendServicer(backend_pb2_grpc.BackendServicer): """gRPC servicer implementing the Backend service for sglang.""" def _parse_options(self, options_list) -> Dict[str, str]: opts: Dict[str, str] = {} for opt in options_list: if ":" not in opt: continue key, value = opt.split(":", 1) opts[key.strip()] = value.strip() return opts def _apply_engine_args(self, engine_kwargs: dict, engine_args_json: str) -> dict: """Merge user-supplied engine_args (JSON object) into the kwargs dict that will be forwarded to ``sglang.Engine`` (which constructs a ``ServerArgs`` from them). Mirrors ``backend/python/vllm/backend.py::_apply_engine_args`` but operates on the kwargs dict because sglang's ``Engine.__init__`` accepts ``**kwargs`` directly rather than a pre-built dataclass. Validation happens against ``ServerArgs`` fields so a typo fails early with a close-match suggestion instead of producing a confusing ``TypeError`` deep inside engine startup. """ if not engine_args_json: return engine_kwargs try: extra = json.loads(engine_args_json) except json.JSONDecodeError as e: raise ValueError(f"engine_args is not valid JSON: {e}") from e if not isinstance(extra, dict): raise ValueError( f"engine_args must be a JSON object, got {type(extra).__name__}" ) valid = {f.name for f in dataclasses.fields(ServerArgs)} for key in extra: if key not in valid: suggestion = difflib.get_close_matches(key, valid, n=1) hint = f" did you mean {suggestion[0]!r}?" if suggestion else "" raise ValueError(f"unknown engine_args key {key!r}.{hint}") engine_kwargs.update(extra) return engine_kwargs def _messages_to_dicts(self, messages) -> List[dict]: result: List[dict] = [] for msg in messages: d = {"role": msg.role, "content": msg.content or ""} if msg.name: d["name"] = msg.name if msg.tool_call_id: d["tool_call_id"] = msg.tool_call_id if msg.reasoning_content: d["reasoning_content"] = msg.reasoning_content if msg.tool_calls: try: d["tool_calls"] = json.loads(msg.tool_calls) except json.JSONDecodeError: pass result.append(d) return result def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) async def LoadModel(self, request, context): engine_kwargs = {"model_path": request.Model} if request.Quantization: engine_kwargs["quantization"] = request.Quantization if request.LoadFormat: engine_kwargs["load_format"] = request.LoadFormat if request.GPUMemoryUtilization: engine_kwargs["mem_fraction_static"] = float(request.GPUMemoryUtilization) if request.TrustRemoteCode: engine_kwargs["trust_remote_code"] = True if request.EnforceEager: engine_kwargs["disable_cuda_graph"] = True if request.TensorParallelSize: engine_kwargs["tp_size"] = int(request.TensorParallelSize) if request.MaxModelLen: engine_kwargs["context_length"] = int(request.MaxModelLen) if request.DType: engine_kwargs["dtype"] = request.DType opts = self._parse_options(request.Options) # Cache parser names — actual parser instances are created per # request because sglang's parsers are stateful. self.tool_parser_name: Optional[str] = opts.get("tool_parser") or None self.reasoning_parser_name: Optional[str] = opts.get("reasoning_parser") or None # Also hand the parser names to sglang's engine so its HTTP/OAI # paths work identically if someone hits the engine directly. if self.tool_parser_name: engine_kwargs["tool_call_parser"] = self.tool_parser_name if self.reasoning_parser_name: engine_kwargs["reasoning_parser"] = self.reasoning_parser_name # engine_args from YAML overrides typed fields above so operators can # tune anything ServerArgs exposes (speculative decoding, attention # backend, MoE, hierarchical cache, …) without waiting on protobuf # changes. try: engine_kwargs = self._apply_engine_args(engine_kwargs, request.EngineArgs) except ValueError as err: print(f"engine_args error: {err}", file=sys.stderr) return backend_pb2.Result(success=False, message=str(err)) try: self.llm = Engine(**engine_kwargs) except Exception as err: print(f"sglang Engine init failed: {err!r}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"{err!r}") # sglang does not expose a uniform get_tokenizer() off Engine. # Use transformers directly — same path sglang uses internally. self.tokenizer = None if HAS_TRANSFORMERS: try: self.tokenizer = AutoTokenizer.from_pretrained( request.Model, trust_remote_code=bool(request.TrustRemoteCode), ) except Exception as err: print(f"AutoTokenizer load failed (non-fatal): {err!r}", file=sys.stderr) print("Model loaded successfully", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) async def Predict(self, request, context): gen = self._predict(request, context, streaming=False) res = await gen.__anext__() return res async def PredictStream(self, request, context): iterations = self._predict(request, context, streaming=True) try: async for iteration in iterations: yield iteration finally: try: await iterations.aclose() except Exception: pass async def TokenizeString(self, request, context): if not getattr(self, "tokenizer", None): context.set_code(grpc.StatusCode.FAILED_PRECONDITION) context.set_details("tokenizer not loaded") return backend_pb2.TokenizationResponse() try: tokens = self.tokenizer.encode(request.Prompt) return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens) except Exception as e: context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(e)) return backend_pb2.TokenizationResponse() async def Free(self, request, context): try: if hasattr(self, "llm"): try: self.llm.shutdown() except Exception: pass del self.llm if hasattr(self, "tokenizer"): del self.tokenizer self.tool_parser_name = None self.reasoning_parser_name = None gc.collect() try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except ImportError: pass return backend_pb2.Result(success=True, message="Model freed") except Exception as e: return backend_pb2.Result(success=False, message=str(e)) def _build_sampling_params(self, request) -> dict: sampling_params: dict = {"temperature": 0.7, "max_new_tokens": 200} mapping = { "N": "n", "PresencePenalty": "presence_penalty", "FrequencyPenalty": "frequency_penalty", "RepetitionPenalty": "repetition_penalty", "Temperature": "temperature", "TopP": "top_p", "TopK": "top_k", "MinP": "min_p", "Seed": _SEED_KEY, "StopPrompts": "stop", "StopTokenIds": "stop_token_ids", "IgnoreEOS": "ignore_eos", "Tokens": "max_new_tokens", "MinTokens": "min_new_tokens", "SkipSpecialTokens": "skip_special_tokens", } for proto_field, sgl_key in mapping.items(): if not hasattr(request, proto_field): continue value = getattr(request, proto_field) if value in (None, 0, 0.0, [], False, ""): continue # repeated fields come back as RepeatedScalarContainer — convert if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): value = list(value) if not value: continue sampling_params[sgl_key] = value # Grammar → JSON schema or EBNF structured decoding. if getattr(request, "Grammar", ""): grammar = request.Grammar try: json.loads(grammar) sampling_params["json_schema"] = grammar except json.JSONDecodeError: sampling_params["ebnf"] = grammar return sampling_params def _build_prompt(self, request) -> str: prompt = request.Prompt if prompt or not request.UseTokenizerTemplate or not request.Messages: return prompt if self.tokenizer is None: print( "UseTokenizerTemplate requested but tokenizer not loaded; " "falling back to naive concatenation", file=sys.stderr, ) return "\n".join(m.content or "" for m in request.Messages) messages_dicts = self._messages_to_dicts(request.Messages) template_kwargs: dict = {"tokenize": False, "add_generation_prompt": True} if request.Tools: try: template_kwargs["tools"] = json.loads(request.Tools) except json.JSONDecodeError: pass if request.Metadata.get("enable_thinking", "").lower() == "true": template_kwargs["enable_thinking"] = True try: return self.tokenizer.apply_chat_template(messages_dicts, **template_kwargs) except TypeError: return self.tokenizer.apply_chat_template( messages_dicts, tokenize=False, add_generation_prompt=True, ) def _make_parsers(self, request): """Construct fresh per-request parser instances (stateful).""" tool_parser = None reasoning_parser = None if HAS_TOOL_PARSERS and self.tool_parser_name and request.Tools: try: tools_raw = json.loads(request.Tools) tools = [SglTool.model_validate(t) for t in tools_raw] if SglTool else tools_raw tool_parser = FunctionCallParser( tools=tools, tool_call_parser=self.tool_parser_name, ) except Exception as e: print(f"FunctionCallParser init failed: {e!r}", file=sys.stderr) if HAS_REASONING_PARSERS and self.reasoning_parser_name: try: reasoning_parser = ReasoningParser( model_type=self.reasoning_parser_name, stream_reasoning=True, ) except Exception as e: print(f"ReasoningParser init failed: {e!r}", file=sys.stderr) return tool_parser, reasoning_parser async def _predict(self, request, context, streaming: bool = False): sampling_params = self._build_sampling_params(request) prompt = self._build_prompt(request) tool_parser, reasoning_parser = self._make_parsers(request) image_data = list(request.Images) if request.Images else None video_data = list(request.Videos) if request.Videos else None # Kick off streaming generation. We always use stream=True so the # non-stream path still gets parser coverage on the final text. try: iterator = await self.llm.async_generate( prompt=prompt, sampling_params=sampling_params, image_data=image_data, video_data=video_data, stream=True, ) except Exception as e: print(f"sglang async_generate failed: {e!r}", file=sys.stderr) yield backend_pb2.Reply(message=bytes(f"error: {e!r}", "utf-8")) return generated_text = "" last_chunk: Optional[dict] = None # Track tool call ids once per (request, tool_index) to match the # OpenAI streaming contract (id sent on first chunk for that tool). tool_ids_seen: Dict[int, str] = {} try: async for chunk in iterator: last_chunk = chunk cumulative = chunk.get("text", "") if isinstance(chunk, dict) else "" delta_text = cumulative[len(generated_text):] if cumulative.startswith(generated_text) else cumulative generated_text = cumulative if not delta_text: continue reasoning_delta = "" content_delta = delta_text if reasoning_parser is not None: try: r, n = reasoning_parser.parse_stream_chunk(delta_text) reasoning_delta = r or "" content_delta = n or "" except Exception as e: print(f"reasoning_parser.parse_stream_chunk: {e!r}", file=sys.stderr) tool_call_deltas: List[backend_pb2.ToolCallDelta] = [] if tool_parser is not None and content_delta: try: normal_text, calls = tool_parser.parse_stream_chunk(content_delta) content_delta = normal_text or "" for tc in calls: idx = int(getattr(tc, "tool_index", 0) or 0) tc_id = tool_ids_seen.get(idx) if tc_id is None: tc_id = f"call_{uuid.uuid4().hex[:24]}" tool_ids_seen[idx] = tc_id tool_call_deltas.append(backend_pb2.ToolCallDelta( index=idx, id=tc_id, name=getattr(tc, "name", "") or "", arguments=getattr(tc, "parameters", "") or "", )) except Exception as e: print(f"tool_parser.parse_stream_chunk: {e!r}", file=sys.stderr) if streaming and (content_delta or reasoning_delta or tool_call_deltas): yield backend_pb2.Reply( message=bytes(content_delta, "utf-8"), chat_deltas=[backend_pb2.ChatDelta( content=content_delta, reasoning_content=reasoning_delta, tool_calls=tool_call_deltas, )], ) finally: try: await iterator.aclose() except Exception: pass # Extract token counts from the final chunk's meta_info. meta = {} if isinstance(last_chunk, dict): meta = last_chunk.get("meta_info") or {} prompt_tokens = int(meta.get("prompt_tokens", 0) or 0) completion_tokens = int(meta.get("completion_tokens", 0) or 0) # Non-streaming path: re-parse the full text with fresh parsers # so we return a clean, complete ChatDelta. Streaming parsers # used above have accumulated state we don't want to reuse. final_content = generated_text final_reasoning = "" final_tool_calls: List[backend_pb2.ToolCallDelta] = [] if not streaming: final_reasoning_parser = None if HAS_REASONING_PARSERS and self.reasoning_parser_name: try: final_reasoning_parser = ReasoningParser( model_type=self.reasoning_parser_name, stream_reasoning=False, ) except Exception: final_reasoning_parser = None if final_reasoning_parser is not None: try: r, n = final_reasoning_parser.parse_non_stream(generated_text) final_reasoning = r or "" final_content = n if n is not None else generated_text except Exception as e: print(f"reasoning_parser.parse_non_stream: {e!r}", file=sys.stderr) if HAS_TOOL_PARSERS and self.tool_parser_name and request.Tools: try: tools_raw = json.loads(request.Tools) tools = [SglTool.model_validate(t) for t in tools_raw] if SglTool else tools_raw fresh_tool_parser = FunctionCallParser( tools=tools, tool_call_parser=self.tool_parser_name, ) normal, calls = fresh_tool_parser.parse_non_stream(final_content) if calls: final_content = normal for tc in calls: idx = int(getattr(tc, "tool_index", 0) or 0) final_tool_calls.append(backend_pb2.ToolCallDelta( index=idx, id=f"call_{uuid.uuid4().hex[:24]}", name=getattr(tc, "name", "") or "", arguments=getattr(tc, "parameters", "") or "", )) except Exception as e: print(f"tool_parser.parse_non_stream: {e!r}", file=sys.stderr) chat_delta = backend_pb2.ChatDelta( content=final_content if not streaming else "", reasoning_content=final_reasoning, tool_calls=final_tool_calls, ) if streaming: yield backend_pb2.Reply( message=b"", prompt_tokens=prompt_tokens, tokens=completion_tokens, chat_deltas=[chat_delta], ) return yield backend_pb2.Reply( message=bytes(final_content or "", "utf-8"), prompt_tokens=prompt_tokens, tokens=completion_tokens, chat_deltas=[chat_delta], ) async def serve(address): server = grpc.aio.server( migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ], interceptors=get_auth_interceptors(aio=True), ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5))) await server.start() print("Server started. Listening on: " + address, file=sys.stderr) await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the sglang gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to.", ) args = parser.parse_args() asyncio.run(serve(args.addr))