feat: refactor shared helpers and enhance MLX backend functionality (#9335)

* refactor(backends): extract python_utils + add mlx_utils shared helpers

Move parse_options() and messages_to_dicts() out of vllm_utils.py into a
new framework-agnostic python_utils.py, and re-export them from vllm_utils
so existing vllm / vllm-omni imports keep working.

Add mlx_utils.py with split_reasoning() and parse_tool_calls() — ported
from mlx_vlm/server.py's process_tool_calls. These work with any
mlx-lm / mlx-vlm tool module (anything exposing tool_call_start,
tool_call_end, parse_tool_call). Used by the mlx and mlx-vlm backends in
later commits to emit structured ChatDelta.tool_calls without
reimplementing per-model parsing.

Shared smoke tests confirm:
- parse_options round-trips bool/int/float/string
- vllm_utils re-exports are identity-equal to python_utils originals
- mlx_utils parse_tool_calls handles <tool_call>...</tool_call> with a
  shim module and produces a correctly-indexed list with JSON arguments
- mlx_utils split_reasoning extracts <think> blocks and leaves clean
  content

* feat(mlx): wire native tool parsers + ChatDelta + token usage + logprobs

Bring the MLX backend up to the same structured-output contract as vLLM
and llama.cpp: emit Reply.chat_deltas so the OpenAI HTTP layer sees
tool_calls and reasoning_content, not just raw text.

Key insight: mlx_lm.load() returns a TokenizerWrapper that already auto-
detects the right tool parser from the model's chat template
(_infer_tool_parser in mlx_lm/tokenizer_utils.py). The wrapper exposes
has_tool_calling, has_thinking, tool_parser, tool_call_start,
tool_call_end, think_start, think_end — no user configuration needed,
unlike vLLM.

Changes in backend/python/mlx/backend.py:

- Imports: replace inline parse_options / messages_to_dicts with the
  shared helpers from python_utils. Pull split_reasoning / parse_tool_calls
  from the new mlx_utils shared module.
- LoadModel: log the auto-detected has_tool_calling / has_thinking /
  tool_parser_type for observability. Drop the local is_float / is_int
  duplicates.
- _prepare_prompt: run request.Messages through messages_to_dicts so
  tool_call_id / tool_calls / reasoning_content survive the conversion,
  and pass tools=json.loads(request.Tools) + enable_thinking=True (when
  request.Metadata says so) to apply_chat_template. Falls back on
  TypeError for tokenizers whose template doesn't accept those kwargs.
- _build_generation_params: return an additional (logits_params,
  stop_words) pair. Maps RepetitionPenalty / PresencePenalty /
  FrequencyPenalty to mlx_lm.sample_utils.make_logits_processors and
  threads StopPrompts through to post-decode truncation.
- New _tool_module_from_tokenizer / _finalize_output / _truncate_at_stop
  helpers. _finalize_output runs split_reasoning when has_thinking is
  true and parse_tool_calls (using a SimpleNamespace shim around the
  wrapper's tool_parser callable) when has_tool_calling is true, then
  extracts prompt_tokens, generation_tokens and (best-effort) logprobs
  from the last GenerationResponse chunk.
- Predict: use make_logits_processors, accumulate text + last_response,
  finalize into a structured Reply carrying chat_deltas,
  prompt_tokens, tokens, logprobs. Early-stops on user stop sequences.
- PredictStream: per-chunk Reply still carries raw message bytes for
  back-compat but now also emits chat_deltas=[ChatDelta(content=delta)].
  On loop exit, emit a terminal Reply with structured
  reasoning_content / tool_calls / token counts / logprobs — so the Go
  side sees tool calls without needing the regex fallback.
- TokenizeString RPC: uses the TokenizerWrapper's encode(); returns
  length + tokens or FAILED_PRECONDITION if the model isn't loaded.
- Free RPC: drops model / tokenizer / lru_cache, runs gc.collect(),
  calls mx.metal.clear_cache() when available, and best-effort clears
  torch.cuda as a belt-and-suspenders.

* feat(mlx-vlm): mirror MLX parity (tool parsers + ChatDelta + samplers)

Same treatment as the MLX backend: emit structured Reply.chat_deltas,
tool_calls, reasoning_content, token counts and logprobs, and extend
sampling parameter coverage beyond the temp/top_p pair the backend
used to handle.

- Imports: drop the inline is_float/is_int helpers, pull parse_options /
  messages_to_dicts from python_utils and split_reasoning /
  parse_tool_calls from mlx_utils. Also import make_sampler and
  make_logits_processors from mlx_lm.sample_utils — mlx-vlm re-uses them.
- LoadModel: use parse_options; call mlx_vlm.tool_parsers._infer_tool_parser
  / load_tool_module to auto-detect a tool module from the processor's
  chat_template. Stash think_start / think_end / has_thinking so later
  finalisation can split reasoning blocks without duck-typing on each
  call. Logs the detected parser type.
- _prepare_prompt: convert proto Messages via messages_to_dicts (so
  tool_call_id / tool_calls survive), pass tools=json.loads(request.Tools)
  and enable_thinking=True to apply_chat_template when present, fall
  back on TypeError for older mlx-vlm versions. Also handle the
  prompt-only + media and empty-prompt + media paths consistently.
- _build_generation_params: return (max_tokens, sampler_params,
  logits_params, stop_words). Maps repetition_penalty / presence_penalty /
  frequency_penalty and passes them through make_logits_processors.
- _finalize_output / _truncate_at_stop: common helper used by Predict
  and PredictStream to split reasoning, run parse_tool_calls against the
  auto-detected tool module, build ToolCallDelta list, and extract token
  counts + logprobs from the last GenerationResult.
- Predict / PredictStream: switch from mlx_vlm.generate to mlx_vlm.stream_generate
  in both paths, accumulate text + last_response, pass sampler and
  logits_processors through, emit content-only ChatDelta per streaming
  chunk followed by a terminal Reply carrying reasoning_content,
  tool_calls, prompt_tokens, tokens and logprobs. Non-streaming Predict
  returns the same structured Reply shape.
- New helper _collect_media extracted from the duplicated base64 image /
  audio decode loop.
- New TokenizeString RPC using the processor's tokenizer.encode and
  Free RPC that drops model/processor/config, runs gc + Metal cache
  clear + best-effort torch.cuda cache clear.

* feat(importer/mlx): auto-set tool_parser/reasoning_parser on import

Mirror what core/gallery/importers/vllm.go does: after applying the
shared inference defaults, look up the model URI in parser_defaults.json
and append matching tool_parser:/reasoning_parser: entries to Options.

The MLX backends auto-detect tool parsers from the chat template at
runtime so they don't actually consume these options — but surfacing
them in the generated YAML:
  - keeps the import experience consistent with vllm
  - gives users a single visible place to override
  - documents the intended parser for a given model family

* test(mlx): add helper unit tests + TokenizeString/Free + e2e make targets

- backend/python/mlx/test.py: add TestSharedHelpers with server-less
  unit tests for parse_options, messages_to_dicts, split_reasoning and
  parse_tool_calls (using a SimpleNamespace shim to fake a tool module
  without requiring a model). Plus test_tokenize_string and test_free
  RPC tests that load a tiny MLX-quantized Llama and exercise the new
  RPCs end-to-end.

- backend/python/mlx-vlm/test.py: same helper unit tests + cleanup of
  the duplicated import block at the top of the file.

- Makefile: register BACKEND_MLX and BACKEND_MLX_VLM (they were missing
  from the docker-build-target eval list — only mlx-distributed had a
  generated target before). Add test-extra-backend-mlx and
  test-extra-backend-mlx-vlm convenience targets that build the
  respective image and run tests/e2e-backends with the tools capability
  against mlx-community/Qwen2.5-0.5B-Instruct-4bit. The MLX backend
  auto-detects the tool parser from the chat template so no
  BACKEND_TEST_OPTIONS is needed (unlike vllm).

* fix(libbackend): don't pass --copies to venv unless PORTABLE_PYTHON=true

backend/python/common/libbackend.sh:ensureVenv() always invoked
'python -m venv --copies', but macOS system python (and some other
builds) refuses with:

    Error: This build of python cannot create venvs without using symlinks

--copies only matters when _makeVenvPortable later relocates the venv,
which only happens when PORTABLE_PYTHON=true. Make --copies conditional
on that flag and fall back to default (symlinked) venv otherwise.

Caught while bringing up the mlx backend on Apple Silicon — the same
build path is used by every Python backend with USE_PIP=true.

* fix(mlx): support mlx-lm 0.29.x tool calling + drop deprecated clear_cache

The released mlx-lm 0.29.x ships a much simpler tool-calling API than
HEAD: TokenizerWrapper detects the <tool_call>...</tool_call> markers
from the tokenizer vocab and exposes has_tool_calling /
tool_call_start / tool_call_end, but does NOT expose a tool_parser
callable on the wrapper and does NOT ship a mlx_lm.tool_parsers
subpackage at all (those only exist on main).

Caught while running the smoke test on Apple Silicon with the
released mlx-lm 0.29.1: tokenizer.tool_parser raised AttributeError
(falling through to the underlying HF tokenizer), so
_tool_module_from_tokenizer always returned None and tool calls slipped
through as raw <tool_call>...</tool_call> text in Reply.message instead
of being parsed into ChatDelta.tool_calls.

Fix: when has_tool_calling is True but tokenizer.tool_parser is missing,
default the parse_tool_call callable to json.loads(body.strip()) — that's
exactly what mlx_lm.tool_parsers.json_tools.parse_tool_call does on HEAD
and covers the only format 0.29 detects (<tool_call>JSON</tool_call>).
Future mlx-lm releases that ship more parsers will be picked up
automatically via the tokenizer.tool_parser attribute when present.

Also tighten the LoadModel logging — the old log line read
init_kwargs.get('tool_parser_type') which doesn't exist on 0.29 and
showed None even when has_tool_calling was True. Log the actual
tool_call_start / tool_call_end markers instead.

While here, switch Free()'s Metal cache clear from the deprecated
mx.metal.clear_cache to mx.clear_cache (mlx >= 0.30), with a
fallback for older releases. Mirrored to the mlx-vlm backend.

* feat(mlx-distributed): mirror MLX parity (tool calls + ChatDelta + sampler)

Same treatment as the mlx and mlx-vlm backends: emit Reply.chat_deltas
with structured tool_calls / reasoning_content / token counts /
logprobs, expand sampling parameter coverage beyond temp+top_p, and
add the missing TokenizeString and Free RPCs.

Notes specific to mlx-distributed:

- Rank 0 is the only rank that owns a sampler — workers participate in
  the pipeline-parallel forward pass via mx.distributed and don't
  re-implement sampling. So the new logits_params (repetition_penalty,
  presence_penalty, frequency_penalty) and stop_words apply on rank 0
  only; we don't need to extend coordinator.broadcast_generation_params,
  which still ships only max_tokens / temperature / top_p to workers
  (everything else is a rank-0 concern).
- Free() now broadcasts CMD_SHUTDOWN to workers when a coordinator is
  active, so they release the model on their end too. The constant is
  already defined and handled by the existing worker loop in
  backend.py:633 (CMD_SHUTDOWN = -1).
- Drop the locally-defined is_float / is_int / parse_options trio in
  favor of python_utils.parse_options, re-exported under the module
  name for back-compat with anything that imported it directly.
- _prepare_prompt: route through messages_to_dicts so tool_call_id /
  tool_calls / reasoning_content survive, pass tools=json.loads(
  request.Tools) and enable_thinking=True to apply_chat_template, fall
  back on TypeError for templates that don't accept those kwargs.
- New _tool_module_from_tokenizer (with the json.loads fallback for
  mlx-lm 0.29.x), _finalize_output, _truncate_at_stop helpers — same
  contract as the mlx backend.
- LoadModel logs the auto-detected has_tool_calling / has_thinking /
  tool_call_start / tool_call_end so users can see what the wrapper
  picked up for the loaded model.
- backend/python/mlx-distributed/test.py: add the same TestSharedHelpers
  unit tests (parse_options, messages_to_dicts, split_reasoning,
  parse_tool_calls) that exist for mlx and mlx-vlm.
This commit is contained in:
Ettore Di Giacinto
2026-04-13 18:44:03 +02:00
committed by GitHub
parent daa0272f2e
commit 016da02845
12 changed files with 1380 additions and 398 deletions

View File

@@ -519,6 +519,22 @@ test-extra-backend-vllm: docker-build-vllm
BACKEND_TEST_OPTIONS=tool_parser:hermes \
$(MAKE) test-extra-backend
## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool
## parser from the chat template, so no tool_parser: option is needed (it
## would be ignored at runtime). Run this on macOS / arm64 with Metal; the
## Linux/CPU mlx variant is untested in CI.
test-extra-backend-mlx: docker-build-mlx
BACKEND_IMAGE=local-ai-backend:mlx \
BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
$(MAKE) test-extra-backend
test-extra-backend-mlx-vlm: docker-build-mlx-vlm
BACKEND_IMAGE=local-ai-backend:mlx-vlm \
BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
$(MAKE) test-extra-backend
DOCKER_IMAGE?=local-ai
IMAGE_TYPE?=core
BASE_IMAGE?=ubuntu:24.04
@@ -652,6 +668,8 @@ BACKEND_NEMO = nemo|python|.|false|true
BACKEND_VOXCPM = voxcpm|python|.|false|true
BACKEND_WHISPERX = whisperx|python|.|false|true
BACKEND_ACE_STEP = ace-step|python|.|false|true
BACKEND_MLX = mlx|python|.|false|true
BACKEND_MLX_VLM = mlx-vlm|python|.|false|true
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
BACKEND_TRL = trl|python|.|false|true
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
@@ -720,6 +738,8 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))

