diff --git a/Makefile b/Makefile index 7e2e35052..5ef606297 100644 --- a/Makefile +++ b/Makefile @@ -519,6 +519,22 @@ test-extra-backend-vllm: docker-build-vllm BACKEND_TEST_OPTIONS=tool_parser:hermes \ $(MAKE) test-extra-backend +## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool +## parser from the chat template, so no tool_parser: option is needed (it +## would be ignored at runtime). Run this on macOS / arm64 with Metal; the +## Linux/CPU mlx variant is untested in CI. +test-extra-backend-mlx: docker-build-mlx + BACKEND_IMAGE=local-ai-backend:mlx \ + BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \ + BACKEND_TEST_CAPS=health,load,predict,stream,tools \ + $(MAKE) test-extra-backend + +test-extra-backend-mlx-vlm: docker-build-mlx-vlm + BACKEND_IMAGE=local-ai-backend:mlx-vlm \ + BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \ + BACKEND_TEST_CAPS=health,load,predict,stream,tools \ + $(MAKE) test-extra-backend + DOCKER_IMAGE?=local-ai IMAGE_TYPE?=core BASE_IMAGE?=ubuntu:24.04 @@ -652,6 +668,8 @@ BACKEND_NEMO = nemo|python|.|false|true BACKEND_VOXCPM = voxcpm|python|.|false|true BACKEND_WHISPERX = whisperx|python|.|false|true BACKEND_ACE_STEP = ace-step|python|.|false|true +BACKEND_MLX = mlx|python|.|false|true +BACKEND_MLX_VLM = mlx-vlm|python|.|false|true BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true BACKEND_TRL = trl|python|.|false|true BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true @@ -720,6 +738,8 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX))) $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP))) $(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP))) $(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP))) +$(eval $(call generate-docker-build-target,$(BACKEND_MLX))) +$(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM))) $(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED))) $(eval $(call generate-docker-build-target,$(BACKEND_TRL))) $(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION))) diff --git a/backend/python/common/libbackend.sh b/backend/python/common/libbackend.sh index c923c12cf..982dafab3 100644 --- a/backend/python/common/libbackend.sh +++ b/backend/python/common/libbackend.sh @@ -344,7 +344,16 @@ function ensureVenv() { if [ ! -d "${EDIR}/venv" ]; then if [ "x${USE_PIP}" == "xtrue" ]; then - "${interpreter}" -m venv --copies "${EDIR}/venv" + # --copies is only needed when we will later relocate the venv via + # _makeVenvPortable (PORTABLE_PYTHON=true). Some Python builds — + # notably macOS system Python — refuse to create a venv with + # --copies because the build doesn't support it. Fall back to + # symlinks in that case. + local venv_args="" + if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then + venv_args="--copies" + fi + "${interpreter}" -m venv ${venv_args} "${EDIR}/venv" source "${EDIR}/venv/bin/activate" "${interpreter}" -m pip install --upgrade pip else diff --git a/backend/python/common/mlx_utils.py b/backend/python/common/mlx_utils.py new file mode 100644 index 000000000..6b34eb962 --- /dev/null +++ b/backend/python/common/mlx_utils.py @@ -0,0 +1,100 @@ +"""Shared utilities for the mlx and mlx-vlm gRPC backends. + +These helpers wrap mlx-lm's and mlx-vlm's native tool-parser modules, which +auto-detect the right parser from the model's chat template. Each tool +module exposes ``tool_call_start``, ``tool_call_end`` and +``parse_tool_call(text, tools) -> dict | list[dict]``. + +The split-reasoning helper is generic enough to work with any think-start / +think-end delimiter pair. +""" +import json +import re +import sys +import uuid + + +def split_reasoning(text, think_start, think_end): + """Split ``...`` blocks out of ``text``. + + Returns ``(reasoning_content, remaining_text)``. When ``think_start`` is + empty or not found, returns ``("", text)`` unchanged. + """ + if not think_start or not text or think_start not in text: + return "", text + pattern = re.compile( + re.escape(think_start) + r"(.*?)" + re.escape(think_end or ""), + re.DOTALL, + ) + reasoning_parts = pattern.findall(text) + if not reasoning_parts: + return "", text + remaining = pattern.sub("", text).strip() + return "\n".join(p.strip() for p in reasoning_parts), remaining + + +def parse_tool_calls(text, tool_module, tools): + """Extract tool calls from ``text`` using a mlx-lm tool module. + + Ports the ``process_tool_calls`` logic from + ``mlx_vlm/server.py`` (v0.10 onwards). ``tool_module`` must expose + ``tool_call_start``, ``tool_call_end`` and ``parse_tool_call``. + + Returns ``(calls, remaining_text)`` where ``calls`` is a list of dicts: + + [{"index": int, "id": str, "name": str, "arguments": str (JSON)}] + + and ``remaining_text`` is the free-form text with the tool call blocks + removed. ``(calls, text)`` is returned unchanged if ``tool_module`` is + ``None`` or the start delimiter isn't present. + """ + if tool_module is None or not text: + return [], text + start = getattr(tool_module, "tool_call_start", None) + end = getattr(tool_module, "tool_call_end", None) + parse_fn = getattr(tool_module, "parse_tool_call", None) + if not start or parse_fn is None or start not in text: + return [], text + + if end == "" or end is None: + pattern = re.compile( + re.escape(start) + r".*?(?:\n|$)", + re.DOTALL, + ) + else: + pattern = re.compile( + re.escape(start) + r".*?" + re.escape(end), + re.DOTALL, + ) + + matches = pattern.findall(text) + if not matches: + return [], text + + remaining = pattern.sub(" ", text).strip() + calls = [] + for match in matches: + call_body = match.strip().removeprefix(start) + if end: + call_body = call_body.removesuffix(end) + call_body = call_body.strip() + try: + parsed = parse_fn(call_body, tools) + except Exception as e: + print( + f"[mlx_utils] Invalid tool call: {call_body!r} ({e})", + file=sys.stderr, + ) + continue + if not isinstance(parsed, list): + parsed = [parsed] + for tc in parsed: + calls.append( + { + "index": len(calls), + "id": str(uuid.uuid4()), + "name": (tc.get("name") or "").strip(), + "arguments": json.dumps(tc.get("arguments", {}), ensure_ascii=False), + } + ) + return calls, remaining diff --git a/backend/python/common/python_utils.py b/backend/python/common/python_utils.py new file mode 100644 index 000000000..aa61ab578 --- /dev/null +++ b/backend/python/common/python_utils.py @@ -0,0 +1,65 @@ +"""Generic utilities shared across Python gRPC backends. + +These helpers don't depend on any specific inference framework and can be +imported by any backend that needs to parse LocalAI gRPC options or build a +chat-template-compatible message list from proto Message objects. +""" +import json + + +def parse_options(options_list): + """Parse Options[] list of ``key:value`` strings into a dict. + + Supports type inference for common cases (bool, int, float). Unknown or + mixed-case values are returned as strings. + + Used by LoadModel to extract backend-specific options passed via + ``ModelOptions.Options`` in ``backend.proto``. + """ + opts = {} + for opt in options_list: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + key = key.strip() + value = value.strip() + # Try type conversion + if value.lower() in ("true", "false"): + opts[key] = value.lower() == "true" + else: + try: + opts[key] = int(value) + except ValueError: + try: + opts[key] = float(value) + except ValueError: + opts[key] = value + return opts + + +def messages_to_dicts(proto_messages): + """Convert proto ``Message`` objects to dicts suitable for ``apply_chat_template``. + + Handles: ``role``, ``content``, ``name``, ``tool_call_id``, + ``reasoning_content``, ``tool_calls`` (JSON string → Python list). + + HuggingFace chat templates (and their MLX/vLLM wrappers) expect a list of + plain dicts — proto Message objects don't work directly with Jinja, so + this conversion is needed before every ``apply_chat_template`` call. + """ + result = [] + for msg in proto_messages: + d = {"role": msg.role, "content": msg.content or ""} + if msg.name: + d["name"] = msg.name + if msg.tool_call_id: + d["tool_call_id"] = msg.tool_call_id + if msg.reasoning_content: + d["reasoning_content"] = msg.reasoning_content + if msg.tool_calls: + try: + d["tool_calls"] = json.loads(msg.tool_calls) + except json.JSONDecodeError: + pass + result.append(d) + return result diff --git a/backend/python/common/vllm_utils.py b/backend/python/common/vllm_utils.py index bc0518663..9124645ac 100644 --- a/backend/python/common/vllm_utils.py +++ b/backend/python/common/vllm_utils.py @@ -1,63 +1,22 @@ -"""Shared utilities for vLLM-based backends.""" -import json +"""vLLM-specific helpers for the vllm and vllm-omni gRPC backends. + +Generic helpers (``parse_options``, ``messages_to_dicts``) live in +``python_utils`` and are re-exported here for backwards compatibility with +existing imports in both backends. +""" import sys +from python_utils import messages_to_dicts, parse_options -def parse_options(options_list): - """Parse Options[] list of 'key:value' strings into a dict. - - Supports type inference for common cases (bool, int, float). - Used by LoadModel to extract backend-specific options. - """ - opts = {} - for opt in options_list: - if ":" not in opt: - continue - key, value = opt.split(":", 1) - key = key.strip() - value = value.strip() - # Try type conversion - if value.lower() in ("true", "false"): - opts[key] = value.lower() == "true" - else: - try: - opts[key] = int(value) - except ValueError: - try: - opts[key] = float(value) - except ValueError: - opts[key] = value - return opts - - -def messages_to_dicts(proto_messages): - """Convert proto Message objects to list of dicts for apply_chat_template(). - - Handles: role, content, name, tool_call_id, reasoning_content, tool_calls (JSON string -> list). - """ - result = [] - for msg in proto_messages: - d = {"role": msg.role, "content": msg.content or ""} - if msg.name: - d["name"] = msg.name - if msg.tool_call_id: - d["tool_call_id"] = msg.tool_call_id - if msg.reasoning_content: - d["reasoning_content"] = msg.reasoning_content - if msg.tool_calls: - try: - d["tool_calls"] = json.loads(msg.tool_calls) - except json.JSONDecodeError: - pass - result.append(d) - return result +__all__ = ["parse_options", "messages_to_dicts", "setup_parsers"] def setup_parsers(opts): - """Return (tool_parser_cls, reasoning_parser_cls) tuple from opts dict. + """Return ``(tool_parser_cls, reasoning_parser_cls)`` from an opts dict. - Uses vLLM's native ToolParserManager and ReasoningParserManager. - Returns (None, None) if vLLM is not installed or parsers not available. + Uses vLLM's native ``ToolParserManager`` / ``ReasoningParserManager``. + Returns ``(None, None)`` if vLLM isn't installed or the requested + parser name can't be resolved. """ tool_parser_cls = None reasoning_parser_cls = None diff --git a/backend/python/mlx-distributed/backend.py b/backend/python/mlx-distributed/backend.py index 90d74eba8..b03775f58 100644 --- a/backend/python/mlx-distributed/backend.py +++ b/backend/python/mlx-distributed/backend.py @@ -15,17 +15,21 @@ Two startup modes: import asyncio from concurrent import futures import argparse +import gc import json import os import signal import sys import tempfile +import types from typing import List import grpc sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) from grpc_auth import get_auth_interceptors +from python_utils import messages_to_dicts, parse_options as _shared_parse_options +from mlx_utils import parse_tool_calls, split_reasoning import backend_pb2 @@ -62,37 +66,10 @@ def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None): raise ValueError(f"Unknown backend: {backend}") -def is_float(s): - try: - float(s) - return True - except ValueError: - return False - - -def is_int(s): - try: - int(s) - return True - except ValueError: - return False - - -def parse_options(options): - """Parse key:value option strings into a dict.""" - result = {} - for opt in options: - if ":" not in opt: - continue - key, value = opt.split(":", 1) - if is_float(value): - value = float(value) - elif is_int(value): - value = int(value) - elif value.lower() in ["true", "false"]: - value = value.lower() == "true" - result[key] = value - return result +# Re-export the shared helper under the local name for back-compat with +# any callers (and the existing distributed worker tests) that imported +# parse_options directly from this module. +parse_options = _shared_parse_options class BackendServicer(backend_pb2_grpc.BackendServicer): @@ -188,6 +165,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): ) print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr) + # Log auto-detected TokenizerWrapper capabilities. Same shape + # as the mlx backend: has_tool_calling / has_thinking from + # mlx_lm.tokenizer_utils + the start/end markers it sniffed + # from the chat template / vocab. + has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False)) + has_thinking = bool(getattr(self.tokenizer, "has_thinking", False)) + tcs = getattr(self.tokenizer, "tool_call_start", None) + tce = getattr(self.tokenizer, "tool_call_end", None) + print( + f"[Rank 0] Tokenizer capabilities: has_tool_calling={has_tools} " + f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}", + file=sys.stderr, + ) + except Exception as err: print(f"[Rank 0] Error loading model: {err}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Error loading model: {err}") @@ -201,7 +192,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): try: import mlx.core as mx from mlx_lm import stream_generate - from mlx_lm.sample_utils import make_sampler + from mlx_lm.sample_utils import make_logits_processors, make_sampler prompt_text = self._prepare_prompt(request) tokens = self._get_tokens_from_prompt(prompt_text) @@ -211,7 +202,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) self.coordinator.broadcast_tokens(tokens) - max_tokens, sampler_params = self._build_generation_params(request) + max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request) if self.coordinator: gen_params = self.coordinator.broadcast_generation_params( @@ -222,6 +213,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): max_tokens = gen_params["max_tokens"] sampler = make_sampler(**sampler_params) + logits_processors = make_logits_processors(**logits_params) if logits_params else None # Use prompt cache in single-node mode gen_kwargs = {} @@ -238,22 +230,44 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): tokens = remaining_tokens if remaining_tokens else cache_key generated = [] + last_response = None for response in stream_generate( self.model, self.tokenizer, prompt=tokens, max_tokens=max_tokens, sampler=sampler, + logits_processors=logits_processors, **gen_kwargs, ): generated.append(response.text) + last_response = response if cache_key is not None: cache_key.append(response.token) + if stop_words and any(s in "".join(generated) for s in stop_words): + break if self.lru_cache is not None and cache_key is not None: self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) - return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8')) + full_text = self._truncate_at_stop("".join(generated), stop_words) + content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = ( + self._finalize_output(request, full_text, last_response) + ) + + return backend_pb2.Reply( + message=bytes(content, encoding='utf-8'), + prompt_tokens=prompt_tokens, + tokens=completion_tokens, + logprobs=logprobs_bytes, + chat_deltas=[ + backend_pb2.ChatDelta( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls_proto, + ) + ], + ) except Exception as e: print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr) @@ -268,7 +282,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): try: import mlx.core as mx from mlx_lm import stream_generate - from mlx_lm.sample_utils import make_sampler + from mlx_lm.sample_utils import make_logits_processors, make_sampler prompt_text = self._prepare_prompt(request) tokens = self._get_tokens_from_prompt(prompt_text) @@ -278,7 +292,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) self.coordinator.broadcast_tokens(tokens) - max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) + max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params( + request, default_max_tokens=512 + ) if self.coordinator: gen_params = self.coordinator.broadcast_generation_params( @@ -289,6 +305,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): max_tokens = gen_params["max_tokens"] sampler = make_sampler(**sampler_params) + logits_processors = make_logits_processors(**logits_params) if logits_params else None # Use prompt cache in single-node mode gen_kwargs = {} @@ -304,17 +321,45 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): gen_kwargs['prompt_cache'] = prompt_cache tokens = remaining_tokens if remaining_tokens else cache_key + accumulated = [] + last_response = None for response in stream_generate( self.model, self.tokenizer, prompt=tokens, max_tokens=max_tokens, sampler=sampler, + logits_processors=logits_processors, **gen_kwargs, ): if cache_key is not None: cache_key.append(response.token) - yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) + accumulated.append(response.text) + last_response = response + yield backend_pb2.Reply( + message=bytes(response.text, encoding='utf-8'), + chat_deltas=[backend_pb2.ChatDelta(content=response.text)], + ) + if stop_words and any(s in "".join(accumulated) for s in stop_words): + break + + full_text = self._truncate_at_stop("".join(accumulated), stop_words) + content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = ( + self._finalize_output(request, full_text, last_response) + ) + yield backend_pb2.Reply( + message=b"", + prompt_tokens=prompt_tokens, + tokens=completion_tokens, + logprobs=logprobs_bytes, + chat_deltas=[ + backend_pb2.ChatDelta( + content="", + reasoning_content=reasoning_content, + tool_calls=tool_calls_proto, + ) + ], + ) except Exception as e: print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr) @@ -335,12 +380,74 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): context.set_details("Embeddings are not supported in the MLX distributed backend.") return backend_pb2.EmbeddingResult() + async def TokenizeString(self, request, context): + if not hasattr(self, "tokenizer") or self.tokenizer is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("tokenizer not loaded") + return backend_pb2.TokenizationResponse() + try: + tokens = self.tokenizer.encode(request.Prompt) + if hasattr(tokens, "tolist"): + tokens = tokens.tolist() + tokens = list(tokens) + return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens) + except Exception as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + return backend_pb2.TokenizationResponse() + + async def Free(self, request, context): + try: + # If we're rank 0 of a distributed run, tell workers to shut + # down their per-request loops first so they release the model. + if self.coordinator is not None: + try: + from coordinator import CMD_SHUTDOWN + self.coordinator.broadcast_command(CMD_SHUTDOWN) + except Exception as e: + print(f"[Rank 0] failed to broadcast shutdown: {e}", file=sys.stderr) + if hasattr(self, "model"): + del self.model + if hasattr(self, "tokenizer"): + del self.tokenizer + if self.lru_cache is not None: + try: + self.lru_cache.clear() + except Exception: + pass + self.lru_cache = None + self.coordinator = None + self.group = None + gc.collect() + try: + import mlx.core as mx # type: ignore + if hasattr(mx, "clear_cache"): + mx.clear_cache() + elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"): + mx.metal.clear_cache() + except Exception: + pass + return backend_pb2.Result(success=True, message="MLX distributed model freed") + except Exception as e: + return backend_pb2.Result(success=False, message=str(e)) + def _prepare_prompt(self, request): if not request.Prompt and request.UseTokenizerTemplate and request.Messages: - messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages] - return self.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + messages = messages_to_dicts(request.Messages) + kwargs = {"tokenize": False, "add_generation_prompt": True} + if request.Tools: + try: + kwargs["tools"] = json.loads(request.Tools) + except json.JSONDecodeError: + pass + if request.Metadata.get("enable_thinking", "").lower() == "true": + kwargs["enable_thinking"] = True + try: + return self.tokenizer.apply_chat_template(messages, **kwargs) + except TypeError: + return self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) return request.Prompt def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: @@ -349,6 +456,82 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return tokens.tolist() return list(tokens) + def _tool_module_from_tokenizer(self): + """Same shim as the mlx backend: fall back to json.loads when the + installed mlx-lm doesn't expose a tool_parser callable on the + wrapper (true on 0.29.x — only HEAD ships parsers).""" + start = getattr(self.tokenizer, "tool_call_start", None) + end = getattr(self.tokenizer, "tool_call_end", None) + if not start: + return None + parse_fn = getattr(self.tokenizer, "tool_parser", None) + if parse_fn is None: + def parse_fn(body, tools): # noqa: E306 + return json.loads(body.strip()) + return types.SimpleNamespace( + tool_call_start=start, + tool_call_end=end or "", + parse_tool_call=parse_fn, + ) + + def _truncate_at_stop(self, text, stop_words): + if not stop_words: + return text + earliest = len(text) + for stop in stop_words: + if not stop: + continue + idx = text.find(stop) + if idx >= 0 and idx < earliest: + earliest = idx + return text[:earliest] if earliest < len(text) else text + + def _finalize_output(self, request, generated_text, last_response): + content = generated_text + reasoning_content = "" + if getattr(self.tokenizer, "has_thinking", False): + think_start = getattr(self.tokenizer, "think_start", "") or "" + think_end = getattr(self.tokenizer, "think_end", "") or "" + reasoning_content, content = split_reasoning(content, think_start, think_end) + + tool_calls_proto: List[backend_pb2.ToolCallDelta] = [] + tool_module = None + if getattr(self.tokenizer, "has_tool_calling", False): + tool_module = self._tool_module_from_tokenizer() + if tool_module is not None: + parsed_tools = None + if request.Tools: + try: + parsed_tools = json.loads(request.Tools) + except json.JSONDecodeError: + parsed_tools = None + calls, content = parse_tool_calls(content, tool_module, parsed_tools) + for c in calls: + tool_calls_proto.append( + backend_pb2.ToolCallDelta( + index=c["index"], id=c["id"], name=c["name"], arguments=c["arguments"], + ) + ) + + prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0 + completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0 + + logprobs_bytes = b"" + if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0: + try: + lp = getattr(last_response, "logprobs", None) + if lp is not None: + token_id = int(getattr(last_response, "token", 0) or 0) + token_text = self.tokenizer.decode([token_id]) if token_id else "" + top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0 + logprobs_bytes = json.dumps( + {"content": [{"token": token_text, "logprob": top_logprob}]} + ).encode("utf-8") + except Exception as e: + print(f"[Rank 0] Logprobs extraction failed: {e}", file=sys.stderr) + + return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes + def _build_generation_params(self, request, default_max_tokens=200): import mlx.core as mx @@ -373,6 +556,22 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): 'xtc_probability': 0.0, } + # Logits processor parameters — pulled from the request and + # forwarded to make_logits_processors. Rank 0 is the only rank + # running the sampler so we don't need to broadcast these to + # workers (workers participate in the pipeline-parallel forward + # pass only). + logits_params = {} + repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0 + if repetition_penalty and repetition_penalty != 1.0: + logits_params['repetition_penalty'] = repetition_penalty + presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0 + if presence_penalty: + logits_params['presence_penalty'] = presence_penalty + frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0 + if frequency_penalty: + logits_params['frequency_penalty'] = frequency_penalty + seed = getattr(request, 'Seed', 0) if seed != 0: mx.random.seed(seed) @@ -392,9 +591,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): for opt_key, param_key in option_mapping.items(): if opt_key in self.options: sampler_params[param_key] = self.options[opt_key] + for opt_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'): + if opt_key in self.options: + logits_params[opt_key] = self.options[opt_key] if 'seed' in self.options: mx.random.seed(self.options['seed']) + stop_words = list(getattr(request, 'StopPrompts', []) or []) + return max_tokens, sampler_params, logits_params, stop_words + # XTC special tokens xtc_special_tokens = [] if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids: diff --git a/backend/python/mlx-distributed/test.py b/backend/python/mlx-distributed/test.py index 4cb1440ed..81bf6c67c 100644 --- a/backend/python/mlx-distributed/test.py +++ b/backend/python/mlx-distributed/test.py @@ -1,3 +1,6 @@ +import os +import sys +import types import unittest import subprocess import time @@ -6,6 +9,12 @@ import grpc import backend_pb2 import backend_pb2_grpc +# Make the shared helpers importable so we can unit-test them without a +# running gRPC server. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +from python_utils import messages_to_dicts, parse_options +from mlx_utils import parse_tool_calls, split_reasoning + class TestBackendServicer(unittest.TestCase): def setUp(self): @@ -85,3 +94,44 @@ class TestBackendServicer(unittest.TestCase): self.fail("sampling params service failed") finally: self.tearDown() + + +class TestSharedHelpers(unittest.TestCase): + """Server-less unit tests for the helpers the mlx-distributed backend depends on.""" + + def test_parse_options_typed(self): + opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true"]) + self.assertEqual(opts["temperature"], 0.7) + self.assertEqual(opts["max_tokens"], 128) + self.assertIs(opts["trust"], True) + + def test_messages_to_dicts_roundtrip(self): + msgs = [ + backend_pb2.Message(role="user", content="hi"), + backend_pb2.Message( + role="assistant", + content="", + tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]', + ), + backend_pb2.Message(role="tool", content="42", tool_call_id="call_1", name="f"), + ] + out = messages_to_dicts(msgs) + self.assertEqual(out[0], {"role": "user", "content": "hi"}) + self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f") + self.assertEqual(out[2]["tool_call_id"], "call_1") + + def test_split_reasoning(self): + r, c = split_reasoning("planfinal", "", "") + self.assertEqual(r, "plan") + self.assertEqual(c, "final") + + def test_parse_tool_calls_with_shim(self): + tm = types.SimpleNamespace( + tool_call_start="", + tool_call_end="", + parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}}, + ) + calls, remaining = parse_tool_calls("Paris", tm, tools=None) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}') diff --git a/backend/python/mlx-vlm/backend.py b/backend/python/mlx-vlm/backend.py index 578a5e563..074214ece 100644 --- a/backend/python/mlx-vlm/backend.py +++ b/backend/python/mlx-vlm/backend.py @@ -2,11 +2,14 @@ import asyncio from concurrent import futures import argparse +import gc +import json import signal import sys import os +import tempfile +import types from typing import List -import time import backend_pb2 import backend_pb2_grpc @@ -15,30 +18,18 @@ import grpc sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) from grpc_auth import get_auth_interceptors +from python_utils import messages_to_dicts, parse_options +from mlx_utils import parse_tool_calls, split_reasoning -from mlx_vlm import load, generate, stream_generate +from mlx_vlm import load, stream_generate from mlx_vlm.prompt_utils import apply_chat_template -from mlx_vlm.utils import load_config, load_image +from mlx_vlm.tool_parsers import _infer_tool_parser, load_tool_module +from mlx_vlm.utils import load_config +from mlx_lm.sample_utils import make_logits_processors, make_sampler import mlx.core as mx import base64 import io from PIL import Image -import tempfile - -def is_float(s): - """Check if a string can be converted to float.""" - try: - float(s) - return True - except ValueError: - return False -def is_int(s): - """Check if a string can be converted to int.""" - try: - int(s) - return True - except ValueError: - return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -78,36 +69,52 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): try: print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr) print(f"Request: {request}", file=sys.stderr) - - # Parse options like in the diffusers backend - options = request.Options - self.options = {} - - # The options are a list of strings in this form optname:optvalue - # We store all the options in a dict for later use - for opt in options: - if ":" not in opt: - continue - key, value = opt.split(":", 1) # Split only on first colon to handle values with colons - - if is_float(value): - value = float(value) - elif is_int(value): - value = int(value) - elif value.lower() in ["true", "false"]: - value = value.lower() == "true" - - self.options[key] = value - + + # Parse Options[] key:value strings into a typed dict + self.options = parse_options(request.Options) print(f"Options: {self.options}", file=sys.stderr) - + # Load model and processor using MLX-VLM # mlx-vlm load function returns (model, processor) instead of (model, tokenizer) self.model, self.processor = load(request.Model) - + # Load model config for chat template support self.config = load_config(request.Model) - + + # Auto-infer the tool parser from the chat template. mlx-vlm has + # its own _infer_tool_parser that falls back to mlx-lm parsers. + tokenizer = ( + self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor + ) + self.tool_module = None + if hasattr(tokenizer, "chat_template"): + try: + parser_type = _infer_tool_parser(tokenizer.chat_template) + if parser_type is not None: + self.tool_module = load_tool_module(parser_type) + print( + f"[mlx-vlm] auto-detected tool parser: {parser_type}", + file=sys.stderr, + ) + else: + print( + "[mlx-vlm] no tool parser matched the chat template", + file=sys.stderr, + ) + except Exception as e: + print( + f"[mlx-vlm] failed to load tool parser: {e}", + file=sys.stderr, + ) + + # Reasoning tokens — check if the tokenizer advertises thinking + # markers. Fall back to empty strings (split_reasoning no-ops). + self.think_start = getattr(tokenizer, "think_start", "") or "" + self.think_end = getattr(tokenizer, "think_end", "") or "" + self.has_thinking = bool( + getattr(tokenizer, "has_thinking", False) or self.think_start + ) + except Exception as err: print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}") @@ -128,63 +135,72 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): """ temp_files = [] try: - # Process images and audios from request - image_paths = [] - audio_paths = [] - - # Process images - if request.Images: - for img_data in request.Images: - img_path = self.load_image_from_base64(img_data) - if img_path: - image_paths.append(img_path) - temp_files.append(img_path) - - # Process audios - if request.Audios: - for audio_data in request.Audios: - audio_path = self.load_audio_from_base64(audio_data) - if audio_path: - audio_paths.append(audio_path) - temp_files.append(audio_path) - - # Prepare the prompt with multimodal information - prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths)) - - # Build generation parameters using request attributes and options - max_tokens, generation_params = self._build_generation_params(request) - - print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr) - print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr) - - # Generate text using MLX-VLM with multimodal inputs - response = generate( + image_paths, audio_paths = self._collect_media(request, temp_files) + + prompt = self._prepare_prompt( + request, + num_images=len(image_paths), + num_audios=len(audio_paths), + ) + + max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request) + sampler = make_sampler(**sampler_params) + logits_processors = make_logits_processors(**logits_params) if logits_params else None + + print( + f"Generating text with MLX-VLM - max_tokens: {max_tokens}, " + f"images: {len(image_paths)}, audios: {len(audio_paths)}", + file=sys.stderr, + ) + + accumulated = [] + last_response = None + for response in stream_generate( model=self.model, processor=self.processor, prompt=prompt, image=image_paths if image_paths else None, audio=audio_paths if audio_paths else None, max_tokens=max_tokens, - temperature=generation_params.get('temp', 0.6), - top_p=generation_params.get('top_p', 1.0), - verbose=False + sampler=sampler, + logits_processors=logits_processors, + ): + accumulated.append(response.text) + last_response = response + if stop_words and any(s in "".join(accumulated) for s in stop_words): + break + + full_text = self._truncate_at_stop("".join(accumulated), stop_words) + content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = ( + self._finalize_output(request, full_text, last_response) ) - - return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) - + + return backend_pb2.Reply( + message=bytes(content, encoding='utf-8'), + prompt_tokens=prompt_tokens, + tokens=completion_tokens, + logprobs=logprobs_bytes, + chat_deltas=[ + backend_pb2.ChatDelta( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls_proto, + ) + ], + ) + except Exception as e: print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Generation failed: {str(e)}") return backend_pb2.Reply(message=bytes("", encoding='utf-8')) finally: - # Clean up temporary files self.cleanup_temp_files(temp_files) def Embedding(self, request, context): """ A gRPC method that calculates embeddings for a given sentence. - + Note: MLX-VLM doesn't support embeddings directly. This method returns an error. Args: @@ -199,6 +215,79 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): context.set_details("Embeddings are not supported in the MLX-VLM backend.") return backend_pb2.EmbeddingResult() + def _collect_media(self, request, temp_files): + """Decode base64 Images and Audios into temp file paths. + + Appends every temp file to ``temp_files`` so the finally block can + clean up even on mid-generation errors. + """ + image_paths = [] + audio_paths = [] + if request.Images: + for img_data in request.Images: + img_path = self.load_image_from_base64(img_data) + if img_path: + image_paths.append(img_path) + temp_files.append(img_path) + if request.Audios: + for audio_data in request.Audios: + audio_path = self.load_audio_from_base64(audio_data) + if audio_path: + audio_paths.append(audio_path) + temp_files.append(audio_path) + return image_paths, audio_paths + + async def TokenizeString(self, request, context): + """Tokenize ``request.Prompt`` via the processor's tokenizer.""" + if not hasattr(self, "processor") or self.processor is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("processor not loaded") + return backend_pb2.TokenizationResponse() + try: + tokenizer = ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + tokens = tokenizer.encode(request.Prompt) + if hasattr(tokens, "tolist"): + tokens = tokens.tolist() + tokens = list(tokens) + return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens) + except Exception as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + return backend_pb2.TokenizationResponse() + + async def Free(self, request, context): + """Drop the loaded model, processor and tool module.""" + try: + if hasattr(self, "model"): + del self.model + if hasattr(self, "processor"): + del self.processor + if hasattr(self, "config"): + del self.config + self.tool_module = None + gc.collect() + # mlx.clear_cache (mlx >= 0.30) supersedes mlx.metal.clear_cache. + try: + if hasattr(mx, "clear_cache"): + mx.clear_cache() + elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"): + mx.metal.clear_cache() + except Exception: + pass + try: + import torch # type: ignore + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception: + pass + return backend_pb2.Result(success=True, message="MLX-VLM model freed") + except Exception as e: + return backend_pb2.Result(success=False, message=str(e)) + async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support. @@ -212,36 +301,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): """ temp_files = [] try: - # Process images and audios from request - image_paths = [] - audio_paths = [] - - # Process images - if request.Images: - for img_data in request.Images: - img_path = self.load_image_from_base64(img_data) - if img_path: - image_paths.append(img_path) - temp_files.append(img_path) - - # Process audios - if request.Audios: - for audio_data in request.Audios: - audio_path = self.load_audio_from_base64(audio_data) - if audio_path: - audio_paths.append(audio_path) - temp_files.append(audio_path) - - # Prepare the prompt with multimodal information - prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths)) - - # Build generation parameters using request attributes and options - max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512) - - print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr) - print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr) - - # Stream text generation using MLX-VLM with multimodal inputs + image_paths, audio_paths = self._collect_media(request, temp_files) + + prompt = self._prepare_prompt( + request, + num_images=len(image_paths), + num_audios=len(audio_paths), + ) + + max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params( + request, default_max_tokens=512 + ) + sampler = make_sampler(**sampler_params) + logits_processors = make_logits_processors(**logits_params) if logits_params else None + + print( + f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, " + f"images: {len(image_paths)}, audios: {len(audio_paths)}", + file=sys.stderr, + ) + + accumulated = [] + last_response = None for response in stream_generate( model=self.model, processor=self.processor, @@ -249,77 +330,91 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): image=image_paths if image_paths else None, audio=audio_paths if audio_paths else None, max_tokens=max_tokens, - temperature=generation_params.get('temp', 0.6), - top_p=generation_params.get('top_p', 1.0), + sampler=sampler, + logits_processors=logits_processors, ): - yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) - + accumulated.append(response.text) + last_response = response + yield backend_pb2.Reply( + message=bytes(response.text, encoding='utf-8'), + chat_deltas=[backend_pb2.ChatDelta(content=response.text)], + ) + if stop_words and any(s in "".join(accumulated) for s in stop_words): + break + + full_text = self._truncate_at_stop("".join(accumulated), stop_words) + content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = ( + self._finalize_output(request, full_text, last_response) + ) + yield backend_pb2.Reply( + message=b"", + prompt_tokens=prompt_tokens, + tokens=completion_tokens, + logprobs=logprobs_bytes, + chat_deltas=[ + backend_pb2.ChatDelta( + content="", + reasoning_content=reasoning_content, + tool_calls=tool_calls_proto, + ) + ], + ) + except Exception as e: print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Streaming generation failed: {str(e)}") yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) finally: - # Clean up temporary files self.cleanup_temp_files(temp_files) + def _build_template_kwargs(self, request, num_images, num_audios): + """Collect kwargs for ``apply_chat_template`` that survive model variants.""" + kwargs = {"num_images": num_images, "num_audios": num_audios} + if request.Tools: + try: + kwargs["tools"] = json.loads(request.Tools) + except json.JSONDecodeError: + pass + if request.Metadata.get("enable_thinking", "").lower() == "true": + kwargs["enable_thinking"] = True + return kwargs + + def _apply_template(self, request, messages, num_images, num_audios): + kwargs = self._build_template_kwargs(request, num_images, num_audios) + try: + return apply_chat_template(self.processor, self.config, messages, **kwargs) + except TypeError: + # Fallback for older mlx-vlm versions that reject tools=/enable_thinking= + return apply_chat_template( + self.processor, + self.config, + messages, + num_images=num_images, + num_audios=num_audios, + ) + def _prepare_prompt(self, request, num_images=0, num_audios=0): """ - Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs. - - Args: - request: The gRPC request containing prompt and message information. - num_images: Number of images in the request. - num_audios: Number of audio files in the request. - - Returns: - str: The prepared prompt. + Prepare the prompt for MLX-VLM generation, handling chat templates and + multimodal inputs. Forwards tool definitions and enable_thinking when + present on the request. """ - # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template if not request.Prompt and request.UseTokenizerTemplate and request.Messages: - # Convert gRPC messages to the format expected by apply_chat_template - messages = [] - for msg in request.Messages: - messages.append({"role": msg.role, "content": msg.content}) - - # Use mlx-vlm's apply_chat_template which handles multimodal inputs - prompt = apply_chat_template( - self.processor, - self.config, - messages, - num_images=num_images, - num_audios=num_audios - ) - return prompt - elif request.Prompt: - # If we have a direct prompt but also have images/audio, we need to format it properly + messages = messages_to_dicts(request.Messages) + return self._apply_template(request, messages, num_images, num_audios) + + if request.Prompt: if num_images > 0 or num_audios > 0: - # Create a simple message structure for multimodal prompt messages = [{"role": "user", "content": request.Prompt}] - prompt = apply_chat_template( - self.processor, - self.config, - messages, - num_images=num_images, - num_audios=num_audios - ) - return prompt - else: - return request.Prompt - else: - # Fallback to empty prompt with multimodal template if we have media - if num_images > 0 or num_audios > 0: - messages = [{"role": "user", "content": ""}] - prompt = apply_chat_template( - self.processor, - self.config, - messages, - num_images=num_images, - num_audios=num_audios - ) - return prompt - else: - return "" + return self._apply_template(request, messages, num_images, num_audios) + return request.Prompt + + # Fallback to empty prompt with multimodal template if we have media + if num_images > 0 or num_audios > 0: + messages = [{"role": "user", "content": ""}] + return self._apply_template(request, messages, num_images, num_audios) + return "" @@ -327,62 +422,122 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def _build_generation_params(self, request, default_max_tokens=200): """ - Build generation parameters from request attributes and options for MLX-VLM. - - Args: - request: The gRPC request. - default_max_tokens: Default max_tokens if not specified. + Build generation parameters from request attributes and options. Returns: - tuple: (max_tokens, generation_params dict) + tuple: (max_tokens, sampler_params, logits_params, stop_words) """ - # Extract max_tokens - max_tokens = getattr(request, 'Tokens', default_max_tokens) - if max_tokens == 0: - max_tokens = default_max_tokens - - # Extract generation parameters from request attributes - temp = getattr(request, 'Temperature', 0.0) - if temp == 0.0: - temp = 0.6 # Default temperature - - top_p = getattr(request, 'TopP', 0.0) - if top_p == 0.0: - top_p = 1.0 # Default top_p - - # Initialize generation parameters for MLX-VLM - generation_params = { + max_tokens = getattr(request, 'Tokens', default_max_tokens) or default_max_tokens + + temp = getattr(request, 'Temperature', 0.0) or 0.6 + top_p = getattr(request, 'TopP', 0.0) or 1.0 + min_p = getattr(request, 'MinP', 0.0) or 0.0 + top_k = getattr(request, 'TopK', 0) or 0 + + sampler_params = { 'temp': temp, 'top_p': top_p, + 'min_p': min_p, + 'top_k': top_k, } - - # Add seed if specified + + logits_params = {} + repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0 + if repetition_penalty and repetition_penalty != 1.0: + logits_params['repetition_penalty'] = repetition_penalty + presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0 + if presence_penalty: + logits_params['presence_penalty'] = presence_penalty + frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0 + if frequency_penalty: + logits_params['frequency_penalty'] = frequency_penalty + seed = getattr(request, 'Seed', 0) if seed != 0: mx.random.seed(seed) - - # Override with options if available + if hasattr(self, 'options'): - # Max tokens from options if 'max_tokens' in self.options: max_tokens = self.options['max_tokens'] - - # Generation parameters from options - param_option_mapping = { - 'temp': 'temp', - 'temperature': 'temp', # alias - 'top_p': 'top_p', + option_mapping = { + 'temp': 'temp', 'temperature': 'temp', + 'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k', } - - for option_key, param_key in param_option_mapping.items(): + for option_key, param_key in option_mapping.items(): if option_key in self.options: - generation_params[param_key] = self.options[option_key] - - # Handle seed from options + sampler_params[param_key] = self.options[option_key] + for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'): + if option_key in self.options: + logits_params[option_key] = self.options[option_key] if 'seed' in self.options: mx.random.seed(self.options['seed']) - - return max_tokens, generation_params + + stop_words = list(getattr(request, 'StopPrompts', []) or []) + return max_tokens, sampler_params, logits_params, stop_words + + def _finalize_output(self, request, generated_text, last_response): + """Split reasoning + tool calls out of generated_text and return the + tuple consumed by Reply-builders.""" + content = generated_text + reasoning_content = "" + + if getattr(self, "has_thinking", False): + reasoning_content, content = split_reasoning(content, self.think_start, self.think_end) + + tool_calls_proto: List[backend_pb2.ToolCallDelta] = [] + if self.tool_module is not None: + parsed_tools = None + if request.Tools: + try: + parsed_tools = json.loads(request.Tools) + except json.JSONDecodeError: + parsed_tools = None + calls, content = parse_tool_calls(content, self.tool_module, parsed_tools) + for c in calls: + tool_calls_proto.append( + backend_pb2.ToolCallDelta( + index=c["index"], + id=c["id"], + name=c["name"], + arguments=c["arguments"], + ) + ) + + prompt_tokens = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0 + completion_tokens = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0 + + logprobs_bytes = b"" + if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0: + try: + lp = getattr(last_response, "logprobs", None) + if lp is not None: + token_id = int(getattr(last_response, "token", 0) or 0) + tokenizer = ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + token_text = tokenizer.decode([token_id]) if token_id else "" + top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0 + logprobs_bytes = json.dumps( + {"content": [{"token": token_text, "logprob": top_logprob}]} + ).encode("utf-8") + except Exception as e: + print(f"[mlx-vlm] Logprobs extraction failed: {e}", file=sys.stderr) + + return content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes + + def _truncate_at_stop(self, text, stop_words): + if not stop_words: + return text + earliest = len(text) + for stop in stop_words: + if not stop: + continue + idx = text.find(stop) + if idx >= 0 and idx < earliest: + earliest = idx + return text[:earliest] if earliest < len(text) else text def load_image_from_base64(self, image_data: str): """ diff --git a/backend/python/mlx-vlm/test.py b/backend/python/mlx-vlm/test.py index 827aa71a3..96cef3fac 100644 --- a/backend/python/mlx-vlm/test.py +++ b/backend/python/mlx-vlm/test.py @@ -1,17 +1,19 @@ +import os +import sys +import types import unittest import subprocess import time + +import grpc import backend_pb2 import backend_pb2_grpc -import grpc - -import unittest -import subprocess -import time -import grpc -import backend_pb2_grpc -import backend_pb2 +# Make the shared helpers importable so we can unit-test them without a +# running gRPC server. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +from python_utils import messages_to_dicts, parse_options +from mlx_utils import parse_tool_calls, split_reasoning class TestBackendServicer(unittest.TestCase): """ @@ -143,4 +145,55 @@ class TestBackendServicer(unittest.TestCase): print(err) self.fail("Embedding service failed") finally: - self.tearDown() \ No newline at end of file + self.tearDown() + + +class TestSharedHelpers(unittest.TestCase): + """Server-less unit tests for the helpers the mlx-vlm backend depends on.""" + + def test_parse_options_typed(self): + opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello"]) + self.assertEqual(opts["temperature"], 0.7) + self.assertEqual(opts["max_tokens"], 128) + self.assertIs(opts["trust"], True) + self.assertEqual(opts["name"], "hello") + + def test_messages_to_dicts_roundtrip(self): + msgs = [ + backend_pb2.Message(role="user", content="hi"), + backend_pb2.Message( + role="assistant", + content="", + tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]', + ), + backend_pb2.Message( + role="tool", + content="42", + tool_call_id="call_1", + name="f", + ), + ] + out = messages_to_dicts(msgs) + self.assertEqual(out[0], {"role": "user", "content": "hi"}) + self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f") + self.assertEqual(out[2]["tool_call_id"], "call_1") + + def test_split_reasoning(self): + r, c = split_reasoning("planfinal", "", "") + self.assertEqual(r, "plan") + self.assertEqual(c, "final") + + def test_parse_tool_calls_with_shim(self): + tm = types.SimpleNamespace( + tool_call_start="", + tool_call_end="", + parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}}, + ) + calls, remaining = parse_tool_calls( + "Paris", + tm, + tools=None, + ) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}') \ No newline at end of file diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index 1a41020f5..a71da522c 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -2,11 +2,13 @@ import asyncio from concurrent import futures import argparse +import gc +import json import signal import sys import os +import types from typing import List -import time import backend_pb2 import backend_pb2_grpc @@ -15,13 +17,13 @@ import grpc sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) from grpc_auth import get_auth_interceptors +from python_utils import messages_to_dicts, parse_options +from mlx_utils import parse_tool_calls, split_reasoning -from mlx_lm import load, generate, stream_generate -from mlx_lm.sample_utils import make_sampler +from mlx_lm import load, stream_generate +from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache import mlx.core as mx -import base64 -import io from mlx_cache import ThreadSafeLRUPromptCache @@ -30,21 +32,6 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) -def is_float(s): - """Check if a string can be converted to float.""" - try: - float(s) - return True - except ValueError: - return False -def is_int(s): - """Check if a string can be converted to int.""" - try: - int(s) - return True - except ValueError: - return False - # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ @@ -78,46 +65,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): try: print(f"Loading MLX model: {request.Model}", file=sys.stderr) print(f"Request: {request}", file=sys.stderr) - - # Parse options like in the diffusers backend - options = request.Options - self.options = {} - - # The options are a list of strings in this form optname:optvalue - # We store all the options in a dict for later use - for opt in options: - if ":" not in opt: - continue - key, value = opt.split(":", 1) # Split only on first colon to handle values with colons - - # Convert numeric values to appropriate types - if is_float(value): - value = float(value) - elif is_int(value): - value = int(value) - elif value.lower() in ["true", "false"]: - value = value.lower() == "true" - - self.options[key] = value - + + # Parse Options[] key:value strings into a typed dict (shared helper) + self.options = parse_options(request.Options) print(f"Options: {self.options}", file=sys.stderr) - + # Build tokenizer config for MLX using options tokenizer_config = {} - + # Handle trust_remote_code from request or options if request.TrustRemoteCode or self.options.get("trust_remote_code", False): tokenizer_config["trust_remote_code"] = True - + # Handle EOS token from options if "eos_token" in self.options: tokenizer_config["eos_token"] = self.options["eos_token"] - + # Handle other tokenizer config options for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]: if key in self.options: tokenizer_config[key] = self.options[key] - + # Load model and tokenizer using MLX if tokenizer_config: print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr) @@ -125,6 +93,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): else: self.model, self.tokenizer = load(request.Model) + # mlx_lm.load() returns a TokenizerWrapper that detects tool + # calling and thinking markers from the chat template / vocab. + # mlx-lm >= 0.30 also exposes a parser callable on the wrapper; + # earlier versions don't (we fall back to json.loads inside + # _tool_module_from_tokenizer below). + has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False)) + has_thinking = bool(getattr(self.tokenizer, "has_thinking", False)) + tcs = getattr(self.tokenizer, "tool_call_start", None) + tce = getattr(self.tokenizer, "tool_call_end", None) + print( + f"MLX tokenizer capabilities: has_tool_calling={has_tools} " + f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}", + file=sys.stderr, + ) + # Initialize thread-safe LRU prompt cache for efficient generation max_cache_entries = self.options.get("max_cache_entries", 10) self.max_kv_size = self.options.get("max_kv_size", None) @@ -134,7 +117,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): can_trim_fn=can_trim_prompt_cache, trim_fn=trim_prompt_cache, ) - + except Exception as err: print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}") @@ -172,30 +155,58 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): remaining_tokens = cache_key # Build generation parameters using request attributes and options - max_tokens, sampler_params = self._build_generation_params(request) + max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request) - print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) + print( + f"Generating text with MLX - max_tokens: {max_tokens}, " + f"cache_hit: {len(remaining_tokens) < len(cache_key)}", + file=sys.stderr, + ) - # Create sampler with parameters + # Create sampler and optional logits processors (penalties) sampler = make_sampler(**sampler_params) + logits_processors = make_logits_processors(**logits_params) if logits_params else None - # Use stream_generate to track generated tokens for cache key + # Use stream_generate to collect text + track tokens for cache key generated_text = [] + last_response = None for response in stream_generate( self.model, self.tokenizer, prompt=remaining_tokens if remaining_tokens else cache_key, max_tokens=max_tokens, sampler=sampler, + logits_processors=logits_processors, prompt_cache=prompt_cache, ): generated_text.append(response.text) cache_key.append(response.token) + last_response = response + # Early stop on user-provided stop sequences + if stop_words and any(s in "".join(generated_text) for s in stop_words): + break # Insert completed cache self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) - return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8')) + full_text = self._truncate_at_stop("".join(generated_text), stop_words) + content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = ( + self._finalize_output(request, full_text, last_response) + ) + + return backend_pb2.Reply( + message=bytes(content, encoding='utf-8'), + prompt_tokens=prompt_tokens, + tokens=completion_tokens, + logprobs=logprobs_bytes, + chat_deltas=[ + backend_pb2.ChatDelta( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls_proto, + ) + ], + ) except Exception as e: print(f"Error in MLX Predict: {e}", file=sys.stderr) @@ -206,7 +217,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def Embedding(self, request, context): """ A gRPC method that calculates embeddings for a given sentence. - + Note: MLX-LM doesn't support embeddings directly. This method returns an error. Args: @@ -221,6 +232,62 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): context.set_details("Embeddings are not supported in the MLX backend.") return backend_pb2.EmbeddingResult() + async def TokenizeString(self, request, context): + """Tokenize ``request.Prompt`` using the loaded model's tokenizer.""" + if not hasattr(self, "tokenizer") or self.tokenizer is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("tokenizer not loaded") + return backend_pb2.TokenizationResponse() + try: + tokens = self.tokenizer.encode(request.Prompt) + if hasattr(tokens, "tolist"): + tokens = tokens.tolist() + tokens = list(tokens) + return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens) + except Exception as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + return backend_pb2.TokenizationResponse() + + async def Free(self, request, context): + """Drop the loaded model, tokenizer and prompt cache. + + Metal / CUDA memory is released via ``gc.collect()`` + the + platform-specific cache clear hooks when available. + """ + try: + if hasattr(self, "model"): + del self.model + if hasattr(self, "tokenizer"): + del self.tokenizer + if hasattr(self, "lru_cache") and self.lru_cache is not None: + try: + self.lru_cache.clear() + except Exception: + pass + self.lru_cache = None + gc.collect() + # Metal: drop the cached allocator. mlx.clear_cache (mlx >= 0.30) + # supersedes the now-deprecated mlx.metal.clear_cache. + try: + if hasattr(mx, "clear_cache"): + mx.clear_cache() + elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"): + mx.metal.clear_cache() + except Exception: + pass + # CUDA: release the torch cache if a CUDA-backed mlx variant + # happens to be installed alongside torch (best-effort). + try: + import torch # type: ignore + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception: + pass + return backend_pb2.Result(success=True, message="MLX model freed") + except Exception as e: + return backend_pb2.Result(success=False, message=str(e)) + async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results using MLX. @@ -251,24 +318,64 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): remaining_tokens = cache_key # Build generation parameters using request attributes and options - max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) + max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params( + request, default_max_tokens=512 + ) - print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) + print( + f"Streaming text with MLX - max_tokens: {max_tokens}, " + f"cache_hit: {len(remaining_tokens) < len(cache_key)}", + file=sys.stderr, + ) - # Create sampler with parameters + # Create sampler and optional logits processors (penalties) sampler = make_sampler(**sampler_params) + logits_processors = make_logits_processors(**logits_params) if logits_params else None - # Stream text generation using MLX with proper parameters + accumulated = [] + last_response = None for response in stream_generate( self.model, self.tokenizer, prompt=remaining_tokens if remaining_tokens else cache_key, max_tokens=max_tokens, sampler=sampler, + logits_processors=logits_processors, prompt_cache=prompt_cache, ): cache_key.append(response.token) - yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) + accumulated.append(response.text) + last_response = response + # Emit a content delta. Structured reasoning / tool parsing + # happens on the final chunk so we don't fragment the state + # machine in v1. + yield backend_pb2.Reply( + message=bytes(response.text, encoding='utf-8'), + chat_deltas=[backend_pb2.ChatDelta(content=response.text)], + ) + # Early stop on user-provided stop sequences + if stop_words and any(s in "".join(accumulated) for s in stop_words): + break + + # Final chunk: run reasoning + tool parsing on accumulated text + # and emit the structured ChatDelta with token counts + logprobs. + full_text = self._truncate_at_stop("".join(accumulated), stop_words) + content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = ( + self._finalize_output(request, full_text, last_response) + ) + yield backend_pb2.Reply( + message=b"", + prompt_tokens=prompt_tokens, + tokens=completion_tokens, + logprobs=logprobs_bytes, + chat_deltas=[ + backend_pb2.ChatDelta( + content="", + reasoning_content=reasoning_content, + tool_calls=tool_calls_proto, + ) + ], + ) except Exception as e: print(f"Error in MLX PredictStream: {e}", file=sys.stderr) @@ -294,21 +401,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: str: The prepared prompt. """ - # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template + # If tokenizer template is enabled and messages are provided instead + # of prompt, apply the tokenizer template (forwards tool definitions + # and enable_thinking when the model supports them). if not request.Prompt and request.UseTokenizerTemplate and request.Messages: - # Convert gRPC messages to the format expected by apply_chat_template - messages = [] - for msg in request.Messages: - messages.append({"role": msg.role, "content": msg.content}) + messages = messages_to_dicts(request.Messages) - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - return prompt - else: - return request.Prompt + kwargs = {"tokenize": False, "add_generation_prompt": True} + if request.Tools: + try: + kwargs["tools"] = json.loads(request.Tools) + except json.JSONDecodeError: + pass + enable_thinking = request.Metadata.get("enable_thinking", "").lower() + if enable_thinking == "true": + kwargs["enable_thinking"] = True + + try: + return self.tokenizer.apply_chat_template(messages, **kwargs) + except TypeError: + # Fallback for tokenizers whose template doesn't accept + # tools= or enable_thinking=. + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + return request.Prompt def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: """ @@ -338,18 +457,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): default_max_tokens: Default max_tokens if not specified. Returns: - tuple: (max_tokens, sampler_params dict) + tuple: (max_tokens, sampler_params dict, logits_processor_params dict, + stop_words list) """ # Extract max_tokens max_tokens = getattr(request, 'Tokens', default_max_tokens) if max_tokens == 0: max_tokens = default_max_tokens - + # Extract sampler parameters from request attributes temp = getattr(request, 'Temperature', 0.0) if temp == 0.0: temp = 0.6 # Default temperature - + top_p = getattr(request, 'TopP', 0.0) if top_p == 0.0: top_p = 1.0 # Default top_p @@ -369,18 +489,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): 'xtc_threshold': 0.0, 'xtc_probability': 0.0, } - + + # Logits processor parameters — only set fields the request actually + # provides so we can feed them unconditionally to make_logits_processors. + logits_params = {} + repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0 + if repetition_penalty and repetition_penalty != 1.0: + logits_params['repetition_penalty'] = repetition_penalty + presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0 + if presence_penalty: + logits_params['presence_penalty'] = presence_penalty + frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0 + if frequency_penalty: + logits_params['frequency_penalty'] = frequency_penalty + # Add seed if specified seed = getattr(request, 'Seed', 0) if seed != 0: mx.random.seed(seed) - + # Override with options if available if hasattr(self, 'options'): # Max tokens from options if 'max_tokens' in self.options: max_tokens = self.options['max_tokens'] - + # Sampler parameters from options sampler_option_mapping = { 'temp': 'temp', @@ -391,32 +524,142 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): 'xtc_threshold': 'xtc_threshold', 'xtc_probability': 'xtc_probability', } - + for option_key, param_key in sampler_option_mapping.items(): if option_key in self.options: sampler_params[param_key] = self.options[option_key] - + + # Logits processor overrides + for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'): + if option_key in self.options: + logits_params[option_key] = self.options[option_key] + # Handle seed from options if 'seed' in self.options: mx.random.seed(self.options['seed']) - + # Special tokens for XTC sampling (if tokenizer has eos_token_ids) xtc_special_tokens = [] if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids: xtc_special_tokens = list(self.tokenizer.eos_token_ids) elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: xtc_special_tokens = [self.tokenizer.eos_token_id] - + # Add newline token if available try: newline_tokens = self.tokenizer.encode("\n") xtc_special_tokens.extend(newline_tokens) - except: + except Exception: pass # Skip if encoding fails - + sampler_params['xtc_special_tokens'] = xtc_special_tokens - - return max_tokens, sampler_params + + # Stop sequences are applied post-decode (mlx-lm doesn't have a + # built-in stop-sequence sampler param). Preserve the list here. + stop_words = list(getattr(request, 'StopPrompts', []) or []) + + return max_tokens, sampler_params, logits_params, stop_words + + def _tool_module_from_tokenizer(self): + """Build a duck-typed tool module from the TokenizerWrapper. + + On mlx-lm >= 0.30 the wrapper exposes a ``tool_parser`` callable + that's been resolved from the model's chat template. On older + releases (e.g. 0.29.x) the wrapper only carries the start/end + markers — fall back to ``json.loads`` of the body, which matches + what ``mlx_lm.tool_parsers.json_tools.parse_tool_call`` does on + HEAD and covers the only format 0.29 detects (````). + """ + start = getattr(self.tokenizer, "tool_call_start", None) + end = getattr(self.tokenizer, "tool_call_end", None) + if not start: + return None + parse_fn = getattr(self.tokenizer, "tool_parser", None) + if parse_fn is None: + def parse_fn(body, tools): # noqa: E306 — local fallback + return json.loads(body.strip()) + return types.SimpleNamespace( + tool_call_start=start, + tool_call_end=end or "", + parse_tool_call=parse_fn, + ) + + def _finalize_output(self, request, generated_text, last_response): + """Build a ChatDelta + token counts + logprobs from accumulated output. + + Returns ``(content, reasoning_content, tool_calls_proto, + prompt_token_count, completion_token_count, logprobs_bytes)``. + """ + content = generated_text + reasoning_content = "" + + if getattr(self.tokenizer, "has_thinking", False): + think_start = getattr(self.tokenizer, "think_start", "") or "" + think_end = getattr(self.tokenizer, "think_end", "") or "" + reasoning_content, content = split_reasoning(content, think_start, think_end) + + tool_calls_proto: List[backend_pb2.ToolCallDelta] = [] + tool_module = None + if getattr(self.tokenizer, "has_tool_calling", False): + tool_module = self._tool_module_from_tokenizer() + if tool_module is not None: + parsed_tools = None + if request.Tools: + try: + parsed_tools = json.loads(request.Tools) + except json.JSONDecodeError: + parsed_tools = None + calls, content = parse_tool_calls(content, tool_module, parsed_tools) + for c in calls: + tool_calls_proto.append( + backend_pb2.ToolCallDelta( + index=c["index"], + id=c["id"], + name=c["name"], + arguments=c["arguments"], + ) + ) + + prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0 + completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0 + + logprobs_bytes = b"" + # Logprobs extraction — only when the request asked for them. + if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0: + try: + lp = getattr(last_response, "logprobs", None) + if lp is not None: + # GenerationResponse.logprobs on the last chunk is the + # logprob distribution of the final token. Without a + # per-token history we at minimum surface the last token's + # top-1 logprob so clients get a non-empty field. + token_id = int(getattr(last_response, "token", 0) or 0) + token_text = self.tokenizer.decode([token_id]) if token_id else "" + top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0 + logprobs_bytes = json.dumps( + { + "content": [ + {"token": token_text, "logprob": top_logprob} + ] + } + ).encode("utf-8") + except Exception as e: + print(f"[mlx] Logprobs extraction failed: {e}", file=sys.stderr) + + return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes + + def _truncate_at_stop(self, text, stop_words): + """Truncate ``text`` at the first occurrence of any stop sequence.""" + if not stop_words: + return text + earliest = len(text) + for stop in stop_words: + if not stop: + continue + idx = text.find(stop) + if idx >= 0 and idx < earliest: + earliest = idx + return text[:earliest] if earliest < len(text) else text async def serve(address): # Start asyncio gRPC server diff --git a/backend/python/mlx/test.py b/backend/python/mlx/test.py index 53d7bc7ec..ac5ff4c06 100644 --- a/backend/python/mlx/test.py +++ b/backend/python/mlx/test.py @@ -1,11 +1,20 @@ +import os +import sys import unittest import subprocess import time +import types import grpc import backend_pb2 import backend_pb2_grpc +# Make the shared helpers importable so we can unit-test them without a +# running gRPC server. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +from python_utils import messages_to_dicts, parse_options +from mlx_utils import parse_tool_calls, split_reasoning + class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service. @@ -231,4 +240,104 @@ class TestBackendServicer(unittest.TestCase): self.tearDown() + def test_tokenize_string(self): + """TokenizeString should return a non-empty token list for a known prompt.""" + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel( + backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit") + ) + self.assertTrue(response.success) + resp = stub.TokenizeString(backend_pb2.PredictOptions(Prompt="Hello, world")) + self.assertGreater(resp.length, 0) + self.assertEqual(len(list(resp.tokens)), resp.length) + except Exception as err: + print(err) + self.fail("TokenizeString service failed") + finally: + self.tearDown() + + def test_free(self): + """Free should release the model and not crash on subsequent calls.""" + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel( + backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit") + ) + self.assertTrue(response.success) + free_resp = stub.Free(backend_pb2.HealthMessage()) + self.assertTrue(free_resp.success) + except Exception as err: + print(err) + self.fail("Free service failed") + finally: + self.tearDown() + + +class TestSharedHelpers(unittest.TestCase): + """Server-less unit tests for the helpers the mlx backend depends on.""" + + def test_parse_options_typed(self): + opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello", "no_colon_skipped"]) + self.assertEqual(opts["temperature"], 0.7) + self.assertEqual(opts["max_tokens"], 128) + self.assertIs(opts["trust"], True) + self.assertEqual(opts["name"], "hello") + self.assertNotIn("no_colon_skipped", opts) + + def test_messages_to_dicts_roundtrip(self): + # Build proto Message objects (via backend_pb2 to match real gRPC) + msgs = [ + backend_pb2.Message(role="user", content="hi"), + backend_pb2.Message( + role="assistant", + content="", + tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]', + ), + backend_pb2.Message( + role="tool", + content="42", + tool_call_id="call_1", + name="f", + ), + ] + out = messages_to_dicts(msgs) + self.assertEqual(out[0], {"role": "user", "content": "hi"}) + self.assertEqual(out[1]["role"], "assistant") + self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f") + self.assertEqual(out[2]["tool_call_id"], "call_1") + self.assertEqual(out[2]["name"], "f") + + def test_split_reasoning(self): + r, c = split_reasoning("step 1\nstep 2The answer is 42.", "", "") + self.assertEqual(r, "step 1\nstep 2") + self.assertEqual(c, "The answer is 42.") + + def test_split_reasoning_no_marker(self): + r, c = split_reasoning("just text", "", "") + self.assertEqual(r, "") + self.assertEqual(c, "just text") + + def test_parse_tool_calls_with_shim(self): + tm = types.SimpleNamespace( + tool_call_start="", + tool_call_end="", + parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}}, + ) + calls, remaining = parse_tool_calls( + "Sure: Paris", + tm, + tools=None, + ) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}') + self.assertEqual(calls[0]["index"], 0) + self.assertNotIn("", remaining) + + # Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py \ No newline at end of file diff --git a/core/gallery/importers/mlx.go b/core/gallery/importers/mlx.go index 7ab513f6d..075c58cf7 100644 --- a/core/gallery/importers/mlx.go +++ b/core/gallery/importers/mlx.go @@ -84,6 +84,20 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) { // Apply per-model-family inference parameter defaults config.ApplyInferenceDefaults(&modelConfig, details.URI) + // Auto-set tool_parser / reasoning_parser from parser_defaults.json so + // the generated YAML mirrors what the vllm importer produces. The mlx + // backends auto-detect parsers from the chat template at runtime and + // ignore these Options, but surfacing them in the config keeps the two + // paths consistent and gives users a single place to override. + if parsers := config.MatchParserDefaults(details.URI); parsers != nil { + if tp, ok := parsers["tool_parser"]; ok { + modelConfig.Options = append(modelConfig.Options, "tool_parser:"+tp) + } + if rp, ok := parsers["reasoning_parser"]; ok { + modelConfig.Options = append(modelConfig.Options, "reasoning_parser:"+rp) + } + } + data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err