diff --git a/Makefile b/Makefile
index 7e2e35052..5ef606297 100644
--- a/Makefile
+++ b/Makefile
@@ -519,6 +519,22 @@ test-extra-backend-vllm: docker-build-vllm
BACKEND_TEST_OPTIONS=tool_parser:hermes \
$(MAKE) test-extra-backend
+## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool
+## parser from the chat template, so no tool_parser: option is needed (it
+## would be ignored at runtime). Run this on macOS / arm64 with Metal; the
+## Linux/CPU mlx variant is untested in CI.
+test-extra-backend-mlx: docker-build-mlx
+ BACKEND_IMAGE=local-ai-backend:mlx \
+ BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \
+ BACKEND_TEST_CAPS=health,load,predict,stream,tools \
+ $(MAKE) test-extra-backend
+
+test-extra-backend-mlx-vlm: docker-build-mlx-vlm
+ BACKEND_IMAGE=local-ai-backend:mlx-vlm \
+ BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \
+ BACKEND_TEST_CAPS=health,load,predict,stream,tools \
+ $(MAKE) test-extra-backend
+
DOCKER_IMAGE?=local-ai
IMAGE_TYPE?=core
BASE_IMAGE?=ubuntu:24.04
@@ -652,6 +668,8 @@ BACKEND_NEMO = nemo|python|.|false|true
BACKEND_VOXCPM = voxcpm|python|.|false|true
BACKEND_WHISPERX = whisperx|python|.|false|true
BACKEND_ACE_STEP = ace-step|python|.|false|true
+BACKEND_MLX = mlx|python|.|false|true
+BACKEND_MLX_VLM = mlx-vlm|python|.|false|true
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
BACKEND_TRL = trl|python|.|false|true
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
@@ -720,6 +738,8 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP)))
+$(eval $(call generate-docker-build-target,$(BACKEND_MLX)))
+$(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
diff --git a/backend/python/common/libbackend.sh b/backend/python/common/libbackend.sh
index c923c12cf..982dafab3 100644
--- a/backend/python/common/libbackend.sh
+++ b/backend/python/common/libbackend.sh
@@ -344,7 +344,16 @@ function ensureVenv() {
if [ ! -d "${EDIR}/venv" ]; then
if [ "x${USE_PIP}" == "xtrue" ]; then
- "${interpreter}" -m venv --copies "${EDIR}/venv"
+ # --copies is only needed when we will later relocate the venv via
+ # _makeVenvPortable (PORTABLE_PYTHON=true). Some Python builds —
+ # notably macOS system Python — refuse to create a venv with
+ # --copies because the build doesn't support it. Fall back to
+ # symlinks in that case.
+ local venv_args=""
+ if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
+ venv_args="--copies"
+ fi
+ "${interpreter}" -m venv ${venv_args} "${EDIR}/venv"
source "${EDIR}/venv/bin/activate"
"${interpreter}" -m pip install --upgrade pip
else
diff --git a/backend/python/common/mlx_utils.py b/backend/python/common/mlx_utils.py
new file mode 100644
index 000000000..6b34eb962
--- /dev/null
+++ b/backend/python/common/mlx_utils.py
@@ -0,0 +1,100 @@
+"""Shared utilities for the mlx and mlx-vlm gRPC backends.
+
+These helpers wrap mlx-lm's and mlx-vlm's native tool-parser modules, which
+auto-detect the right parser from the model's chat template. Each tool
+module exposes ``tool_call_start``, ``tool_call_end`` and
+``parse_tool_call(text, tools) -> dict | list[dict]``.
+
+The split-reasoning helper is generic enough to work with any think-start /
+think-end delimiter pair.
+"""
+import json
+import re
+import sys
+import uuid
+
+
+def split_reasoning(text, think_start, think_end):
+ """Split ``...`` blocks out of ``text``.
+
+ Returns ``(reasoning_content, remaining_text)``. When ``think_start`` is
+ empty or not found, returns ``("", text)`` unchanged.
+ """
+ if not think_start or not text or think_start not in text:
+ return "", text
+ pattern = re.compile(
+ re.escape(think_start) + r"(.*?)" + re.escape(think_end or ""),
+ re.DOTALL,
+ )
+ reasoning_parts = pattern.findall(text)
+ if not reasoning_parts:
+ return "", text
+ remaining = pattern.sub("", text).strip()
+ return "\n".join(p.strip() for p in reasoning_parts), remaining
+
+
+def parse_tool_calls(text, tool_module, tools):
+ """Extract tool calls from ``text`` using a mlx-lm tool module.
+
+ Ports the ``process_tool_calls`` logic from
+ ``mlx_vlm/server.py`` (v0.10 onwards). ``tool_module`` must expose
+ ``tool_call_start``, ``tool_call_end`` and ``parse_tool_call``.
+
+ Returns ``(calls, remaining_text)`` where ``calls`` is a list of dicts:
+
+ [{"index": int, "id": str, "name": str, "arguments": str (JSON)}]
+
+ and ``remaining_text`` is the free-form text with the tool call blocks
+ removed. ``(calls, text)`` is returned unchanged if ``tool_module`` is
+ ``None`` or the start delimiter isn't present.
+ """
+ if tool_module is None or not text:
+ return [], text
+ start = getattr(tool_module, "tool_call_start", None)
+ end = getattr(tool_module, "tool_call_end", None)
+ parse_fn = getattr(tool_module, "parse_tool_call", None)
+ if not start or parse_fn is None or start not in text:
+ return [], text
+
+ if end == "" or end is None:
+ pattern = re.compile(
+ re.escape(start) + r".*?(?:\n|$)",
+ re.DOTALL,
+ )
+ else:
+ pattern = re.compile(
+ re.escape(start) + r".*?" + re.escape(end),
+ re.DOTALL,
+ )
+
+ matches = pattern.findall(text)
+ if not matches:
+ return [], text
+
+ remaining = pattern.sub(" ", text).strip()
+ calls = []
+ for match in matches:
+ call_body = match.strip().removeprefix(start)
+ if end:
+ call_body = call_body.removesuffix(end)
+ call_body = call_body.strip()
+ try:
+ parsed = parse_fn(call_body, tools)
+ except Exception as e:
+ print(
+ f"[mlx_utils] Invalid tool call: {call_body!r} ({e})",
+ file=sys.stderr,
+ )
+ continue
+ if not isinstance(parsed, list):
+ parsed = [parsed]
+ for tc in parsed:
+ calls.append(
+ {
+ "index": len(calls),
+ "id": str(uuid.uuid4()),
+ "name": (tc.get("name") or "").strip(),
+ "arguments": json.dumps(tc.get("arguments", {}), ensure_ascii=False),
+ }
+ )
+ return calls, remaining
diff --git a/backend/python/common/python_utils.py b/backend/python/common/python_utils.py
new file mode 100644
index 000000000..aa61ab578
--- /dev/null
+++ b/backend/python/common/python_utils.py
@@ -0,0 +1,65 @@
+"""Generic utilities shared across Python gRPC backends.
+
+These helpers don't depend on any specific inference framework and can be
+imported by any backend that needs to parse LocalAI gRPC options or build a
+chat-template-compatible message list from proto Message objects.
+"""
+import json
+
+
+def parse_options(options_list):
+ """Parse Options[] list of ``key:value`` strings into a dict.
+
+ Supports type inference for common cases (bool, int, float). Unknown or
+ mixed-case values are returned as strings.
+
+ Used by LoadModel to extract backend-specific options passed via
+ ``ModelOptions.Options`` in ``backend.proto``.
+ """
+ opts = {}
+ for opt in options_list:
+ if ":" not in opt:
+ continue
+ key, value = opt.split(":", 1)
+ key = key.strip()
+ value = value.strip()
+ # Try type conversion
+ if value.lower() in ("true", "false"):
+ opts[key] = value.lower() == "true"
+ else:
+ try:
+ opts[key] = int(value)
+ except ValueError:
+ try:
+ opts[key] = float(value)
+ except ValueError:
+ opts[key] = value
+ return opts
+
+
+def messages_to_dicts(proto_messages):
+ """Convert proto ``Message`` objects to dicts suitable for ``apply_chat_template``.
+
+ Handles: ``role``, ``content``, ``name``, ``tool_call_id``,
+ ``reasoning_content``, ``tool_calls`` (JSON string → Python list).
+
+ HuggingFace chat templates (and their MLX/vLLM wrappers) expect a list of
+ plain dicts — proto Message objects don't work directly with Jinja, so
+ this conversion is needed before every ``apply_chat_template`` call.
+ """
+ result = []
+ for msg in proto_messages:
+ d = {"role": msg.role, "content": msg.content or ""}
+ if msg.name:
+ d["name"] = msg.name
+ if msg.tool_call_id:
+ d["tool_call_id"] = msg.tool_call_id
+ if msg.reasoning_content:
+ d["reasoning_content"] = msg.reasoning_content
+ if msg.tool_calls:
+ try:
+ d["tool_calls"] = json.loads(msg.tool_calls)
+ except json.JSONDecodeError:
+ pass
+ result.append(d)
+ return result
diff --git a/backend/python/common/vllm_utils.py b/backend/python/common/vllm_utils.py
index bc0518663..9124645ac 100644
--- a/backend/python/common/vllm_utils.py
+++ b/backend/python/common/vllm_utils.py
@@ -1,63 +1,22 @@
-"""Shared utilities for vLLM-based backends."""
-import json
+"""vLLM-specific helpers for the vllm and vllm-omni gRPC backends.
+
+Generic helpers (``parse_options``, ``messages_to_dicts``) live in
+``python_utils`` and are re-exported here for backwards compatibility with
+existing imports in both backends.
+"""
import sys
+from python_utils import messages_to_dicts, parse_options
-def parse_options(options_list):
- """Parse Options[] list of 'key:value' strings into a dict.
-
- Supports type inference for common cases (bool, int, float).
- Used by LoadModel to extract backend-specific options.
- """
- opts = {}
- for opt in options_list:
- if ":" not in opt:
- continue
- key, value = opt.split(":", 1)
- key = key.strip()
- value = value.strip()
- # Try type conversion
- if value.lower() in ("true", "false"):
- opts[key] = value.lower() == "true"
- else:
- try:
- opts[key] = int(value)
- except ValueError:
- try:
- opts[key] = float(value)
- except ValueError:
- opts[key] = value
- return opts
-
-
-def messages_to_dicts(proto_messages):
- """Convert proto Message objects to list of dicts for apply_chat_template().
-
- Handles: role, content, name, tool_call_id, reasoning_content, tool_calls (JSON string -> list).
- """
- result = []
- for msg in proto_messages:
- d = {"role": msg.role, "content": msg.content or ""}
- if msg.name:
- d["name"] = msg.name
- if msg.tool_call_id:
- d["tool_call_id"] = msg.tool_call_id
- if msg.reasoning_content:
- d["reasoning_content"] = msg.reasoning_content
- if msg.tool_calls:
- try:
- d["tool_calls"] = json.loads(msg.tool_calls)
- except json.JSONDecodeError:
- pass
- result.append(d)
- return result
+__all__ = ["parse_options", "messages_to_dicts", "setup_parsers"]
def setup_parsers(opts):
- """Return (tool_parser_cls, reasoning_parser_cls) tuple from opts dict.
+ """Return ``(tool_parser_cls, reasoning_parser_cls)`` from an opts dict.
- Uses vLLM's native ToolParserManager and ReasoningParserManager.
- Returns (None, None) if vLLM is not installed or parsers not available.
+ Uses vLLM's native ``ToolParserManager`` / ``ReasoningParserManager``.
+ Returns ``(None, None)`` if vLLM isn't installed or the requested
+ parser name can't be resolved.
"""
tool_parser_cls = None
reasoning_parser_cls = None
diff --git a/backend/python/mlx-distributed/backend.py b/backend/python/mlx-distributed/backend.py
index 90d74eba8..b03775f58 100644
--- a/backend/python/mlx-distributed/backend.py
+++ b/backend/python/mlx-distributed/backend.py
@@ -15,17 +15,21 @@ Two startup modes:
import asyncio
from concurrent import futures
import argparse
+import gc
import json
import os
import signal
import sys
import tempfile
+import types
from typing import List
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
+from python_utils import messages_to_dicts, parse_options as _shared_parse_options
+from mlx_utils import parse_tool_calls, split_reasoning
import backend_pb2
@@ -62,37 +66,10 @@ def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None):
raise ValueError(f"Unknown backend: {backend}")
-def is_float(s):
- try:
- float(s)
- return True
- except ValueError:
- return False
-
-
-def is_int(s):
- try:
- int(s)
- return True
- except ValueError:
- return False
-
-
-def parse_options(options):
- """Parse key:value option strings into a dict."""
- result = {}
- for opt in options:
- if ":" not in opt:
- continue
- key, value = opt.split(":", 1)
- if is_float(value):
- value = float(value)
- elif is_int(value):
- value = int(value)
- elif value.lower() in ["true", "false"]:
- value = value.lower() == "true"
- result[key] = value
- return result
+# Re-export the shared helper under the local name for back-compat with
+# any callers (and the existing distributed worker tests) that imported
+# parse_options directly from this module.
+parse_options = _shared_parse_options
class BackendServicer(backend_pb2_grpc.BackendServicer):
@@ -188,6 +165,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
)
print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr)
+ # Log auto-detected TokenizerWrapper capabilities. Same shape
+ # as the mlx backend: has_tool_calling / has_thinking from
+ # mlx_lm.tokenizer_utils + the start/end markers it sniffed
+ # from the chat template / vocab.
+ has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False))
+ has_thinking = bool(getattr(self.tokenizer, "has_thinking", False))
+ tcs = getattr(self.tokenizer, "tool_call_start", None)
+ tce = getattr(self.tokenizer, "tool_call_end", None)
+ print(
+ f"[Rank 0] Tokenizer capabilities: has_tool_calling={has_tools} "
+ f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}",
+ file=sys.stderr,
+ )
+
except Exception as err:
print(f"[Rank 0] Error loading model: {err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading model: {err}")
@@ -201,7 +192,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
import mlx.core as mx
from mlx_lm import stream_generate
- from mlx_lm.sample_utils import make_sampler
+ from mlx_lm.sample_utils import make_logits_processors, make_sampler
prompt_text = self._prepare_prompt(request)
tokens = self._get_tokens_from_prompt(prompt_text)
@@ -211,7 +202,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
self.coordinator.broadcast_tokens(tokens)
- max_tokens, sampler_params = self._build_generation_params(request)
+ max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
if self.coordinator:
gen_params = self.coordinator.broadcast_generation_params(
@@ -222,6 +213,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
max_tokens = gen_params["max_tokens"]
sampler = make_sampler(**sampler_params)
+ logits_processors = make_logits_processors(**logits_params) if logits_params else None
# Use prompt cache in single-node mode
gen_kwargs = {}
@@ -238,22 +230,44 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
tokens = remaining_tokens if remaining_tokens else cache_key
generated = []
+ last_response = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=tokens,
max_tokens=max_tokens,
sampler=sampler,
+ logits_processors=logits_processors,
**gen_kwargs,
):
generated.append(response.text)
+ last_response = response
if cache_key is not None:
cache_key.append(response.token)
+ if stop_words and any(s in "".join(generated) for s in stop_words):
+ break
if self.lru_cache is not None and cache_key is not None:
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
- return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8'))
+ full_text = self._truncate_at_stop("".join(generated), stop_words)
+ content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
+ self._finalize_output(request, full_text, last_response)
+ )
+
+ return backend_pb2.Reply(
+ message=bytes(content, encoding='utf-8'),
+ prompt_tokens=prompt_tokens,
+ tokens=completion_tokens,
+ logprobs=logprobs_bytes,
+ chat_deltas=[
+ backend_pb2.ChatDelta(
+ content=content,
+ reasoning_content=reasoning_content,
+ tool_calls=tool_calls_proto,
+ )
+ ],
+ )
except Exception as e:
print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr)
@@ -268,7 +282,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
import mlx.core as mx
from mlx_lm import stream_generate
- from mlx_lm.sample_utils import make_sampler
+ from mlx_lm.sample_utils import make_logits_processors, make_sampler
prompt_text = self._prepare_prompt(request)
tokens = self._get_tokens_from_prompt(prompt_text)
@@ -278,7 +292,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
self.coordinator.broadcast_tokens(tokens)
- max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
+ max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
+ request, default_max_tokens=512
+ )
if self.coordinator:
gen_params = self.coordinator.broadcast_generation_params(
@@ -289,6 +305,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
max_tokens = gen_params["max_tokens"]
sampler = make_sampler(**sampler_params)
+ logits_processors = make_logits_processors(**logits_params) if logits_params else None
# Use prompt cache in single-node mode
gen_kwargs = {}
@@ -304,17 +321,45 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
gen_kwargs['prompt_cache'] = prompt_cache
tokens = remaining_tokens if remaining_tokens else cache_key
+ accumulated = []
+ last_response = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=tokens,
max_tokens=max_tokens,
sampler=sampler,
+ logits_processors=logits_processors,
**gen_kwargs,
):
if cache_key is not None:
cache_key.append(response.token)
- yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
+ accumulated.append(response.text)
+ last_response = response
+ yield backend_pb2.Reply(
+ message=bytes(response.text, encoding='utf-8'),
+ chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
+ )
+ if stop_words and any(s in "".join(accumulated) for s in stop_words):
+ break
+
+ full_text = self._truncate_at_stop("".join(accumulated), stop_words)
+ content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
+ self._finalize_output(request, full_text, last_response)
+ )
+ yield backend_pb2.Reply(
+ message=b"",
+ prompt_tokens=prompt_tokens,
+ tokens=completion_tokens,
+ logprobs=logprobs_bytes,
+ chat_deltas=[
+ backend_pb2.ChatDelta(
+ content="",
+ reasoning_content=reasoning_content,
+ tool_calls=tool_calls_proto,
+ )
+ ],
+ )
except Exception as e:
print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr)
@@ -335,12 +380,74 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
context.set_details("Embeddings are not supported in the MLX distributed backend.")
return backend_pb2.EmbeddingResult()
+ async def TokenizeString(self, request, context):
+ if not hasattr(self, "tokenizer") or self.tokenizer is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("tokenizer not loaded")
+ return backend_pb2.TokenizationResponse()
+ try:
+ tokens = self.tokenizer.encode(request.Prompt)
+ if hasattr(tokens, "tolist"):
+ tokens = tokens.tolist()
+ tokens = list(tokens)
+ return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
+ except Exception as e:
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(str(e))
+ return backend_pb2.TokenizationResponse()
+
+ async def Free(self, request, context):
+ try:
+ # If we're rank 0 of a distributed run, tell workers to shut
+ # down their per-request loops first so they release the model.
+ if self.coordinator is not None:
+ try:
+ from coordinator import CMD_SHUTDOWN
+ self.coordinator.broadcast_command(CMD_SHUTDOWN)
+ except Exception as e:
+ print(f"[Rank 0] failed to broadcast shutdown: {e}", file=sys.stderr)
+ if hasattr(self, "model"):
+ del self.model
+ if hasattr(self, "tokenizer"):
+ del self.tokenizer
+ if self.lru_cache is not None:
+ try:
+ self.lru_cache.clear()
+ except Exception:
+ pass
+ self.lru_cache = None
+ self.coordinator = None
+ self.group = None
+ gc.collect()
+ try:
+ import mlx.core as mx # type: ignore
+ if hasattr(mx, "clear_cache"):
+ mx.clear_cache()
+ elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
+ mx.metal.clear_cache()
+ except Exception:
+ pass
+ return backend_pb2.Result(success=True, message="MLX distributed model freed")
+ except Exception as e:
+ return backend_pb2.Result(success=False, message=str(e))
+
def _prepare_prompt(self, request):
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
- messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages]
- return self.tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
+ messages = messages_to_dicts(request.Messages)
+ kwargs = {"tokenize": False, "add_generation_prompt": True}
+ if request.Tools:
+ try:
+ kwargs["tools"] = json.loads(request.Tools)
+ except json.JSONDecodeError:
+ pass
+ if request.Metadata.get("enable_thinking", "").lower() == "true":
+ kwargs["enable_thinking"] = True
+ try:
+ return self.tokenizer.apply_chat_template(messages, **kwargs)
+ except TypeError:
+ return self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
return request.Prompt
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
@@ -349,6 +456,82 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return tokens.tolist()
return list(tokens)
+ def _tool_module_from_tokenizer(self):
+ """Same shim as the mlx backend: fall back to json.loads when the
+ installed mlx-lm doesn't expose a tool_parser callable on the
+ wrapper (true on 0.29.x — only HEAD ships parsers)."""
+ start = getattr(self.tokenizer, "tool_call_start", None)
+ end = getattr(self.tokenizer, "tool_call_end", None)
+ if not start:
+ return None
+ parse_fn = getattr(self.tokenizer, "tool_parser", None)
+ if parse_fn is None:
+ def parse_fn(body, tools): # noqa: E306
+ return json.loads(body.strip())
+ return types.SimpleNamespace(
+ tool_call_start=start,
+ tool_call_end=end or "",
+ parse_tool_call=parse_fn,
+ )
+
+ def _truncate_at_stop(self, text, stop_words):
+ if not stop_words:
+ return text
+ earliest = len(text)
+ for stop in stop_words:
+ if not stop:
+ continue
+ idx = text.find(stop)
+ if idx >= 0 and idx < earliest:
+ earliest = idx
+ return text[:earliest] if earliest < len(text) else text
+
+ def _finalize_output(self, request, generated_text, last_response):
+ content = generated_text
+ reasoning_content = ""
+ if getattr(self.tokenizer, "has_thinking", False):
+ think_start = getattr(self.tokenizer, "think_start", "") or ""
+ think_end = getattr(self.tokenizer, "think_end", "") or ""
+ reasoning_content, content = split_reasoning(content, think_start, think_end)
+
+ tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
+ tool_module = None
+ if getattr(self.tokenizer, "has_tool_calling", False):
+ tool_module = self._tool_module_from_tokenizer()
+ if tool_module is not None:
+ parsed_tools = None
+ if request.Tools:
+ try:
+ parsed_tools = json.loads(request.Tools)
+ except json.JSONDecodeError:
+ parsed_tools = None
+ calls, content = parse_tool_calls(content, tool_module, parsed_tools)
+ for c in calls:
+ tool_calls_proto.append(
+ backend_pb2.ToolCallDelta(
+ index=c["index"], id=c["id"], name=c["name"], arguments=c["arguments"],
+ )
+ )
+
+ prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
+ completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
+
+ logprobs_bytes = b""
+ if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
+ try:
+ lp = getattr(last_response, "logprobs", None)
+ if lp is not None:
+ token_id = int(getattr(last_response, "token", 0) or 0)
+ token_text = self.tokenizer.decode([token_id]) if token_id else ""
+ top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
+ logprobs_bytes = json.dumps(
+ {"content": [{"token": token_text, "logprob": top_logprob}]}
+ ).encode("utf-8")
+ except Exception as e:
+ print(f"[Rank 0] Logprobs extraction failed: {e}", file=sys.stderr)
+
+ return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes
+
def _build_generation_params(self, request, default_max_tokens=200):
import mlx.core as mx
@@ -373,6 +556,22 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
'xtc_probability': 0.0,
}
+ # Logits processor parameters — pulled from the request and
+ # forwarded to make_logits_processors. Rank 0 is the only rank
+ # running the sampler so we don't need to broadcast these to
+ # workers (workers participate in the pipeline-parallel forward
+ # pass only).
+ logits_params = {}
+ repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
+ if repetition_penalty and repetition_penalty != 1.0:
+ logits_params['repetition_penalty'] = repetition_penalty
+ presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
+ if presence_penalty:
+ logits_params['presence_penalty'] = presence_penalty
+ frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
+ if frequency_penalty:
+ logits_params['frequency_penalty'] = frequency_penalty
+
seed = getattr(request, 'Seed', 0)
if seed != 0:
mx.random.seed(seed)
@@ -392,9 +591,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
for opt_key, param_key in option_mapping.items():
if opt_key in self.options:
sampler_params[param_key] = self.options[opt_key]
+ for opt_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
+ if opt_key in self.options:
+ logits_params[opt_key] = self.options[opt_key]
if 'seed' in self.options:
mx.random.seed(self.options['seed'])
+ stop_words = list(getattr(request, 'StopPrompts', []) or [])
+ return max_tokens, sampler_params, logits_params, stop_words
+
# XTC special tokens
xtc_special_tokens = []
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
diff --git a/backend/python/mlx-distributed/test.py b/backend/python/mlx-distributed/test.py
index 4cb1440ed..81bf6c67c 100644
--- a/backend/python/mlx-distributed/test.py
+++ b/backend/python/mlx-distributed/test.py
@@ -1,3 +1,6 @@
+import os
+import sys
+import types
import unittest
import subprocess
import time
@@ -6,6 +9,12 @@ import grpc
import backend_pb2
import backend_pb2_grpc
+# Make the shared helpers importable so we can unit-test them without a
+# running gRPC server.
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
+from python_utils import messages_to_dicts, parse_options
+from mlx_utils import parse_tool_calls, split_reasoning
+
class TestBackendServicer(unittest.TestCase):
def setUp(self):
@@ -85,3 +94,44 @@ class TestBackendServicer(unittest.TestCase):
self.fail("sampling params service failed")
finally:
self.tearDown()
+
+
+class TestSharedHelpers(unittest.TestCase):
+ """Server-less unit tests for the helpers the mlx-distributed backend depends on."""
+
+ def test_parse_options_typed(self):
+ opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true"])
+ self.assertEqual(opts["temperature"], 0.7)
+ self.assertEqual(opts["max_tokens"], 128)
+ self.assertIs(opts["trust"], True)
+
+ def test_messages_to_dicts_roundtrip(self):
+ msgs = [
+ backend_pb2.Message(role="user", content="hi"),
+ backend_pb2.Message(
+ role="assistant",
+ content="",
+ tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
+ ),
+ backend_pb2.Message(role="tool", content="42", tool_call_id="call_1", name="f"),
+ ]
+ out = messages_to_dicts(msgs)
+ self.assertEqual(out[0], {"role": "user", "content": "hi"})
+ self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
+ self.assertEqual(out[2]["tool_call_id"], "call_1")
+
+ def test_split_reasoning(self):
+ r, c = split_reasoning("planfinal", "", "")
+ self.assertEqual(r, "plan")
+ self.assertEqual(c, "final")
+
+ def test_parse_tool_calls_with_shim(self):
+ tm = types.SimpleNamespace(
+ tool_call_start="",
+ tool_call_end="",
+ parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
+ )
+ calls, remaining = parse_tool_calls("Paris", tm, tools=None)
+ self.assertEqual(len(calls), 1)
+ self.assertEqual(calls[0]["name"], "get_weather")
+ self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
diff --git a/backend/python/mlx-vlm/backend.py b/backend/python/mlx-vlm/backend.py
index 578a5e563..074214ece 100644
--- a/backend/python/mlx-vlm/backend.py
+++ b/backend/python/mlx-vlm/backend.py
@@ -2,11 +2,14 @@
import asyncio
from concurrent import futures
import argparse
+import gc
+import json
import signal
import sys
import os
+import tempfile
+import types
from typing import List
-import time
import backend_pb2
import backend_pb2_grpc
@@ -15,30 +18,18 @@ import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
+from python_utils import messages_to_dicts, parse_options
+from mlx_utils import parse_tool_calls, split_reasoning
-from mlx_vlm import load, generate, stream_generate
+from mlx_vlm import load, stream_generate
from mlx_vlm.prompt_utils import apply_chat_template
-from mlx_vlm.utils import load_config, load_image
+from mlx_vlm.tool_parsers import _infer_tool_parser, load_tool_module
+from mlx_vlm.utils import load_config
+from mlx_lm.sample_utils import make_logits_processors, make_sampler
import mlx.core as mx
import base64
import io
from PIL import Image
-import tempfile
-
-def is_float(s):
- """Check if a string can be converted to float."""
- try:
- float(s)
- return True
- except ValueError:
- return False
-def is_int(s):
- """Check if a string can be converted to int."""
- try:
- int(s)
- return True
- except ValueError:
- return False
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -78,36 +69,52 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr)
print(f"Request: {request}", file=sys.stderr)
-
- # Parse options like in the diffusers backend
- options = request.Options
- self.options = {}
-
- # The options are a list of strings in this form optname:optvalue
- # We store all the options in a dict for later use
- for opt in options:
- if ":" not in opt:
- continue
- key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
-
- if is_float(value):
- value = float(value)
- elif is_int(value):
- value = int(value)
- elif value.lower() in ["true", "false"]:
- value = value.lower() == "true"
-
- self.options[key] = value
-
+
+ # Parse Options[] key:value strings into a typed dict
+ self.options = parse_options(request.Options)
print(f"Options: {self.options}", file=sys.stderr)
-
+
# Load model and processor using MLX-VLM
# mlx-vlm load function returns (model, processor) instead of (model, tokenizer)
self.model, self.processor = load(request.Model)
-
+
# Load model config for chat template support
self.config = load_config(request.Model)
-
+
+ # Auto-infer the tool parser from the chat template. mlx-vlm has
+ # its own _infer_tool_parser that falls back to mlx-lm parsers.
+ tokenizer = (
+ self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
+ )
+ self.tool_module = None
+ if hasattr(tokenizer, "chat_template"):
+ try:
+ parser_type = _infer_tool_parser(tokenizer.chat_template)
+ if parser_type is not None:
+ self.tool_module = load_tool_module(parser_type)
+ print(
+ f"[mlx-vlm] auto-detected tool parser: {parser_type}",
+ file=sys.stderr,
+ )
+ else:
+ print(
+ "[mlx-vlm] no tool parser matched the chat template",
+ file=sys.stderr,
+ )
+ except Exception as e:
+ print(
+ f"[mlx-vlm] failed to load tool parser: {e}",
+ file=sys.stderr,
+ )
+
+ # Reasoning tokens — check if the tokenizer advertises thinking
+ # markers. Fall back to empty strings (split_reasoning no-ops).
+ self.think_start = getattr(tokenizer, "think_start", "") or ""
+ self.think_end = getattr(tokenizer, "think_end", "") or ""
+ self.has_thinking = bool(
+ getattr(tokenizer, "has_thinking", False) or self.think_start
+ )
+
except Exception as err:
print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}")
@@ -128,63 +135,72 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
temp_files = []
try:
- # Process images and audios from request
- image_paths = []
- audio_paths = []
-
- # Process images
- if request.Images:
- for img_data in request.Images:
- img_path = self.load_image_from_base64(img_data)
- if img_path:
- image_paths.append(img_path)
- temp_files.append(img_path)
-
- # Process audios
- if request.Audios:
- for audio_data in request.Audios:
- audio_path = self.load_audio_from_base64(audio_data)
- if audio_path:
- audio_paths.append(audio_path)
- temp_files.append(audio_path)
-
- # Prepare the prompt with multimodal information
- prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
-
- # Build generation parameters using request attributes and options
- max_tokens, generation_params = self._build_generation_params(request)
-
- print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
- print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
-
- # Generate text using MLX-VLM with multimodal inputs
- response = generate(
+ image_paths, audio_paths = self._collect_media(request, temp_files)
+
+ prompt = self._prepare_prompt(
+ request,
+ num_images=len(image_paths),
+ num_audios=len(audio_paths),
+ )
+
+ max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
+ sampler = make_sampler(**sampler_params)
+ logits_processors = make_logits_processors(**logits_params) if logits_params else None
+
+ print(
+ f"Generating text with MLX-VLM - max_tokens: {max_tokens}, "
+ f"images: {len(image_paths)}, audios: {len(audio_paths)}",
+ file=sys.stderr,
+ )
+
+ accumulated = []
+ last_response = None
+ for response in stream_generate(
model=self.model,
processor=self.processor,
prompt=prompt,
image=image_paths if image_paths else None,
audio=audio_paths if audio_paths else None,
max_tokens=max_tokens,
- temperature=generation_params.get('temp', 0.6),
- top_p=generation_params.get('top_p', 1.0),
- verbose=False
+ sampler=sampler,
+ logits_processors=logits_processors,
+ ):
+ accumulated.append(response.text)
+ last_response = response
+ if stop_words and any(s in "".join(accumulated) for s in stop_words):
+ break
+
+ full_text = self._truncate_at_stop("".join(accumulated), stop_words)
+ content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
+ self._finalize_output(request, full_text, last_response)
)
-
- return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
-
+
+ return backend_pb2.Reply(
+ message=bytes(content, encoding='utf-8'),
+ prompt_tokens=prompt_tokens,
+ tokens=completion_tokens,
+ logprobs=logprobs_bytes,
+ chat_deltas=[
+ backend_pb2.ChatDelta(
+ content=content,
+ reasoning_content=reasoning_content,
+ tool_calls=tool_calls_proto,
+ )
+ ],
+ )
+
except Exception as e:
print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Generation failed: {str(e)}")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
finally:
- # Clean up temporary files
self.cleanup_temp_files(temp_files)
def Embedding(self, request, context):
"""
A gRPC method that calculates embeddings for a given sentence.
-
+
Note: MLX-VLM doesn't support embeddings directly. This method returns an error.
Args:
@@ -199,6 +215,79 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
context.set_details("Embeddings are not supported in the MLX-VLM backend.")
return backend_pb2.EmbeddingResult()
+ def _collect_media(self, request, temp_files):
+ """Decode base64 Images and Audios into temp file paths.
+
+ Appends every temp file to ``temp_files`` so the finally block can
+ clean up even on mid-generation errors.
+ """
+ image_paths = []
+ audio_paths = []
+ if request.Images:
+ for img_data in request.Images:
+ img_path = self.load_image_from_base64(img_data)
+ if img_path:
+ image_paths.append(img_path)
+ temp_files.append(img_path)
+ if request.Audios:
+ for audio_data in request.Audios:
+ audio_path = self.load_audio_from_base64(audio_data)
+ if audio_path:
+ audio_paths.append(audio_path)
+ temp_files.append(audio_path)
+ return image_paths, audio_paths
+
+ async def TokenizeString(self, request, context):
+ """Tokenize ``request.Prompt`` via the processor's tokenizer."""
+ if not hasattr(self, "processor") or self.processor is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("processor not loaded")
+ return backend_pb2.TokenizationResponse()
+ try:
+ tokenizer = (
+ self.processor.tokenizer
+ if hasattr(self.processor, "tokenizer")
+ else self.processor
+ )
+ tokens = tokenizer.encode(request.Prompt)
+ if hasattr(tokens, "tolist"):
+ tokens = tokens.tolist()
+ tokens = list(tokens)
+ return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
+ except Exception as e:
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(str(e))
+ return backend_pb2.TokenizationResponse()
+
+ async def Free(self, request, context):
+ """Drop the loaded model, processor and tool module."""
+ try:
+ if hasattr(self, "model"):
+ del self.model
+ if hasattr(self, "processor"):
+ del self.processor
+ if hasattr(self, "config"):
+ del self.config
+ self.tool_module = None
+ gc.collect()
+ # mlx.clear_cache (mlx >= 0.30) supersedes mlx.metal.clear_cache.
+ try:
+ if hasattr(mx, "clear_cache"):
+ mx.clear_cache()
+ elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
+ mx.metal.clear_cache()
+ except Exception:
+ pass
+ try:
+ import torch # type: ignore
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ except Exception:
+ pass
+ return backend_pb2.Result(success=True, message="MLX-VLM model freed")
+ except Exception as e:
+ return backend_pb2.Result(success=False, message=str(e))
+
async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support.
@@ -212,36 +301,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
temp_files = []
try:
- # Process images and audios from request
- image_paths = []
- audio_paths = []
-
- # Process images
- if request.Images:
- for img_data in request.Images:
- img_path = self.load_image_from_base64(img_data)
- if img_path:
- image_paths.append(img_path)
- temp_files.append(img_path)
-
- # Process audios
- if request.Audios:
- for audio_data in request.Audios:
- audio_path = self.load_audio_from_base64(audio_data)
- if audio_path:
- audio_paths.append(audio_path)
- temp_files.append(audio_path)
-
- # Prepare the prompt with multimodal information
- prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
-
- # Build generation parameters using request attributes and options
- max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512)
-
- print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
- print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
-
- # Stream text generation using MLX-VLM with multimodal inputs
+ image_paths, audio_paths = self._collect_media(request, temp_files)
+
+ prompt = self._prepare_prompt(
+ request,
+ num_images=len(image_paths),
+ num_audios=len(audio_paths),
+ )
+
+ max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
+ request, default_max_tokens=512
+ )
+ sampler = make_sampler(**sampler_params)
+ logits_processors = make_logits_processors(**logits_params) if logits_params else None
+
+ print(
+ f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, "
+ f"images: {len(image_paths)}, audios: {len(audio_paths)}",
+ file=sys.stderr,
+ )
+
+ accumulated = []
+ last_response = None
for response in stream_generate(
model=self.model,
processor=self.processor,
@@ -249,77 +330,91 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
image=image_paths if image_paths else None,
audio=audio_paths if audio_paths else None,
max_tokens=max_tokens,
- temperature=generation_params.get('temp', 0.6),
- top_p=generation_params.get('top_p', 1.0),
+ sampler=sampler,
+ logits_processors=logits_processors,
):
- yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
-
+ accumulated.append(response.text)
+ last_response = response
+ yield backend_pb2.Reply(
+ message=bytes(response.text, encoding='utf-8'),
+ chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
+ )
+ if stop_words and any(s in "".join(accumulated) for s in stop_words):
+ break
+
+ full_text = self._truncate_at_stop("".join(accumulated), stop_words)
+ content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
+ self._finalize_output(request, full_text, last_response)
+ )
+ yield backend_pb2.Reply(
+ message=b"",
+ prompt_tokens=prompt_tokens,
+ tokens=completion_tokens,
+ logprobs=logprobs_bytes,
+ chat_deltas=[
+ backend_pb2.ChatDelta(
+ content="",
+ reasoning_content=reasoning_content,
+ tool_calls=tool_calls_proto,
+ )
+ ],
+ )
+
except Exception as e:
print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Streaming generation failed: {str(e)}")
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
finally:
- # Clean up temporary files
self.cleanup_temp_files(temp_files)
+ def _build_template_kwargs(self, request, num_images, num_audios):
+ """Collect kwargs for ``apply_chat_template`` that survive model variants."""
+ kwargs = {"num_images": num_images, "num_audios": num_audios}
+ if request.Tools:
+ try:
+ kwargs["tools"] = json.loads(request.Tools)
+ except json.JSONDecodeError:
+ pass
+ if request.Metadata.get("enable_thinking", "").lower() == "true":
+ kwargs["enable_thinking"] = True
+ return kwargs
+
+ def _apply_template(self, request, messages, num_images, num_audios):
+ kwargs = self._build_template_kwargs(request, num_images, num_audios)
+ try:
+ return apply_chat_template(self.processor, self.config, messages, **kwargs)
+ except TypeError:
+ # Fallback for older mlx-vlm versions that reject tools=/enable_thinking=
+ return apply_chat_template(
+ self.processor,
+ self.config,
+ messages,
+ num_images=num_images,
+ num_audios=num_audios,
+ )
+
def _prepare_prompt(self, request, num_images=0, num_audios=0):
"""
- Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs.
-
- Args:
- request: The gRPC request containing prompt and message information.
- num_images: Number of images in the request.
- num_audios: Number of audio files in the request.
-
- Returns:
- str: The prepared prompt.
+ Prepare the prompt for MLX-VLM generation, handling chat templates and
+ multimodal inputs. Forwards tool definitions and enable_thinking when
+ present on the request.
"""
- # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
- # Convert gRPC messages to the format expected by apply_chat_template
- messages = []
- for msg in request.Messages:
- messages.append({"role": msg.role, "content": msg.content})
-
- # Use mlx-vlm's apply_chat_template which handles multimodal inputs
- prompt = apply_chat_template(
- self.processor,
- self.config,
- messages,
- num_images=num_images,
- num_audios=num_audios
- )
- return prompt
- elif request.Prompt:
- # If we have a direct prompt but also have images/audio, we need to format it properly
+ messages = messages_to_dicts(request.Messages)
+ return self._apply_template(request, messages, num_images, num_audios)
+
+ if request.Prompt:
if num_images > 0 or num_audios > 0:
- # Create a simple message structure for multimodal prompt
messages = [{"role": "user", "content": request.Prompt}]
- prompt = apply_chat_template(
- self.processor,
- self.config,
- messages,
- num_images=num_images,
- num_audios=num_audios
- )
- return prompt
- else:
- return request.Prompt
- else:
- # Fallback to empty prompt with multimodal template if we have media
- if num_images > 0 or num_audios > 0:
- messages = [{"role": "user", "content": ""}]
- prompt = apply_chat_template(
- self.processor,
- self.config,
- messages,
- num_images=num_images,
- num_audios=num_audios
- )
- return prompt
- else:
- return ""
+ return self._apply_template(request, messages, num_images, num_audios)
+ return request.Prompt
+
+ # Fallback to empty prompt with multimodal template if we have media
+ if num_images > 0 or num_audios > 0:
+ messages = [{"role": "user", "content": ""}]
+ return self._apply_template(request, messages, num_images, num_audios)
+ return ""
@@ -327,62 +422,122 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def _build_generation_params(self, request, default_max_tokens=200):
"""
- Build generation parameters from request attributes and options for MLX-VLM.
-
- Args:
- request: The gRPC request.
- default_max_tokens: Default max_tokens if not specified.
+ Build generation parameters from request attributes and options.
Returns:
- tuple: (max_tokens, generation_params dict)
+ tuple: (max_tokens, sampler_params, logits_params, stop_words)
"""
- # Extract max_tokens
- max_tokens = getattr(request, 'Tokens', default_max_tokens)
- if max_tokens == 0:
- max_tokens = default_max_tokens
-
- # Extract generation parameters from request attributes
- temp = getattr(request, 'Temperature', 0.0)
- if temp == 0.0:
- temp = 0.6 # Default temperature
-
- top_p = getattr(request, 'TopP', 0.0)
- if top_p == 0.0:
- top_p = 1.0 # Default top_p
-
- # Initialize generation parameters for MLX-VLM
- generation_params = {
+ max_tokens = getattr(request, 'Tokens', default_max_tokens) or default_max_tokens
+
+ temp = getattr(request, 'Temperature', 0.0) or 0.6
+ top_p = getattr(request, 'TopP', 0.0) or 1.0
+ min_p = getattr(request, 'MinP', 0.0) or 0.0
+ top_k = getattr(request, 'TopK', 0) or 0
+
+ sampler_params = {
'temp': temp,
'top_p': top_p,
+ 'min_p': min_p,
+ 'top_k': top_k,
}
-
- # Add seed if specified
+
+ logits_params = {}
+ repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
+ if repetition_penalty and repetition_penalty != 1.0:
+ logits_params['repetition_penalty'] = repetition_penalty
+ presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
+ if presence_penalty:
+ logits_params['presence_penalty'] = presence_penalty
+ frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
+ if frequency_penalty:
+ logits_params['frequency_penalty'] = frequency_penalty
+
seed = getattr(request, 'Seed', 0)
if seed != 0:
mx.random.seed(seed)
-
- # Override with options if available
+
if hasattr(self, 'options'):
- # Max tokens from options
if 'max_tokens' in self.options:
max_tokens = self.options['max_tokens']
-
- # Generation parameters from options
- param_option_mapping = {
- 'temp': 'temp',
- 'temperature': 'temp', # alias
- 'top_p': 'top_p',
+ option_mapping = {
+ 'temp': 'temp', 'temperature': 'temp',
+ 'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k',
}
-
- for option_key, param_key in param_option_mapping.items():
+ for option_key, param_key in option_mapping.items():
if option_key in self.options:
- generation_params[param_key] = self.options[option_key]
-
- # Handle seed from options
+ sampler_params[param_key] = self.options[option_key]
+ for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
+ if option_key in self.options:
+ logits_params[option_key] = self.options[option_key]
if 'seed' in self.options:
mx.random.seed(self.options['seed'])
-
- return max_tokens, generation_params
+
+ stop_words = list(getattr(request, 'StopPrompts', []) or [])
+ return max_tokens, sampler_params, logits_params, stop_words
+
+ def _finalize_output(self, request, generated_text, last_response):
+ """Split reasoning + tool calls out of generated_text and return the
+ tuple consumed by Reply-builders."""
+ content = generated_text
+ reasoning_content = ""
+
+ if getattr(self, "has_thinking", False):
+ reasoning_content, content = split_reasoning(content, self.think_start, self.think_end)
+
+ tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
+ if self.tool_module is not None:
+ parsed_tools = None
+ if request.Tools:
+ try:
+ parsed_tools = json.loads(request.Tools)
+ except json.JSONDecodeError:
+ parsed_tools = None
+ calls, content = parse_tool_calls(content, self.tool_module, parsed_tools)
+ for c in calls:
+ tool_calls_proto.append(
+ backend_pb2.ToolCallDelta(
+ index=c["index"],
+ id=c["id"],
+ name=c["name"],
+ arguments=c["arguments"],
+ )
+ )
+
+ prompt_tokens = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
+ completion_tokens = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
+
+ logprobs_bytes = b""
+ if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
+ try:
+ lp = getattr(last_response, "logprobs", None)
+ if lp is not None:
+ token_id = int(getattr(last_response, "token", 0) or 0)
+ tokenizer = (
+ self.processor.tokenizer
+ if hasattr(self.processor, "tokenizer")
+ else self.processor
+ )
+ token_text = tokenizer.decode([token_id]) if token_id else ""
+ top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
+ logprobs_bytes = json.dumps(
+ {"content": [{"token": token_text, "logprob": top_logprob}]}
+ ).encode("utf-8")
+ except Exception as e:
+ print(f"[mlx-vlm] Logprobs extraction failed: {e}", file=sys.stderr)
+
+ return content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes
+
+ def _truncate_at_stop(self, text, stop_words):
+ if not stop_words:
+ return text
+ earliest = len(text)
+ for stop in stop_words:
+ if not stop:
+ continue
+ idx = text.find(stop)
+ if idx >= 0 and idx < earliest:
+ earliest = idx
+ return text[:earliest] if earliest < len(text) else text
def load_image_from_base64(self, image_data: str):
"""
diff --git a/backend/python/mlx-vlm/test.py b/backend/python/mlx-vlm/test.py
index 827aa71a3..96cef3fac 100644
--- a/backend/python/mlx-vlm/test.py
+++ b/backend/python/mlx-vlm/test.py
@@ -1,17 +1,19 @@
+import os
+import sys
+import types
import unittest
import subprocess
import time
+
+import grpc
import backend_pb2
import backend_pb2_grpc
-import grpc
-
-import unittest
-import subprocess
-import time
-import grpc
-import backend_pb2_grpc
-import backend_pb2
+# Make the shared helpers importable so we can unit-test them without a
+# running gRPC server.
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
+from python_utils import messages_to_dicts, parse_options
+from mlx_utils import parse_tool_calls, split_reasoning
class TestBackendServicer(unittest.TestCase):
"""
@@ -143,4 +145,55 @@ class TestBackendServicer(unittest.TestCase):
print(err)
self.fail("Embedding service failed")
finally:
- self.tearDown()
\ No newline at end of file
+ self.tearDown()
+
+
+class TestSharedHelpers(unittest.TestCase):
+ """Server-less unit tests for the helpers the mlx-vlm backend depends on."""
+
+ def test_parse_options_typed(self):
+ opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello"])
+ self.assertEqual(opts["temperature"], 0.7)
+ self.assertEqual(opts["max_tokens"], 128)
+ self.assertIs(opts["trust"], True)
+ self.assertEqual(opts["name"], "hello")
+
+ def test_messages_to_dicts_roundtrip(self):
+ msgs = [
+ backend_pb2.Message(role="user", content="hi"),
+ backend_pb2.Message(
+ role="assistant",
+ content="",
+ tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
+ ),
+ backend_pb2.Message(
+ role="tool",
+ content="42",
+ tool_call_id="call_1",
+ name="f",
+ ),
+ ]
+ out = messages_to_dicts(msgs)
+ self.assertEqual(out[0], {"role": "user", "content": "hi"})
+ self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
+ self.assertEqual(out[2]["tool_call_id"], "call_1")
+
+ def test_split_reasoning(self):
+ r, c = split_reasoning("planfinal", "", "")
+ self.assertEqual(r, "plan")
+ self.assertEqual(c, "final")
+
+ def test_parse_tool_calls_with_shim(self):
+ tm = types.SimpleNamespace(
+ tool_call_start="",
+ tool_call_end="",
+ parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
+ )
+ calls, remaining = parse_tool_calls(
+ "Paris",
+ tm,
+ tools=None,
+ )
+ self.assertEqual(len(calls), 1)
+ self.assertEqual(calls[0]["name"], "get_weather")
+ self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
\ No newline at end of file
diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py
index 1a41020f5..a71da522c 100644
--- a/backend/python/mlx/backend.py
+++ b/backend/python/mlx/backend.py
@@ -2,11 +2,13 @@
import asyncio
from concurrent import futures
import argparse
+import gc
+import json
import signal
import sys
import os
+import types
from typing import List
-import time
import backend_pb2
import backend_pb2_grpc
@@ -15,13 +17,13 @@ import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
+from python_utils import messages_to_dicts, parse_options
+from mlx_utils import parse_tool_calls, split_reasoning
-from mlx_lm import load, generate, stream_generate
-from mlx_lm.sample_utils import make_sampler
+from mlx_lm import load, stream_generate
+from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
import mlx.core as mx
-import base64
-import io
from mlx_cache import ThreadSafeLRUPromptCache
@@ -30,21 +32,6 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
-def is_float(s):
- """Check if a string can be converted to float."""
- try:
- float(s)
- return True
- except ValueError:
- return False
-def is_int(s):
- """Check if a string can be converted to int."""
- try:
- int(s)
- return True
- except ValueError:
- return False
-
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
@@ -78,46 +65,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
print(f"Loading MLX model: {request.Model}", file=sys.stderr)
print(f"Request: {request}", file=sys.stderr)
-
- # Parse options like in the diffusers backend
- options = request.Options
- self.options = {}
-
- # The options are a list of strings in this form optname:optvalue
- # We store all the options in a dict for later use
- for opt in options:
- if ":" not in opt:
- continue
- key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
-
- # Convert numeric values to appropriate types
- if is_float(value):
- value = float(value)
- elif is_int(value):
- value = int(value)
- elif value.lower() in ["true", "false"]:
- value = value.lower() == "true"
-
- self.options[key] = value
-
+
+ # Parse Options[] key:value strings into a typed dict (shared helper)
+ self.options = parse_options(request.Options)
print(f"Options: {self.options}", file=sys.stderr)
-
+
# Build tokenizer config for MLX using options
tokenizer_config = {}
-
+
# Handle trust_remote_code from request or options
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
tokenizer_config["trust_remote_code"] = True
-
+
# Handle EOS token from options
if "eos_token" in self.options:
tokenizer_config["eos_token"] = self.options["eos_token"]
-
+
# Handle other tokenizer config options
for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]:
if key in self.options:
tokenizer_config[key] = self.options[key]
-
+
# Load model and tokenizer using MLX
if tokenizer_config:
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
@@ -125,6 +93,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
else:
self.model, self.tokenizer = load(request.Model)
+ # mlx_lm.load() returns a TokenizerWrapper that detects tool
+ # calling and thinking markers from the chat template / vocab.
+ # mlx-lm >= 0.30 also exposes a parser callable on the wrapper;
+ # earlier versions don't (we fall back to json.loads inside
+ # _tool_module_from_tokenizer below).
+ has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False))
+ has_thinking = bool(getattr(self.tokenizer, "has_thinking", False))
+ tcs = getattr(self.tokenizer, "tool_call_start", None)
+ tce = getattr(self.tokenizer, "tool_call_end", None)
+ print(
+ f"MLX tokenizer capabilities: has_tool_calling={has_tools} "
+ f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}",
+ file=sys.stderr,
+ )
+
# Initialize thread-safe LRU prompt cache for efficient generation
max_cache_entries = self.options.get("max_cache_entries", 10)
self.max_kv_size = self.options.get("max_kv_size", None)
@@ -134,7 +117,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
can_trim_fn=can_trim_prompt_cache,
trim_fn=trim_prompt_cache,
)
-
+
except Exception as err:
print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}")
@@ -172,30 +155,58 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
remaining_tokens = cache_key
# Build generation parameters using request attributes and options
- max_tokens, sampler_params = self._build_generation_params(request)
+ max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
- print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
+ print(
+ f"Generating text with MLX - max_tokens: {max_tokens}, "
+ f"cache_hit: {len(remaining_tokens) < len(cache_key)}",
+ file=sys.stderr,
+ )
- # Create sampler with parameters
+ # Create sampler and optional logits processors (penalties)
sampler = make_sampler(**sampler_params)
+ logits_processors = make_logits_processors(**logits_params) if logits_params else None
- # Use stream_generate to track generated tokens for cache key
+ # Use stream_generate to collect text + track tokens for cache key
generated_text = []
+ last_response = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=remaining_tokens if remaining_tokens else cache_key,
max_tokens=max_tokens,
sampler=sampler,
+ logits_processors=logits_processors,
prompt_cache=prompt_cache,
):
generated_text.append(response.text)
cache_key.append(response.token)
+ last_response = response
+ # Early stop on user-provided stop sequences
+ if stop_words and any(s in "".join(generated_text) for s in stop_words):
+ break
# Insert completed cache
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
- return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8'))
+ full_text = self._truncate_at_stop("".join(generated_text), stop_words)
+ content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
+ self._finalize_output(request, full_text, last_response)
+ )
+
+ return backend_pb2.Reply(
+ message=bytes(content, encoding='utf-8'),
+ prompt_tokens=prompt_tokens,
+ tokens=completion_tokens,
+ logprobs=logprobs_bytes,
+ chat_deltas=[
+ backend_pb2.ChatDelta(
+ content=content,
+ reasoning_content=reasoning_content,
+ tool_calls=tool_calls_proto,
+ )
+ ],
+ )
except Exception as e:
print(f"Error in MLX Predict: {e}", file=sys.stderr)
@@ -206,7 +217,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def Embedding(self, request, context):
"""
A gRPC method that calculates embeddings for a given sentence.
-
+
Note: MLX-LM doesn't support embeddings directly. This method returns an error.
Args:
@@ -221,6 +232,62 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
context.set_details("Embeddings are not supported in the MLX backend.")
return backend_pb2.EmbeddingResult()
+ async def TokenizeString(self, request, context):
+ """Tokenize ``request.Prompt`` using the loaded model's tokenizer."""
+ if not hasattr(self, "tokenizer") or self.tokenizer is None:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ context.set_details("tokenizer not loaded")
+ return backend_pb2.TokenizationResponse()
+ try:
+ tokens = self.tokenizer.encode(request.Prompt)
+ if hasattr(tokens, "tolist"):
+ tokens = tokens.tolist()
+ tokens = list(tokens)
+ return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
+ except Exception as e:
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(str(e))
+ return backend_pb2.TokenizationResponse()
+
+ async def Free(self, request, context):
+ """Drop the loaded model, tokenizer and prompt cache.
+
+ Metal / CUDA memory is released via ``gc.collect()`` + the
+ platform-specific cache clear hooks when available.
+ """
+ try:
+ if hasattr(self, "model"):
+ del self.model
+ if hasattr(self, "tokenizer"):
+ del self.tokenizer
+ if hasattr(self, "lru_cache") and self.lru_cache is not None:
+ try:
+ self.lru_cache.clear()
+ except Exception:
+ pass
+ self.lru_cache = None
+ gc.collect()
+ # Metal: drop the cached allocator. mlx.clear_cache (mlx >= 0.30)
+ # supersedes the now-deprecated mlx.metal.clear_cache.
+ try:
+ if hasattr(mx, "clear_cache"):
+ mx.clear_cache()
+ elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
+ mx.metal.clear_cache()
+ except Exception:
+ pass
+ # CUDA: release the torch cache if a CUDA-backed mlx variant
+ # happens to be installed alongside torch (best-effort).
+ try:
+ import torch # type: ignore
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ except Exception:
+ pass
+ return backend_pb2.Result(success=True, message="MLX model freed")
+ except Exception as e:
+ return backend_pb2.Result(success=False, message=str(e))
+
async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results using MLX.
@@ -251,24 +318,64 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
remaining_tokens = cache_key
# Build generation parameters using request attributes and options
- max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
+ max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
+ request, default_max_tokens=512
+ )
- print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
+ print(
+ f"Streaming text with MLX - max_tokens: {max_tokens}, "
+ f"cache_hit: {len(remaining_tokens) < len(cache_key)}",
+ file=sys.stderr,
+ )
- # Create sampler with parameters
+ # Create sampler and optional logits processors (penalties)
sampler = make_sampler(**sampler_params)
+ logits_processors = make_logits_processors(**logits_params) if logits_params else None
- # Stream text generation using MLX with proper parameters
+ accumulated = []
+ last_response = None
for response in stream_generate(
self.model,
self.tokenizer,
prompt=remaining_tokens if remaining_tokens else cache_key,
max_tokens=max_tokens,
sampler=sampler,
+ logits_processors=logits_processors,
prompt_cache=prompt_cache,
):
cache_key.append(response.token)
- yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
+ accumulated.append(response.text)
+ last_response = response
+ # Emit a content delta. Structured reasoning / tool parsing
+ # happens on the final chunk so we don't fragment the state
+ # machine in v1.
+ yield backend_pb2.Reply(
+ message=bytes(response.text, encoding='utf-8'),
+ chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
+ )
+ # Early stop on user-provided stop sequences
+ if stop_words and any(s in "".join(accumulated) for s in stop_words):
+ break
+
+ # Final chunk: run reasoning + tool parsing on accumulated text
+ # and emit the structured ChatDelta with token counts + logprobs.
+ full_text = self._truncate_at_stop("".join(accumulated), stop_words)
+ content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
+ self._finalize_output(request, full_text, last_response)
+ )
+ yield backend_pb2.Reply(
+ message=b"",
+ prompt_tokens=prompt_tokens,
+ tokens=completion_tokens,
+ logprobs=logprobs_bytes,
+ chat_deltas=[
+ backend_pb2.ChatDelta(
+ content="",
+ reasoning_content=reasoning_content,
+ tool_calls=tool_calls_proto,
+ )
+ ],
+ )
except Exception as e:
print(f"Error in MLX PredictStream: {e}", file=sys.stderr)
@@ -294,21 +401,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
str: The prepared prompt.
"""
- # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
+ # If tokenizer template is enabled and messages are provided instead
+ # of prompt, apply the tokenizer template (forwards tool definitions
+ # and enable_thinking when the model supports them).
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
- # Convert gRPC messages to the format expected by apply_chat_template
- messages = []
- for msg in request.Messages:
- messages.append({"role": msg.role, "content": msg.content})
+ messages = messages_to_dicts(request.Messages)
- prompt = self.tokenizer.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=True
- )
- return prompt
- else:
- return request.Prompt
+ kwargs = {"tokenize": False, "add_generation_prompt": True}
+ if request.Tools:
+ try:
+ kwargs["tools"] = json.loads(request.Tools)
+ except json.JSONDecodeError:
+ pass
+ enable_thinking = request.Metadata.get("enable_thinking", "").lower()
+ if enable_thinking == "true":
+ kwargs["enable_thinking"] = True
+
+ try:
+ return self.tokenizer.apply_chat_template(messages, **kwargs)
+ except TypeError:
+ # Fallback for tokenizers whose template doesn't accept
+ # tools= or enable_thinking=.
+ return self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ return request.Prompt
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
"""
@@ -338,18 +457,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
default_max_tokens: Default max_tokens if not specified.
Returns:
- tuple: (max_tokens, sampler_params dict)
+ tuple: (max_tokens, sampler_params dict, logits_processor_params dict,
+ stop_words list)
"""
# Extract max_tokens
max_tokens = getattr(request, 'Tokens', default_max_tokens)
if max_tokens == 0:
max_tokens = default_max_tokens
-
+
# Extract sampler parameters from request attributes
temp = getattr(request, 'Temperature', 0.0)
if temp == 0.0:
temp = 0.6 # Default temperature
-
+
top_p = getattr(request, 'TopP', 0.0)
if top_p == 0.0:
top_p = 1.0 # Default top_p
@@ -369,18 +489,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
'xtc_threshold': 0.0,
'xtc_probability': 0.0,
}
-
+
+ # Logits processor parameters — only set fields the request actually
+ # provides so we can feed them unconditionally to make_logits_processors.
+ logits_params = {}
+ repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
+ if repetition_penalty and repetition_penalty != 1.0:
+ logits_params['repetition_penalty'] = repetition_penalty
+ presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
+ if presence_penalty:
+ logits_params['presence_penalty'] = presence_penalty
+ frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
+ if frequency_penalty:
+ logits_params['frequency_penalty'] = frequency_penalty
+
# Add seed if specified
seed = getattr(request, 'Seed', 0)
if seed != 0:
mx.random.seed(seed)
-
+
# Override with options if available
if hasattr(self, 'options'):
# Max tokens from options
if 'max_tokens' in self.options:
max_tokens = self.options['max_tokens']
-
+
# Sampler parameters from options
sampler_option_mapping = {
'temp': 'temp',
@@ -391,32 +524,142 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
'xtc_threshold': 'xtc_threshold',
'xtc_probability': 'xtc_probability',
}
-
+
for option_key, param_key in sampler_option_mapping.items():
if option_key in self.options:
sampler_params[param_key] = self.options[option_key]
-
+
+ # Logits processor overrides
+ for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
+ if option_key in self.options:
+ logits_params[option_key] = self.options[option_key]
+
# Handle seed from options
if 'seed' in self.options:
mx.random.seed(self.options['seed'])
-
+
# Special tokens for XTC sampling (if tokenizer has eos_token_ids)
xtc_special_tokens = []
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
xtc_special_tokens = [self.tokenizer.eos_token_id]
-
+
# Add newline token if available
try:
newline_tokens = self.tokenizer.encode("\n")
xtc_special_tokens.extend(newline_tokens)
- except:
+ except Exception:
pass # Skip if encoding fails
-
+
sampler_params['xtc_special_tokens'] = xtc_special_tokens
-
- return max_tokens, sampler_params
+
+ # Stop sequences are applied post-decode (mlx-lm doesn't have a
+ # built-in stop-sequence sampler param). Preserve the list here.
+ stop_words = list(getattr(request, 'StopPrompts', []) or [])
+
+ return max_tokens, sampler_params, logits_params, stop_words
+
+ def _tool_module_from_tokenizer(self):
+ """Build a duck-typed tool module from the TokenizerWrapper.
+
+ On mlx-lm >= 0.30 the wrapper exposes a ``tool_parser`` callable
+ that's been resolved from the model's chat template. On older
+ releases (e.g. 0.29.x) the wrapper only carries the start/end
+ markers — fall back to ``json.loads`` of the body, which matches
+ what ``mlx_lm.tool_parsers.json_tools.parse_tool_call`` does on
+ HEAD and covers the only format 0.29 detects (````).
+ """
+ start = getattr(self.tokenizer, "tool_call_start", None)
+ end = getattr(self.tokenizer, "tool_call_end", None)
+ if not start:
+ return None
+ parse_fn = getattr(self.tokenizer, "tool_parser", None)
+ if parse_fn is None:
+ def parse_fn(body, tools): # noqa: E306 — local fallback
+ return json.loads(body.strip())
+ return types.SimpleNamespace(
+ tool_call_start=start,
+ tool_call_end=end or "",
+ parse_tool_call=parse_fn,
+ )
+
+ def _finalize_output(self, request, generated_text, last_response):
+ """Build a ChatDelta + token counts + logprobs from accumulated output.
+
+ Returns ``(content, reasoning_content, tool_calls_proto,
+ prompt_token_count, completion_token_count, logprobs_bytes)``.
+ """
+ content = generated_text
+ reasoning_content = ""
+
+ if getattr(self.tokenizer, "has_thinking", False):
+ think_start = getattr(self.tokenizer, "think_start", "") or ""
+ think_end = getattr(self.tokenizer, "think_end", "") or ""
+ reasoning_content, content = split_reasoning(content, think_start, think_end)
+
+ tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
+ tool_module = None
+ if getattr(self.tokenizer, "has_tool_calling", False):
+ tool_module = self._tool_module_from_tokenizer()
+ if tool_module is not None:
+ parsed_tools = None
+ if request.Tools:
+ try:
+ parsed_tools = json.loads(request.Tools)
+ except json.JSONDecodeError:
+ parsed_tools = None
+ calls, content = parse_tool_calls(content, tool_module, parsed_tools)
+ for c in calls:
+ tool_calls_proto.append(
+ backend_pb2.ToolCallDelta(
+ index=c["index"],
+ id=c["id"],
+ name=c["name"],
+ arguments=c["arguments"],
+ )
+ )
+
+ prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
+ completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
+
+ logprobs_bytes = b""
+ # Logprobs extraction — only when the request asked for them.
+ if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
+ try:
+ lp = getattr(last_response, "logprobs", None)
+ if lp is not None:
+ # GenerationResponse.logprobs on the last chunk is the
+ # logprob distribution of the final token. Without a
+ # per-token history we at minimum surface the last token's
+ # top-1 logprob so clients get a non-empty field.
+ token_id = int(getattr(last_response, "token", 0) or 0)
+ token_text = self.tokenizer.decode([token_id]) if token_id else ""
+ top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
+ logprobs_bytes = json.dumps(
+ {
+ "content": [
+ {"token": token_text, "logprob": top_logprob}
+ ]
+ }
+ ).encode("utf-8")
+ except Exception as e:
+ print(f"[mlx] Logprobs extraction failed: {e}", file=sys.stderr)
+
+ return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes
+
+ def _truncate_at_stop(self, text, stop_words):
+ """Truncate ``text`` at the first occurrence of any stop sequence."""
+ if not stop_words:
+ return text
+ earliest = len(text)
+ for stop in stop_words:
+ if not stop:
+ continue
+ idx = text.find(stop)
+ if idx >= 0 and idx < earliest:
+ earliest = idx
+ return text[:earliest] if earliest < len(text) else text
async def serve(address):
# Start asyncio gRPC server
diff --git a/backend/python/mlx/test.py b/backend/python/mlx/test.py
index 53d7bc7ec..ac5ff4c06 100644
--- a/backend/python/mlx/test.py
+++ b/backend/python/mlx/test.py
@@ -1,11 +1,20 @@
+import os
+import sys
import unittest
import subprocess
import time
+import types
import grpc
import backend_pb2
import backend_pb2_grpc
+# Make the shared helpers importable so we can unit-test them without a
+# running gRPC server.
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
+from python_utils import messages_to_dicts, parse_options
+from mlx_utils import parse_tool_calls, split_reasoning
+
class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service.
@@ -231,4 +240,104 @@ class TestBackendServicer(unittest.TestCase):
self.tearDown()
+ def test_tokenize_string(self):
+ """TokenizeString should return a non-empty token list for a known prompt."""
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(
+ backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")
+ )
+ self.assertTrue(response.success)
+ resp = stub.TokenizeString(backend_pb2.PredictOptions(Prompt="Hello, world"))
+ self.assertGreater(resp.length, 0)
+ self.assertEqual(len(list(resp.tokens)), resp.length)
+ except Exception as err:
+ print(err)
+ self.fail("TokenizeString service failed")
+ finally:
+ self.tearDown()
+
+ def test_free(self):
+ """Free should release the model and not crash on subsequent calls."""
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.LoadModel(
+ backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")
+ )
+ self.assertTrue(response.success)
+ free_resp = stub.Free(backend_pb2.HealthMessage())
+ self.assertTrue(free_resp.success)
+ except Exception as err:
+ print(err)
+ self.fail("Free service failed")
+ finally:
+ self.tearDown()
+
+
+class TestSharedHelpers(unittest.TestCase):
+ """Server-less unit tests for the helpers the mlx backend depends on."""
+
+ def test_parse_options_typed(self):
+ opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello", "no_colon_skipped"])
+ self.assertEqual(opts["temperature"], 0.7)
+ self.assertEqual(opts["max_tokens"], 128)
+ self.assertIs(opts["trust"], True)
+ self.assertEqual(opts["name"], "hello")
+ self.assertNotIn("no_colon_skipped", opts)
+
+ def test_messages_to_dicts_roundtrip(self):
+ # Build proto Message objects (via backend_pb2 to match real gRPC)
+ msgs = [
+ backend_pb2.Message(role="user", content="hi"),
+ backend_pb2.Message(
+ role="assistant",
+ content="",
+ tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
+ ),
+ backend_pb2.Message(
+ role="tool",
+ content="42",
+ tool_call_id="call_1",
+ name="f",
+ ),
+ ]
+ out = messages_to_dicts(msgs)
+ self.assertEqual(out[0], {"role": "user", "content": "hi"})
+ self.assertEqual(out[1]["role"], "assistant")
+ self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
+ self.assertEqual(out[2]["tool_call_id"], "call_1")
+ self.assertEqual(out[2]["name"], "f")
+
+ def test_split_reasoning(self):
+ r, c = split_reasoning("step 1\nstep 2The answer is 42.", "", "")
+ self.assertEqual(r, "step 1\nstep 2")
+ self.assertEqual(c, "The answer is 42.")
+
+ def test_split_reasoning_no_marker(self):
+ r, c = split_reasoning("just text", "", "")
+ self.assertEqual(r, "")
+ self.assertEqual(c, "just text")
+
+ def test_parse_tool_calls_with_shim(self):
+ tm = types.SimpleNamespace(
+ tool_call_start="",
+ tool_call_end="",
+ parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
+ )
+ calls, remaining = parse_tool_calls(
+ "Sure: Paris",
+ tm,
+ tools=None,
+ )
+ self.assertEqual(len(calls), 1)
+ self.assertEqual(calls[0]["name"], "get_weather")
+ self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
+ self.assertEqual(calls[0]["index"], 0)
+ self.assertNotIn("", remaining)
+
+
# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py
\ No newline at end of file
diff --git a/core/gallery/importers/mlx.go b/core/gallery/importers/mlx.go
index 7ab513f6d..075c58cf7 100644
--- a/core/gallery/importers/mlx.go
+++ b/core/gallery/importers/mlx.go
@@ -84,6 +84,20 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
// Apply per-model-family inference parameter defaults
config.ApplyInferenceDefaults(&modelConfig, details.URI)
+ // Auto-set tool_parser / reasoning_parser from parser_defaults.json so
+ // the generated YAML mirrors what the vllm importer produces. The mlx
+ // backends auto-detect parsers from the chat template at runtime and
+ // ignore these Options, but surfacing them in the config keeps the two
+ // paths consistent and gives users a single place to override.
+ if parsers := config.MatchParserDefaults(details.URI); parsers != nil {
+ if tp, ok := parsers["tool_parser"]; ok {
+ modelConfig.Options = append(modelConfig.Options, "tool_parser:"+tp)
+ }
+ if rp, ok := parsers["reasoning_parser"]; ok {
+ modelConfig.Options = append(modelConfig.Options, "reasoning_parser:"+rp)
+ }
+ }
+
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err