View File

@@ -344,7 +344,16 @@ function ensureVenv() {
if [ ! -d "${EDIR}/venv" ]; then
if [ "x${USE_PIP}" == "xtrue" ]; then
"${interpreter}" -m venv --copies "${EDIR}/venv"
# --copies is only needed when we will later relocate the venv via
# _makeVenvPortable (PORTABLE_PYTHON=true). Some Python builds —
# notably macOS system Python — refuse to create a venv with
# --copies because the build doesn't support it. Fall back to
# symlinks in that case.
local venv_args=""
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
venv_args="--copies"
fi
"${interpreter}" -m venv ${venv_args} "${EDIR}/venv"
source "${EDIR}/venv/bin/activate"
"${interpreter}" -m pip install --upgrade pip
else

View File

@@ -0,0 +1,100 @@
"""Shared utilities for the mlx and mlx-vlm gRPC backends.
These helpers wrap mlx-lm's and mlx-vlm's native tool-parser modules, which
auto-detect the right parser from the model's chat template. Each tool
module exposes ``tool_call_start``, ``tool_call_end`` and
``parse_tool_call(text, tools) -> dict | list[dict]``.
The split-reasoning helper is generic enough to work with any think-start /
think-end delimiter pair.
"""
import json
import re
import sys
import uuid
def split_reasoning(text, think_start, think_end):
"""Split ``<think>...</think>`` blocks out of ``text``.
Returns ``(reasoning_content, remaining_text)``. When ``think_start`` is
empty or not found, returns ``("", text)`` unchanged.
"""
if not think_start or not text or think_start not in text:
return "", text
pattern = re.compile(
re.escape(think_start) + r"(.*?)" + re.escape(think_end or ""),
re.DOTALL,
)
reasoning_parts = pattern.findall(text)
if not reasoning_parts:
return "", text
remaining = pattern.sub("", text).strip()
return "\n".join(p.strip() for p in reasoning_parts), remaining
def parse_tool_calls(text, tool_module, tools):
"""Extract tool calls from ``text`` using a mlx-lm tool module.
Ports the ``process_tool_calls`` logic from
``mlx_vlm/server.py`` (v0.10 onwards). ``tool_module`` must expose
``tool_call_start``, ``tool_call_end`` and ``parse_tool_call``.
Returns ``(calls, remaining_text)`` where ``calls`` is a list of dicts:
[{"index": int, "id": str, "name": str, "arguments": str (JSON)}]
and ``remaining_text`` is the free-form text with the tool call blocks
removed. ``(calls, text)`` is returned unchanged if ``tool_module`` is
``None`` or the start delimiter isn't present.
"""
if tool_module is None or not text:
return [], text
start = getattr(tool_module, "tool_call_start", None)
end = getattr(tool_module, "tool_call_end", None)
parse_fn = getattr(tool_module, "parse_tool_call", None)
if not start or parse_fn is None or start not in text:
return [], text
if end == "" or end is None:
pattern = re.compile(
re.escape(start) + r".*?(?:\n|$)",
re.DOTALL,
)
else:
pattern = re.compile(
re.escape(start) + r".*?" + re.escape(end),
re.DOTALL,
)
matches = pattern.findall(text)
if not matches:
return [], text
remaining = pattern.sub(" ", text).strip()
calls = []
for match in matches:
call_body = match.strip().removeprefix(start)
if end:
call_body = call_body.removesuffix(end)
call_body = call_body.strip()
try:
parsed = parse_fn(call_body, tools)
except Exception as e:
print(
f"[mlx_utils] Invalid tool call: {call_body!r} ({e})",
file=sys.stderr,
)
continue
if not isinstance(parsed, list):
parsed = [parsed]
for tc in parsed:
calls.append(
{
"index": len(calls),
"id": str(uuid.uuid4()),
"name": (tc.get("name") or "").strip(),
"arguments": json.dumps(tc.get("arguments", {}), ensure_ascii=False),
}
)
return calls, remaining

View File

@@ -0,0 +1,65 @@
"""Generic utilities shared across Python gRPC backends.
These helpers don't depend on any specific inference framework and can be
imported by any backend that needs to parse LocalAI gRPC options or build a
chat-template-compatible message list from proto Message objects.
"""
import json
def parse_options(options_list):
"""Parse Options[] list of ``key:value`` strings into a dict.
Supports type inference for common cases (bool, int, float). Unknown or
mixed-case values are returned as strings.
Used by LoadModel to extract backend-specific options passed via
``ModelOptions.Options`` in ``backend.proto``.
"""
opts = {}
for opt in options_list:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
key = key.strip()
value = value.strip()
# Try type conversion
if value.lower() in ("true", "false"):
opts[key] = value.lower() == "true"
else:
try:
opts[key] = int(value)
except ValueError:
try:
opts[key] = float(value)
except ValueError:
opts[key] = value
return opts
def messages_to_dicts(proto_messages):
"""Convert proto ``Message`` objects to dicts suitable for ``apply_chat_template``.
Handles: ``role``, ``content``, ``name``, ``tool_call_id``,
``reasoning_content``, ``tool_calls`` (JSON string → Python list).
HuggingFace chat templates (and their MLX/vLLM wrappers) expect a list of
plain dicts — proto Message objects don't work directly with Jinja, so
this conversion is needed before every ``apply_chat_template`` call.
"""
result = []
for msg in proto_messages:
d = {"role": msg.role, "content": msg.content or ""}
if msg.name:
d["name"] = msg.name
if msg.tool_call_id:
d["tool_call_id"] = msg.tool_call_id
if msg.reasoning_content:
d["reasoning_content"] = msg.reasoning_content
if msg.tool_calls:
try:
d["tool_calls"] = json.loads(msg.tool_calls)
except json.JSONDecodeError:
pass
result.append(d)
return result

View File

@@ -1,63 +1,22 @@
"""Shared utilities for vLLM-based backends."""
import json
"""vLLM-specific helpers for the vllm and vllm-omni gRPC backends.
Generic helpers (``parse_options``, ``messages_to_dicts``) live in
``python_utils`` and are re-exported here for backwards compatibility with
existing imports in both backends.
"""
import sys
from python_utils import messages_to_dicts, parse_options
def parse_options(options_list):
"""Parse Options[] list of 'key:value' strings into a dict.
Supports type inference for common cases (bool, int, float).
Used by LoadModel to extract backend-specific options.
"""
opts = {}
for opt in options_list:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
key = key.strip()
value = value.strip()
# Try type conversion
if value.lower() in ("true", "false"):
opts[key] = value.lower() == "true"
else:
try:
opts[key] = int(value)
except ValueError:
try:
opts[key] = float(value)
except ValueError:
opts[key] = value
return opts
def messages_to_dicts(proto_messages):
"""Convert proto Message objects to list of dicts for apply_chat_template().
Handles: role, content, name, tool_call_id, reasoning_content, tool_calls (JSON string -> list).
"""
result = []
for msg in proto_messages:
d = {"role": msg.role, "content": msg.content or ""}
if msg.name:
d["name"] = msg.name
if msg.tool_call_id:
d["tool_call_id"] = msg.tool_call_id
if msg.reasoning_content:
d["reasoning_content"] = msg.reasoning_content
if msg.tool_calls:
try:
d["tool_calls"] = json.loads(msg.tool_calls)
except json.JSONDecodeError:
pass
result.append(d)
return result
__all__ = ["parse_options", "messages_to_dicts", "setup_parsers"]
def setup_parsers(opts):
"""Return (tool_parser_cls, reasoning_parser_cls) tuple from opts dict.
"""Return ``(tool_parser_cls, reasoning_parser_cls)`` from an opts dict.
Uses vLLM's native ToolParserManager and ReasoningParserManager.
Returns (None, None) if vLLM is not installed or parsers not available.
Uses vLLM's native ``ToolParserManager`` / ``ReasoningParserManager``.
Returns ``(None, None)`` if vLLM isn't installed or the requested
parser name can't be resolved.
"""
tool_parser_cls = None
reasoning_parser_cls = None

View File

@@ -15,17 +15,21 @@ Two startup modes:
import asyncio
from concurrent import futures
import argparse
import gc
import json
import os
import signal
import sys
import tempfile
import types
from typing import List
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from python_utils import messages_to_dicts, parse_options as _shared_parse_options
from mlx_utils import parse_tool_calls, split_reasoning
import backend_pb2
@@ -62,37 +66,10 @@ def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None):
raise ValueError(f"Unknown backend: {backend}")
def is_float(s):
try:
float(s)
return True
except ValueError:
return False
def is_int(s):
try:
int(s)
return True
except ValueError:
return False
def parse_options(options):
"""Parse key:value option strings into a dict."""
result = {}
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
result[key] = value
return result
# Re-export the shared helper under the local name for back-compat with
# any callers (and the existing distributed worker tests) that imported
# parse_options directly from this module.
parse_options = _shared_parse_options
class BackendServicer(backend_pb2_grpc.BackendServicer):
@@ -188,6 +165,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
)
print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr)
# Log auto-detected TokenizerWrapper capabilities. Same shape
# as the mlx backend: has_tool_calling / has_thinking from
# mlx_lm.tokenizer_utils + the start/end markers it sniffed
# from the chat template / vocab.
has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False))
has_thinking = bool(getattr(self.tokenizer, "has_thinking", False))
tcs = getattr(self.tokenizer, "tool_call_start", None)
tce = getattr(self.tokenizer, "tool_call_end", None)
print(
f"[Rank 0] Tokenizer capabilities: has_tool_calling={has_tools} "
f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}",
file=sys.stderr,
)
except Exception as err:
print(f"[Rank 0] Error loading model: {err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading model: {err}")
@@ -201,7 +192,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_sampler
from mlx_lm.sample_utils import make_logits_processors, make_sampler
prompt_text = self._prepare_prompt(request)
tokens = self._get_tokens_from_prompt(prompt_text)
@@ -211,7 +202,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
self.coordinator.broadcast_tokens(tokens)
max_tokens, sampler_params = self._build_generation_params(request)
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
if self.coordinator:
gen_params = self.coordinator.broadcast_generation_params(
@@ -222,6 +213,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
max_tokens = gen_params["max_tokens"]
sampler = make_sampler(**sampler_params)
logits_processors = make_logits_processors(**logits_params) if logits_params else None
# Use prompt cache in single-node mode
gen_kwargs = {}
@@ -238,22 +230,44 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
tokens = remaining_tokens if remaining_tokens else cache_key
generated = []
last_response = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=tokens,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
**gen_kwargs,
):
generated.append(response.text)
last_response = response
if cache_key is not None:
cache_key.append(response.token)
if stop_words and any(s in "".join(generated) for s in stop_words):
break
if self.lru_cache is not None and cache_key is not None:
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8'))
full_text = self._truncate_at_stop("".join(generated), stop_words)
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
self._finalize_output(request, full_text, last_response)
)
return backend_pb2.Reply(
message=bytes(content, encoding='utf-8'),
prompt_tokens=prompt_tokens,
tokens=completion_tokens,
logprobs=logprobs_bytes,
chat_deltas=[
backend_pb2.ChatDelta(
content=content,
reasoning_content=reasoning_content,
tool_calls=tool_calls_proto,
)
],
)
except Exception as e:
print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr)
@@ -268,7 +282,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_sampler
from mlx_lm.sample_utils import make_logits_processors, make_sampler
prompt_text = self._prepare_prompt(request)
tokens = self._get_tokens_from_prompt(prompt_text)
@@ -278,7 +292,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
self.coordinator.broadcast_tokens(tokens)
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
request, default_max_tokens=512
)
if self.coordinator:
gen_params = self.coordinator.broadcast_generation_params(
@@ -289,6 +305,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
max_tokens = gen_params["max_tokens"]
sampler = make_sampler(**sampler_params)
logits_processors = make_logits_processors(**logits_params) if logits_params else None
# Use prompt cache in single-node mode
gen_kwargs = {}
@@ -304,17 +321,45 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
gen_kwargs['prompt_cache'] = prompt_cache
tokens = remaining_tokens if remaining_tokens else cache_key
accumulated = []
last_response = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=tokens,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
**gen_kwargs,
):
if cache_key is not None:
cache_key.append(response.token)
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
accumulated.append(response.text)
last_response = response
yield backend_pb2.Reply(
message=bytes(response.text, encoding='utf-8'),
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
)
if stop_words and any(s in "".join(accumulated) for s in stop_words):
break
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
self._finalize_output(request, full_text, last_response)
)
yield backend_pb2.Reply(
message=b"",
prompt_tokens=prompt_tokens,
tokens=completion_tokens,
logprobs=logprobs_bytes,
chat_deltas=[
backend_pb2.ChatDelta(
content="",
reasoning_content=reasoning_content,
tool_calls=tool_calls_proto,
)
],
)
except Exception as e:
print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr)
@@ -335,12 +380,74 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
context.set_details("Embeddings are not supported in the MLX distributed backend.")
return backend_pb2.EmbeddingResult()
async def TokenizeString(self, request, context):
if not hasattr(self, "tokenizer") or self.tokenizer is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("tokenizer not loaded")
return backend_pb2.TokenizationResponse()
try:
tokens = self.tokenizer.encode(request.Prompt)
if hasattr(tokens, "tolist"):
tokens = tokens.tolist()
tokens = list(tokens)
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
except Exception as e:
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return backend_pb2.TokenizationResponse()
async def Free(self, request, context):
try:
# If we're rank 0 of a distributed run, tell workers to shut
# down their per-request loops first so they release the model.
if self.coordinator is not None:
try:
from coordinator import CMD_SHUTDOWN
self.coordinator.broadcast_command(CMD_SHUTDOWN)
except Exception as e:
print(f"[Rank 0] failed to broadcast shutdown: {e}", file=sys.stderr)
if hasattr(self, "model"):
del self.model
if hasattr(self, "tokenizer"):
del self.tokenizer
if self.lru_cache is not None:
try:
self.lru_cache.clear()
except Exception:
pass
self.lru_cache = None
self.coordinator = None
self.group = None
gc.collect()
try:
import mlx.core as mx # type: ignore
if hasattr(mx, "clear_cache"):
mx.clear_cache()
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
mx.metal.clear_cache()
except Exception:
pass
return backend_pb2.Result(success=True, message="MLX distributed model freed")
except Exception as e:
return backend_pb2.Result(success=False, message=str(e))
def _prepare_prompt(self, request):
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages]
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
messages = messages_to_dicts(request.Messages)
kwargs = {"tokenize": False, "add_generation_prompt": True}
if request.Tools:
try:
kwargs["tools"] = json.loads(request.Tools)
except json.JSONDecodeError:
pass
if request.Metadata.get("enable_thinking", "").lower() == "true":
kwargs["enable_thinking"] = True
try:
return self.tokenizer.apply_chat_template(messages, **kwargs)
except TypeError:
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return request.Prompt
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
@@ -349,6 +456,82 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return tokens.tolist()
return list(tokens)
def _tool_module_from_tokenizer(self):
"""Same shim as the mlx backend: fall back to json.loads when the
installed mlx-lm doesn't expose a tool_parser callable on the
wrapper (true on 0.29.x — only HEAD ships parsers)."""
start = getattr(self.tokenizer, "tool_call_start", None)
end = getattr(self.tokenizer, "tool_call_end", None)
if not start:
return None
parse_fn = getattr(self.tokenizer, "tool_parser", None)
if parse_fn is None:
def parse_fn(body, tools): # noqa: E306
return json.loads(body.strip())
return types.SimpleNamespace(
tool_call_start=start,
tool_call_end=end or "",
parse_tool_call=parse_fn,
)
def _truncate_at_stop(self, text, stop_words):
if not stop_words:
return text
earliest = len(text)
for stop in stop_words:
if not stop:
continue
idx = text.find(stop)
if idx >= 0 and idx < earliest:
earliest = idx
return text[:earliest] if earliest < len(text) else text
def _finalize_output(self, request, generated_text, last_response):
content = generated_text
reasoning_content = ""
if getattr(self.tokenizer, "has_thinking", False):
think_start = getattr(self.tokenizer, "think_start", "") or ""
think_end = getattr(self.tokenizer, "think_end", "") or ""
reasoning_content, content = split_reasoning(content, think_start, think_end)
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
tool_module = None
if getattr(self.tokenizer, "has_tool_calling", False):
tool_module = self._tool_module_from_tokenizer()
if tool_module is not None:
parsed_tools = None
if request.Tools:
try:
parsed_tools = json.loads(request.Tools)
except json.JSONDecodeError:
parsed_tools = None
calls, content = parse_tool_calls(content, tool_module, parsed_tools)
for c in calls:
tool_calls_proto.append(
backend_pb2.ToolCallDelta(
index=c["index"], id=c["id"], name=c["name"], arguments=c["arguments"],
)
)
prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
logprobs_bytes = b""
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
try:
lp = getattr(last_response, "logprobs", None)
if lp is not None:
token_id = int(getattr(last_response, "token", 0) or 0)
token_text = self.tokenizer.decode([token_id]) if token_id else ""
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
logprobs_bytes = json.dumps(
{"content": [{"token": token_text, "logprob": top_logprob}]}
).encode("utf-8")
except Exception as e:
print(f"[Rank 0] Logprobs extraction failed: {e}", file=sys.stderr)
return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes
def _build_generation_params(self, request, default_max_tokens=200):
import mlx.core as mx
@@ -373,6 +556,22 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
'xtc_probability': 0.0,
}
# Logits processor parameters — pulled from the request and
# forwarded to make_logits_processors. Rank 0 is the only rank
# running the sampler so we don't need to broadcast these to
# workers (workers participate in the pipeline-parallel forward
# pass only).
logits_params = {}
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
if repetition_penalty and repetition_penalty != 1.0:
logits_params['repetition_penalty'] = repetition_penalty
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
if presence_penalty:
logits_params['presence_penalty'] = presence_penalty
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
if frequency_penalty:
logits_params['frequency_penalty'] = frequency_penalty
seed = getattr(request, 'Seed', 0)
if seed != 0:
mx.random.seed(seed)
@@ -392,9 +591,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
for opt_key, param_key in option_mapping.items():
if opt_key in self.options:
sampler_params[param_key] = self.options[opt_key]
for opt_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
if opt_key in self.options:
logits_params[opt_key] = self.options[opt_key]
if 'seed' in self.options:
mx.random.seed(self.options['seed'])
stop_words = list(getattr(request, 'StopPrompts', []) or [])
return max_tokens, sampler_params, logits_params, stop_words
# XTC special tokens
xtc_special_tokens = []
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:

View File

@@ -1,3 +1,6 @@
import os
import sys
import types
import unittest
import subprocess
import time
@@ -6,6 +9,12 @@ import grpc
import backend_pb2
import backend_pb2_grpc
# Make the shared helpers importable so we can unit-test them without a
# running gRPC server.
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
from python_utils import messages_to_dicts, parse_options
from mlx_utils import parse_tool_calls, split_reasoning
class TestBackendServicer(unittest.TestCase):
def setUp(self):
@@ -85,3 +94,44 @@ class TestBackendServicer(unittest.TestCase):
self.fail("sampling params service failed")
finally:
self.tearDown()
class TestSharedHelpers(unittest.TestCase):
"""Server-less unit tests for the helpers the mlx-distributed backend depends on."""
def test_parse_options_typed(self):
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true"])
self.assertEqual(opts["temperature"], 0.7)
self.assertEqual(opts["max_tokens"], 128)
self.assertIs(opts["trust"], True)
def test_messages_to_dicts_roundtrip(self):
msgs = [
backend_pb2.Message(role="user", content="hi"),
backend_pb2.Message(
role="assistant",
content="",
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
),
backend_pb2.Message(role="tool", content="42", tool_call_id="call_1", name="f"),
]
out = messages_to_dicts(msgs)
self.assertEqual(out[0], {"role": "user", "content": "hi"})
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
self.assertEqual(out[2]["tool_call_id"], "call_1")
def test_split_reasoning(self):
r, c = split_reasoning("<think>plan</think>final", "<think>", "</think>")
self.assertEqual(r, "plan")
self.assertEqual(c, "final")
def test_parse_tool_calls_with_shim(self):
tm = types.SimpleNamespace(
tool_call_start="<tool_call>",
tool_call_end="</tool_call>",
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
)
calls, remaining = parse_tool_calls("<tool_call>Paris</tool_call>", tm, tools=None)
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0]["name"], "get_weather")
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')

View File

@@ -2,11 +2,14 @@
import asyncio
from concurrent import futures
import argparse
import gc
import json
import signal
import sys
import os
import tempfile
import types
from typing import List
import time
import backend_pb2
import backend_pb2_grpc
@@ -15,30 +18,18 @@ import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from python_utils import messages_to_dicts, parse_options
from mlx_utils import parse_tool_calls, split_reasoning
from mlx_vlm import load, generate, stream_generate
from mlx_vlm import load, stream_generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config, load_image
from mlx_vlm.tool_parsers import _infer_tool_parser, load_tool_module
from mlx_vlm.utils import load_config
from mlx_lm.sample_utils import make_logits_processors, make_sampler
import mlx.core as mx
import base64
import io
from PIL import Image
import tempfile
def is_float(s):
"""Check if a string can be converted to float."""
try:
float(s)
return True
except ValueError:
return False
def is_int(s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -78,36 +69,52 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr)
print(f"Request: {request}", file=sys.stderr)
# Parse options like in the diffusers backend
options = request.Options
self.options = {}
# The options are a list of strings in this form optname:optvalue
# We store all the options in a dict for later use
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value
# Parse Options[] key:value strings into a typed dict
self.options = parse_options(request.Options)
print(f"Options: {self.options}", file=sys.stderr)
# Load model and processor using MLX-VLM
# mlx-vlm load function returns (model, processor) instead of (model, tokenizer)
self.model, self.processor = load(request.Model)
# Load model config for chat template support
self.config = load_config(request.Model)
# Auto-infer the tool parser from the chat template. mlx-vlm has
# its own _infer_tool_parser that falls back to mlx-lm parsers.
tokenizer = (
self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
)
self.tool_module = None
if hasattr(tokenizer, "chat_template"):
try:
parser_type = _infer_tool_parser(tokenizer.chat_template)
if parser_type is not None:
self.tool_module = load_tool_module(parser_type)
print(
f"[mlx-vlm] auto-detected tool parser: {parser_type}",
file=sys.stderr,
)
else:
print(
"[mlx-vlm] no tool parser matched the chat template",
file=sys.stderr,
)
except Exception as e:
print(
f"[mlx-vlm] failed to load tool parser: {e}",
file=sys.stderr,
)
# Reasoning tokens — check if the tokenizer advertises thinking
# markers. Fall back to empty strings (split_reasoning no-ops).
self.think_start = getattr(tokenizer, "think_start", "") or ""
self.think_end = getattr(tokenizer, "think_end", "") or ""
self.has_thinking = bool(
getattr(tokenizer, "has_thinking", False) or self.think_start
)
except Exception as err:
print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}")
@@ -128,63 +135,72 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
temp_files = []
try:
# Process images and audios from request
image_paths = []
audio_paths = []
# Process images
if request.Images:
for img_data in request.Images:
img_path = self.load_image_from_base64(img_data)
if img_path:
image_paths.append(img_path)
temp_files.append(img_path)
# Process audios
if request.Audios:
for audio_data in request.Audios:
audio_path = self.load_audio_from_base64(audio_data)
if audio_path:
audio_paths.append(audio_path)
temp_files.append(audio_path)
# Prepare the prompt with multimodal information
prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
# Build generation parameters using request attributes and options
max_tokens, generation_params = self._build_generation_params(request)
print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
# Generate text using MLX-VLM with multimodal inputs
response = generate(
image_paths, audio_paths = self._collect_media(request, temp_files)
prompt = self._prepare_prompt(
request,
num_images=len(image_paths),
num_audios=len(audio_paths),
)
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
sampler = make_sampler(**sampler_params)
logits_processors = make_logits_processors(**logits_params) if logits_params else None
print(
f"Generating text with MLX-VLM - max_tokens: {max_tokens}, "
f"images: {len(image_paths)}, audios: {len(audio_paths)}",
file=sys.stderr,
)
accumulated = []
last_response = None
for response in stream_generate(
model=self.model,
processor=self.processor,
prompt=prompt,
image=image_paths if image_paths else None,
audio=audio_paths if audio_paths else None,
max_tokens=max_tokens,
temperature=generation_params.get('temp', 0.6),
top_p=generation_params.get('top_p', 1.0),
verbose=False
sampler=sampler,
logits_processors=logits_processors,
):
accumulated.append(response.text)
last_response = response
if stop_words and any(s in "".join(accumulated) for s in stop_words):
break
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
self._finalize_output(request, full_text, last_response)
)
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
return backend_pb2.Reply(
message=bytes(content, encoding='utf-8'),
prompt_tokens=prompt_tokens,
tokens=completion_tokens,
logprobs=logprobs_bytes,
chat_deltas=[
backend_pb2.ChatDelta(
content=content,
reasoning_content=reasoning_content,
tool_calls=tool_calls_proto,
)
],
)
except Exception as e:
print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Generation failed: {str(e)}")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
finally:
# Clean up temporary files
self.cleanup_temp_files(temp_files)
def Embedding(self, request, context):
"""
A gRPC method that calculates embeddings for a given sentence.
Note: MLX-VLM doesn't support embeddings directly. This method returns an error.
Args:
@@ -199,6 +215,79 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
context.set_details("Embeddings are not supported in the MLX-VLM backend.")
return backend_pb2.EmbeddingResult()
def _collect_media(self, request, temp_files):
"""Decode base64 Images and Audios into temp file paths.
Appends every temp file to ``temp_files`` so the finally block can
clean up even on mid-generation errors.
"""
image_paths = []
audio_paths = []
if request.Images:
for img_data in request.Images:
img_path = self.load_image_from_base64(img_data)
if img_path:
image_paths.append(img_path)
temp_files.append(img_path)
if request.Audios:
for audio_data in request.Audios:
audio_path = self.load_audio_from_base64(audio_data)
if audio_path:
audio_paths.append(audio_path)
temp_files.append(audio_path)
return image_paths, audio_paths
async def TokenizeString(self, request, context):
"""Tokenize ``request.Prompt`` via the processor's tokenizer."""
if not hasattr(self, "processor") or self.processor is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("processor not loaded")
return backend_pb2.TokenizationResponse()
try:
tokenizer = (
self.processor.tokenizer
if hasattr(self.processor, "tokenizer")
else self.processor
)
tokens = tokenizer.encode(request.Prompt)
if hasattr(tokens, "tolist"):
tokens = tokens.tolist()
tokens = list(tokens)
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
except Exception as e:
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return backend_pb2.TokenizationResponse()
async def Free(self, request, context):
"""Drop the loaded model, processor and tool module."""
try:
if hasattr(self, "model"):
del self.model
if hasattr(self, "processor"):
del self.processor
if hasattr(self, "config"):
del self.config
self.tool_module = None
gc.collect()
# mlx.clear_cache (mlx >= 0.30) supersedes mlx.metal.clear_cache.
try:
if hasattr(mx, "clear_cache"):
mx.clear_cache()
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
mx.metal.clear_cache()
except Exception:
pass
try:
import torch # type: ignore
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
return backend_pb2.Result(success=True, message="MLX-VLM model freed")
except Exception as e:
return backend_pb2.Result(success=False, message=str(e))
async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support.
@@ -212,36 +301,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
temp_files = []
try:
# Process images and audios from request
image_paths = []
audio_paths = []
# Process images
if request.Images:
for img_data in request.Images:
img_path = self.load_image_from_base64(img_data)
if img_path:
image_paths.append(img_path)
temp_files.append(img_path)
# Process audios
if request.Audios:
for audio_data in request.Audios:
audio_path = self.load_audio_from_base64(audio_data)
if audio_path:
audio_paths.append(audio_path)
temp_files.append(audio_path)
# Prepare the prompt with multimodal information
prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
# Build generation parameters using request attributes and options
max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512)
print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
# Stream text generation using MLX-VLM with multimodal inputs
image_paths, audio_paths = self._collect_media(request, temp_files)
prompt = self._prepare_prompt(
request,
num_images=len(image_paths),
num_audios=len(audio_paths),
)
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
request, default_max_tokens=512
)
sampler = make_sampler(**sampler_params)
logits_processors = make_logits_processors(**logits_params) if logits_params else None
print(
f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, "
f"images: {len(image_paths)}, audios: {len(audio_paths)}",
file=sys.stderr,
)
accumulated = []
last_response = None
for response in stream_generate(
model=self.model,
processor=self.processor,
@@ -249,77 +330,91 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
image=image_paths if image_paths else None,
audio=audio_paths if audio_paths else None,
max_tokens=max_tokens,
temperature=generation_params.get('temp', 0.6),
top_p=generation_params.get('top_p', 1.0),
sampler=sampler,
logits_processors=logits_processors,
):
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
accumulated.append(response.text)
last_response = response
yield backend_pb2.Reply(
message=bytes(response.text, encoding='utf-8'),
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
)
if stop_words and any(s in "".join(accumulated) for s in stop_words):
break
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
self._finalize_output(request, full_text, last_response)
)
yield backend_pb2.Reply(
message=b"",
prompt_tokens=prompt_tokens,
tokens=completion_tokens,
logprobs=logprobs_bytes,
chat_deltas=[
backend_pb2.ChatDelta(
content="",
reasoning_content=reasoning_content,
tool_calls=tool_calls_proto,
)
],
)
except Exception as e:
print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Streaming generation failed: {str(e)}")
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
finally:
# Clean up temporary files
self.cleanup_temp_files(temp_files)
def _build_template_kwargs(self, request, num_images, num_audios):
"""Collect kwargs for ``apply_chat_template`` that survive model variants."""
kwargs = {"num_images": num_images, "num_audios": num_audios}
if request.Tools:
try:
kwargs["tools"] = json.loads(request.Tools)
except json.JSONDecodeError:
pass
if request.Metadata.get("enable_thinking", "").lower() == "true":
kwargs["enable_thinking"] = True
return kwargs
def _apply_template(self, request, messages, num_images, num_audios):
kwargs = self._build_template_kwargs(request, num_images, num_audios)
try:
return apply_chat_template(self.processor, self.config, messages, **kwargs)
except TypeError:
# Fallback for older mlx-vlm versions that reject tools=/enable_thinking=
return apply_chat_template(
self.processor,
self.config,
messages,
num_images=num_images,
num_audios=num_audios,
)
def _prepare_prompt(self, request, num_images=0, num_audios=0):
"""
Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs.
Args:
request: The gRPC request containing prompt and message information.
num_images: Number of images in the request.
num_audios: Number of audio files in the request.
Returns:
str: The prepared prompt.
Prepare the prompt for MLX-VLM generation, handling chat templates and
multimodal inputs. Forwards tool definitions and enable_thinking when
present on the request.
"""
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
# Convert gRPC messages to the format expected by apply_chat_template
messages = []
for msg in request.Messages:
messages.append({"role": msg.role, "content": msg.content})
# Use mlx-vlm's apply_chat_template which handles multimodal inputs
prompt = apply_chat_template(
self.processor,
self.config,
messages,
num_images=num_images,
num_audios=num_audios
)
return prompt
elif request.Prompt:
# If we have a direct prompt but also have images/audio, we need to format it properly
messages = messages_to_dicts(request.Messages)
return self._apply_template(request, messages, num_images, num_audios)
if request.Prompt:
if num_images > 0 or num_audios > 0:
# Create a simple message structure for multimodal prompt
messages = [{"role": "user", "content": request.Prompt}]
prompt = apply_chat_template(
self.processor,
self.config,
messages,
num_images=num_images,
num_audios=num_audios
)
return prompt
else:
return request.Prompt
else:
# Fallback to empty prompt with multimodal template if we have media
if num_images > 0 or num_audios > 0:
messages = [{"role": "user", "content": ""}]
prompt = apply_chat_template(
self.processor,
self.config,
messages,
num_images=num_images,
num_audios=num_audios
)
return prompt
else:
return ""
return self._apply_template(request, messages, num_images, num_audios)
return request.Prompt
# Fallback to empty prompt with multimodal template if we have media
if num_images > 0 or num_audios > 0:
messages = [{"role": "user", "content": ""}]
return self._apply_template(request, messages, num_images, num_audios)
return ""
@@ -327,62 +422,122 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def _build_generation_params(self, request, default_max_tokens=200):
"""
Build generation parameters from request attributes and options for MLX-VLM.
Args:
request: The gRPC request.
default_max_tokens: Default max_tokens if not specified.
Build generation parameters from request attributes and options.
Returns:
tuple: (max_tokens, generation_params dict)
tuple: (max_tokens, sampler_params, logits_params, stop_words)
"""
# Extract max_tokens
max_tokens = getattr(request, 'Tokens', default_max_tokens)
if max_tokens == 0:
max_tokens = default_max_tokens
# Extract generation parameters from request attributes
temp = getattr(request, 'Temperature', 0.0)
if temp == 0.0:
temp = 0.6 # Default temperature
top_p = getattr(request, 'TopP', 0.0)
if top_p == 0.0:
top_p = 1.0 # Default top_p
# Initialize generation parameters for MLX-VLM
generation_params = {
max_tokens = getattr(request, 'Tokens', default_max_tokens) or default_max_tokens
temp = getattr(request, 'Temperature', 0.0) or 0.6
top_p = getattr(request, 'TopP', 0.0) or 1.0
min_p = getattr(request, 'MinP', 0.0) or 0.0
top_k = getattr(request, 'TopK', 0) or 0
sampler_params = {
'temp': temp,
'top_p': top_p,
'min_p': min_p,
'top_k': top_k,
}
# Add seed if specified
logits_params = {}
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
if repetition_penalty and repetition_penalty != 1.0:
logits_params['repetition_penalty'] = repetition_penalty
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
if presence_penalty:
logits_params['presence_penalty'] = presence_penalty
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
if frequency_penalty:
logits_params['frequency_penalty'] = frequency_penalty
seed = getattr(request, 'Seed', 0)
if seed != 0:
mx.random.seed(seed)
# Override with options if available
if hasattr(self, 'options'):
# Max tokens from options
if 'max_tokens' in self.options:
max_tokens = self.options['max_tokens']
# Generation parameters from options
param_option_mapping = {
'temp': 'temp',
'temperature': 'temp', # alias
'top_p': 'top_p',
option_mapping = {
'temp': 'temp', 'temperature': 'temp',
'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k',
}
for option_key, param_key in param_option_mapping.items():
for option_key, param_key in option_mapping.items():
if option_key in self.options:
generation_params[param_key] = self.options[option_key]
# Handle seed from options
sampler_params[param_key] = self.options[option_key]
for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
if option_key in self.options:
logits_params[option_key] = self.options[option_key]
if 'seed' in self.options:
mx.random.seed(self.options['seed'])
return max_tokens, generation_params
stop_words = list(getattr(request, 'StopPrompts', []) or [])
return max_tokens, sampler_params, logits_params, stop_words
def _finalize_output(self, request, generated_text, last_response):
"""Split reasoning + tool calls out of generated_text and return the
tuple consumed by Reply-builders."""
content = generated_text
reasoning_content = ""
if getattr(self, "has_thinking", False):
reasoning_content, content = split_reasoning(content, self.think_start, self.think_end)
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
if self.tool_module is not None:
parsed_tools = None
if request.Tools:
try:
parsed_tools = json.loads(request.Tools)
except json.JSONDecodeError:
parsed_tools = None
calls, content = parse_tool_calls(content, self.tool_module, parsed_tools)
for c in calls:
tool_calls_proto.append(
backend_pb2.ToolCallDelta(
index=c["index"],
id=c["id"],
name=c["name"],
arguments=c["arguments"],
)
)
prompt_tokens = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
completion_tokens = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
logprobs_bytes = b""
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
try:
lp = getattr(last_response, "logprobs", None)
if lp is not None:
token_id = int(getattr(last_response, "token", 0) or 0)
tokenizer = (
self.processor.tokenizer
if hasattr(self.processor, "tokenizer")
else self.processor
)
token_text = tokenizer.decode([token_id]) if token_id else ""
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
logprobs_bytes = json.dumps(
{"content": [{"token": token_text, "logprob": top_logprob}]}
).encode("utf-8")
except Exception as e:
print(f"[mlx-vlm] Logprobs extraction failed: {e}", file=sys.stderr)
return content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes
def _truncate_at_stop(self, text, stop_words):
if not stop_words:
return text
earliest = len(text)
for stop in stop_words:
if not stop:
continue
idx = text.find(stop)
if idx >= 0 and idx < earliest:
earliest = idx
return text[:earliest] if earliest < len(text) else text
def load_image_from_base64(self, image_data: str):
"""

View File

@@ -1,17 +1,19 @@
import os
import sys
import types
import unittest
import subprocess
import time
import grpc
import backend_pb2
import backend_pb2_grpc
import grpc
import unittest
import subprocess
import time
import grpc
import backend_pb2_grpc
import backend_pb2
# Make the shared helpers importable so we can unit-test them without a
# running gRPC server.
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
from python_utils import messages_to_dicts, parse_options
from mlx_utils import parse_tool_calls, split_reasoning
class TestBackendServicer(unittest.TestCase):
"""
@@ -143,4 +145,55 @@ class TestBackendServicer(unittest.TestCase):
print(err)
self.fail("Embedding service failed")
finally:
self.tearDown()
self.tearDown()
class TestSharedHelpers(unittest.TestCase):
"""Server-less unit tests for the helpers the mlx-vlm backend depends on."""
def test_parse_options_typed(self):
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello"])
self.assertEqual(opts["temperature"], 0.7)
self.assertEqual(opts["max_tokens"], 128)
self.assertIs(opts["trust"], True)
self.assertEqual(opts["name"], "hello")
def test_messages_to_dicts_roundtrip(self):
msgs = [
backend_pb2.Message(role="user", content="hi"),
backend_pb2.Message(
role="assistant",
content="",
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
),
backend_pb2.Message(
role="tool",
content="42",
tool_call_id="call_1",
name="f",
),
]
out = messages_to_dicts(msgs)
self.assertEqual(out[0], {"role": "user", "content": "hi"})
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
self.assertEqual(out[2]["tool_call_id"], "call_1")
def test_split_reasoning(self):
r, c = split_reasoning("<think>plan</think>final", "<think>", "</think>")
self.assertEqual(r, "plan")
self.assertEqual(c, "final")
def test_parse_tool_calls_with_shim(self):
tm = types.SimpleNamespace(
tool_call_start="<tool_call>",
tool_call_end="</tool_call>",
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
)
calls, remaining = parse_tool_calls(
"<tool_call>Paris</tool_call>",
tm,
tools=None,
)
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0]["name"], "get_weather")
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')

View File

@@ -2,11 +2,13 @@
import asyncio
from concurrent import futures
import argparse
import gc
import json
import signal
import sys
import os
import types
from typing import List
import time
import backend_pb2
import backend_pb2_grpc
@@ -15,13 +17,13 @@ import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from python_utils import messages_to_dicts, parse_options
from mlx_utils import parse_tool_calls, split_reasoning
from mlx_lm import load, generate, stream_generate
from mlx_lm.sample_utils import make_sampler
from mlx_lm import load, stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
import mlx.core as mx
import base64
import io
from mlx_cache import ThreadSafeLRUPromptCache
@@ -30,21 +32,6 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
def is_float(s):
"""Check if a string can be converted to float."""
try:
float(s)
return True
except ValueError:
return False
def is_int(s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
@@ -78,46 +65,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
print(f"Loading MLX model: {request.Model}", file=sys.stderr)
print(f"Request: {request}", file=sys.stderr)
# Parse options like in the diffusers backend
options = request.Options
self.options = {}
# The options are a list of strings in this form optname:optvalue
# We store all the options in a dict for later use
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
# Convert numeric values to appropriate types
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value
# Parse Options[] key:value strings into a typed dict (shared helper)
self.options = parse_options(request.Options)
print(f"Options: {self.options}", file=sys.stderr)
# Build tokenizer config for MLX using options
tokenizer_config = {}
# Handle trust_remote_code from request or options
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
tokenizer_config["trust_remote_code"] = True
# Handle EOS token from options
if "eos_token" in self.options:
tokenizer_config["eos_token"] = self.options["eos_token"]
# Handle other tokenizer config options
for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]:
if key in self.options:
tokenizer_config[key] = self.options[key]
# Load model and tokenizer using MLX
if tokenizer_config:
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
@@ -125,6 +93,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
else:
self.model, self.tokenizer = load(request.Model)
# mlx_lm.load() returns a TokenizerWrapper that detects tool
# calling and thinking markers from the chat template / vocab.
# mlx-lm >= 0.30 also exposes a parser callable on the wrapper;
# earlier versions don't (we fall back to json.loads inside
# _tool_module_from_tokenizer below).
has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False))
has_thinking = bool(getattr(self.tokenizer, "has_thinking", False))
tcs = getattr(self.tokenizer, "tool_call_start", None)
tce = getattr(self.tokenizer, "tool_call_end", None)
print(
f"MLX tokenizer capabilities: has_tool_calling={has_tools} "
f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}",
file=sys.stderr,
)
# Initialize thread-safe LRU prompt cache for efficient generation
max_cache_entries = self.options.get("max_cache_entries", 10)
self.max_kv_size = self.options.get("max_kv_size", None)
@@ -134,7 +117,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
can_trim_fn=can_trim_prompt_cache,
trim_fn=trim_prompt_cache,
)
except Exception as err:
print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}")
@@ -172,30 +155,58 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
remaining_tokens = cache_key
# Build generation parameters using request attributes and options
max_tokens, sampler_params = self._build_generation_params(request)
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
print(
f"Generating text with MLX - max_tokens: {max_tokens}, "
f"cache_hit: {len(remaining_tokens) < len(cache_key)}",
file=sys.stderr,
)
# Create sampler with parameters
# Create sampler and optional logits processors (penalties)
sampler = make_sampler(**sampler_params)
logits_processors = make_logits_processors(**logits_params) if logits_params else None
# Use stream_generate to track generated tokens for cache key
# Use stream_generate to collect text + track tokens for cache key
generated_text = []
last_response = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=remaining_tokens if remaining_tokens else cache_key,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
):
generated_text.append(response.text)
cache_key.append(response.token)
last_response = response
# Early stop on user-provided stop sequences
if stop_words and any(s in "".join(generated_text) for s in stop_words):
break
# Insert completed cache
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8'))
full_text = self._truncate_at_stop("".join(generated_text), stop_words)
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
self._finalize_output(request, full_text, last_response)
)
return backend_pb2.Reply(
message=bytes(content, encoding='utf-8'),
prompt_tokens=prompt_tokens,
tokens=completion_tokens,
logprobs=logprobs_bytes,
chat_deltas=[
backend_pb2.ChatDelta(
content=content,
reasoning_content=reasoning_content,
tool_calls=tool_calls_proto,
)
],
)
except Exception as e:
print(f"Error in MLX Predict: {e}", file=sys.stderr)
@@ -206,7 +217,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def Embedding(self, request, context):
"""
A gRPC method that calculates embeddings for a given sentence.
Note: MLX-LM doesn't support embeddings directly. This method returns an error.
Args:
@@ -221,6 +232,62 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
context.set_details("Embeddings are not supported in the MLX backend.")
return backend_pb2.EmbeddingResult()
async def TokenizeString(self, request, context):
"""Tokenize ``request.Prompt`` using the loaded model's tokenizer."""
if not hasattr(self, "tokenizer") or self.tokenizer is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("tokenizer not loaded")
return backend_pb2.TokenizationResponse()
try:
tokens = self.tokenizer.encode(request.Prompt)
if hasattr(tokens, "tolist"):
tokens = tokens.tolist()
tokens = list(tokens)
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
except Exception as e:
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return backend_pb2.TokenizationResponse()
async def Free(self, request, context):
"""Drop the loaded model, tokenizer and prompt cache.
Metal / CUDA memory is released via ``gc.collect()`` + the
platform-specific cache clear hooks when available.
"""
try:
if hasattr(self, "model"):
del self.model
if hasattr(self, "tokenizer"):
del self.tokenizer
if hasattr(self, "lru_cache") and self.lru_cache is not None:
try:
self.lru_cache.clear()
except Exception:
pass
self.lru_cache = None
gc.collect()
# Metal: drop the cached allocator. mlx.clear_cache (mlx >= 0.30)
# supersedes the now-deprecated mlx.metal.clear_cache.
try:
if hasattr(mx, "clear_cache"):
mx.clear_cache()
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
mx.metal.clear_cache()
except Exception:
pass
# CUDA: release the torch cache if a CUDA-backed mlx variant
# happens to be installed alongside torch (best-effort).
try:
import torch # type: ignore
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
return backend_pb2.Result(success=True, message="MLX model freed")
except Exception as e:
return backend_pb2.Result(success=False, message=str(e))
async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results using MLX.
@@ -251,24 +318,64 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
remaining_tokens = cache_key
# Build generation parameters using request attributes and options
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
request, default_max_tokens=512
)
print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
print(
f"Streaming text with MLX - max_tokens: {max_tokens}, "
f"cache_hit: {len(remaining_tokens) < len(cache_key)}",
file=sys.stderr,
)
# Create sampler with parameters
# Create sampler and optional logits processors (penalties)
sampler = make_sampler(**sampler_params)
logits_processors = make_logits_processors(**logits_params) if logits_params else None
# Stream text generation using MLX with proper parameters
accumulated = []
last_response = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=remaining_tokens if remaining_tokens else cache_key,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
):
cache_key.append(response.token)
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
accumulated.append(response.text)
last_response = response
# Emit a content delta. Structured reasoning / tool parsing
# happens on the final chunk so we don't fragment the state
# machine in v1.
yield backend_pb2.Reply(
message=bytes(response.text, encoding='utf-8'),
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
)
# Early stop on user-provided stop sequences
if stop_words and any(s in "".join(accumulated) for s in stop_words):
break
# Final chunk: run reasoning + tool parsing on accumulated text
# and emit the structured ChatDelta with token counts + logprobs.
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
self._finalize_output(request, full_text, last_response)
)
yield backend_pb2.Reply(
message=b"",
prompt_tokens=prompt_tokens,
tokens=completion_tokens,
logprobs=logprobs_bytes,
chat_deltas=[
backend_pb2.ChatDelta(
content="",
reasoning_content=reasoning_content,
tool_calls=tool_calls_proto,
)
],
)
except Exception as e:
print(f"Error in MLX PredictStream: {e}", file=sys.stderr)
@@ -294,21 +401,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
str: The prepared prompt.
"""
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
# If tokenizer template is enabled and messages are provided instead
# of prompt, apply the tokenizer template (forwards tool definitions
# and enable_thinking when the model supports them).
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
# Convert gRPC messages to the format expected by apply_chat_template
messages = []
for msg in request.Messages:
messages.append({"role": msg.role, "content": msg.content})
messages = messages_to_dicts(request.Messages)
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return prompt
else:
return request.Prompt
kwargs = {"tokenize": False, "add_generation_prompt": True}
if request.Tools:
try:
kwargs["tools"] = json.loads(request.Tools)
except json.JSONDecodeError:
pass
enable_thinking = request.Metadata.get("enable_thinking", "").lower()
if enable_thinking == "true":
kwargs["enable_thinking"] = True
try:
return self.tokenizer.apply_chat_template(messages, **kwargs)
except TypeError:
# Fallback for tokenizers whose template doesn't accept
# tools= or enable_thinking=.
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return request.Prompt
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
"""
@@ -338,18 +457,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
default_max_tokens: Default max_tokens if not specified.
Returns:
tuple: (max_tokens, sampler_params dict)
tuple: (max_tokens, sampler_params dict, logits_processor_params dict,
stop_words list)
"""
# Extract max_tokens
max_tokens = getattr(request, 'Tokens', default_max_tokens)
if max_tokens == 0:
max_tokens = default_max_tokens
# Extract sampler parameters from request attributes
temp = getattr(request, 'Temperature', 0.0)
if temp == 0.0:
temp = 0.6 # Default temperature
top_p = getattr(request, 'TopP', 0.0)
if top_p == 0.0:
top_p = 1.0 # Default top_p
@@ -369,18 +489,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
'xtc_threshold': 0.0,
'xtc_probability': 0.0,
}
# Logits processor parameters — only set fields the request actually
# provides so we can feed them unconditionally to make_logits_processors.
logits_params = {}
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
if repetition_penalty and repetition_penalty != 1.0:
logits_params['repetition_penalty'] = repetition_penalty
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
if presence_penalty:
logits_params['presence_penalty'] = presence_penalty
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
if frequency_penalty:
logits_params['frequency_penalty'] = frequency_penalty
# Add seed if specified
seed = getattr(request, 'Seed', 0)
if seed != 0:
mx.random.seed(seed)
# Override with options if available
if hasattr(self, 'options'):
# Max tokens from options
if 'max_tokens' in self.options:
max_tokens = self.options['max_tokens']
# Sampler parameters from options
sampler_option_mapping = {
'temp': 'temp',
@@ -391,32 +524,142 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
'xtc_threshold': 'xtc_threshold',
'xtc_probability': 'xtc_probability',
}
for option_key, param_key in sampler_option_mapping.items():
if option_key in self.options:
sampler_params[param_key] = self.options[option_key]
# Logits processor overrides
for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
if option_key in self.options:
logits_params[option_key] = self.options[option_key]
# Handle seed from options
if 'seed' in self.options:
mx.random.seed(self.options['seed'])
# Special tokens for XTC sampling (if tokenizer has eos_token_ids)
xtc_special_tokens = []
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
xtc_special_tokens = [self.tokenizer.eos_token_id]
# Add newline token if available
try:
newline_tokens = self.tokenizer.encode("\n")
xtc_special_tokens.extend(newline_tokens)
except:
except Exception:
pass # Skip if encoding fails
sampler_params['xtc_special_tokens'] = xtc_special_tokens
return max_tokens, sampler_params
# Stop sequences are applied post-decode (mlx-lm doesn't have a
# built-in stop-sequence sampler param). Preserve the list here.
stop_words = list(getattr(request, 'StopPrompts', []) or [])
return max_tokens, sampler_params, logits_params, stop_words
def _tool_module_from_tokenizer(self):
"""Build a duck-typed tool module from the TokenizerWrapper.
On mlx-lm >= 0.30 the wrapper exposes a ``tool_parser`` callable
that's been resolved from the model's chat template. On older
releases (e.g. 0.29.x) the wrapper only carries the start/end
markers — fall back to ``json.loads`` of the body, which matches
what ``mlx_lm.tool_parsers.json_tools.parse_tool_call`` does on
HEAD and covers the only format 0.29 detects (``<tool_call>``).
"""
start = getattr(self.tokenizer, "tool_call_start", None)
end = getattr(self.tokenizer, "tool_call_end", None)
if not start:
return None
parse_fn = getattr(self.tokenizer, "tool_parser", None)
if parse_fn is None:
def parse_fn(body, tools): # noqa: E306 — local fallback
return json.loads(body.strip())
return types.SimpleNamespace(
tool_call_start=start,
tool_call_end=end or "",
parse_tool_call=parse_fn,
)
def _finalize_output(self, request, generated_text, last_response):
"""Build a ChatDelta + token counts + logprobs from accumulated output.
Returns ``(content, reasoning_content, tool_calls_proto,
prompt_token_count, completion_token_count, logprobs_bytes)``.
"""
content = generated_text
reasoning_content = ""
if getattr(self.tokenizer, "has_thinking", False):
think_start = getattr(self.tokenizer, "think_start", "") or ""
think_end = getattr(self.tokenizer, "think_end", "") or ""
reasoning_content, content = split_reasoning(content, think_start, think_end)
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
tool_module = None
if getattr(self.tokenizer, "has_tool_calling", False):
tool_module = self._tool_module_from_tokenizer()
if tool_module is not None:
parsed_tools = None
if request.Tools:
try:
parsed_tools = json.loads(request.Tools)
except json.JSONDecodeError:
parsed_tools = None
calls, content = parse_tool_calls(content, tool_module, parsed_tools)
for c in calls:
tool_calls_proto.append(
backend_pb2.ToolCallDelta(
index=c["index"],
id=c["id"],
name=c["name"],
arguments=c["arguments"],
)
)
prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
logprobs_bytes = b""
# Logprobs extraction — only when the request asked for them.
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
try:
lp = getattr(last_response, "logprobs", None)
if lp is not None:
# GenerationResponse.logprobs on the last chunk is the
# logprob distribution of the final token. Without a
# per-token history we at minimum surface the last token's
# top-1 logprob so clients get a non-empty field.
token_id = int(getattr(last_response, "token", 0) or 0)
token_text = self.tokenizer.decode([token_id]) if token_id else ""
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
logprobs_bytes = json.dumps(
{
"content": [
{"token": token_text, "logprob": top_logprob}
]
}
).encode("utf-8")
except Exception as e:
print(f"[mlx] Logprobs extraction failed: {e}", file=sys.stderr)
return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes
def _truncate_at_stop(self, text, stop_words):
"""Truncate ``text`` at the first occurrence of any stop sequence."""
if not stop_words:
return text
earliest = len(text)
for stop in stop_words:
if not stop:
continue
idx = text.find(stop)
if idx >= 0 and idx < earliest:
earliest = idx
return text[:earliest] if earliest < len(text) else text
async def serve(address):
# Start asyncio gRPC server

View File

@@ -1,11 +1,20 @@
import os
import sys
import unittest
import subprocess
import time
import types
import grpc
import backend_pb2
import backend_pb2_grpc
# Make the shared helpers importable so we can unit-test them without a
# running gRPC server.
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
from python_utils import messages_to_dicts, parse_options
from mlx_utils import parse_tool_calls, split_reasoning
class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service.
@@ -231,4 +240,104 @@ class TestBackendServicer(unittest.TestCase):
self.tearDown()
def test_tokenize_string(self):
"""TokenizeString should return a non-empty token list for a known prompt."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(
backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")
)
self.assertTrue(response.success)
resp = stub.TokenizeString(backend_pb2.PredictOptions(Prompt="Hello, world"))
self.assertGreater(resp.length, 0)
self.assertEqual(len(list(resp.tokens)), resp.length)
except Exception as err:
print(err)
self.fail("TokenizeString service failed")
finally:
self.tearDown()
def test_free(self):
"""Free should release the model and not crash on subsequent calls."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(
backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")
)
self.assertTrue(response.success)
free_resp = stub.Free(backend_pb2.HealthMessage())
self.assertTrue(free_resp.success)
except Exception as err:
print(err)
self.fail("Free service failed")
finally:
self.tearDown()
class TestSharedHelpers(unittest.TestCase):
"""Server-less unit tests for the helpers the mlx backend depends on."""
def test_parse_options_typed(self):
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello", "no_colon_skipped"])
self.assertEqual(opts["temperature"], 0.7)
self.assertEqual(opts["max_tokens"], 128)
self.assertIs(opts["trust"], True)
self.assertEqual(opts["name"], "hello")
self.assertNotIn("no_colon_skipped", opts)
def test_messages_to_dicts_roundtrip(self):
# Build proto Message objects (via backend_pb2 to match real gRPC)
msgs = [
backend_pb2.Message(role="user", content="hi"),
backend_pb2.Message(
role="assistant",
content="",
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
),
backend_pb2.Message(
role="tool",
content="42",
tool_call_id="call_1",
name="f",
),
]
out = messages_to_dicts(msgs)
self.assertEqual(out[0], {"role": "user", "content": "hi"})
self.assertEqual(out[1]["role"], "assistant")
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
self.assertEqual(out[2]["tool_call_id"], "call_1")
self.assertEqual(out[2]["name"], "f")
def test_split_reasoning(self):
r, c = split_reasoning("<think>step 1\nstep 2</think>The answer is 42.", "<think>", "</think>")
self.assertEqual(r, "step 1\nstep 2")
self.assertEqual(c, "The answer is 42.")
def test_split_reasoning_no_marker(self):
r, c = split_reasoning("just text", "<think>", "</think>")
self.assertEqual(r, "")
self.assertEqual(c, "just text")
def test_parse_tool_calls_with_shim(self):
tm = types.SimpleNamespace(
tool_call_start="<tool_call>",
tool_call_end="</tool_call>",
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
)
calls, remaining = parse_tool_calls(
"Sure: <tool_call>Paris</tool_call>",
tm,
tools=None,
)
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0]["name"], "get_weather")
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
self.assertEqual(calls[0]["index"], 0)
self.assertNotIn("<tool_call>", remaining)
# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py

View File

@@ -84,6 +84,20 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
// Apply per-model-family inference parameter defaults
config.ApplyInferenceDefaults(&modelConfig, details.URI)
// Auto-set tool_parser / reasoning_parser from parser_defaults.json so
// the generated YAML mirrors what the vllm importer produces. The mlx
// backends auto-detect parsers from the chat template at runtime and
// ignore these Options, but surfacing them in the config keeps the two
// paths consistent and gives users a single place to override.
if parsers := config.MatchParserDefaults(details.URI); parsers != nil {
if tp, ok := parsers["tool_parser"]; ok {
modelConfig.Options = append(modelConfig.Options, "tool_parser:"+tp)
}
if rp, ok := parsers["reasoning_parser"]; ok {
modelConfig.Options = append(modelConfig.Options, "reasoning_parser:"+rp)
}
}
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err