mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-16 12:59:33 -04:00
feat: refactor shared helpers and enhance MLX backend functionality (#9335)
* refactor(backends): extract python_utils + add mlx_utils shared helpers
Move parse_options() and messages_to_dicts() out of vllm_utils.py into a
new framework-agnostic python_utils.py, and re-export them from vllm_utils
so existing vllm / vllm-omni imports keep working.
Add mlx_utils.py with split_reasoning() and parse_tool_calls() — ported
from mlx_vlm/server.py's process_tool_calls. These work with any
mlx-lm / mlx-vlm tool module (anything exposing tool_call_start,
tool_call_end, parse_tool_call). Used by the mlx and mlx-vlm backends in
later commits to emit structured ChatDelta.tool_calls without
reimplementing per-model parsing.
Shared smoke tests confirm:
- parse_options round-trips bool/int/float/string
- vllm_utils re-exports are identity-equal to python_utils originals
- mlx_utils parse_tool_calls handles <tool_call>...</tool_call> with a
shim module and produces a correctly-indexed list with JSON arguments
- mlx_utils split_reasoning extracts <think> blocks and leaves clean
content
* feat(mlx): wire native tool parsers + ChatDelta + token usage + logprobs
Bring the MLX backend up to the same structured-output contract as vLLM
and llama.cpp: emit Reply.chat_deltas so the OpenAI HTTP layer sees
tool_calls and reasoning_content, not just raw text.
Key insight: mlx_lm.load() returns a TokenizerWrapper that already auto-
detects the right tool parser from the model's chat template
(_infer_tool_parser in mlx_lm/tokenizer_utils.py). The wrapper exposes
has_tool_calling, has_thinking, tool_parser, tool_call_start,
tool_call_end, think_start, think_end — no user configuration needed,
unlike vLLM.
Changes in backend/python/mlx/backend.py:
- Imports: replace inline parse_options / messages_to_dicts with the
shared helpers from python_utils. Pull split_reasoning / parse_tool_calls
from the new mlx_utils shared module.
- LoadModel: log the auto-detected has_tool_calling / has_thinking /
tool_parser_type for observability. Drop the local is_float / is_int
duplicates.
- _prepare_prompt: run request.Messages through messages_to_dicts so
tool_call_id / tool_calls / reasoning_content survive the conversion,
and pass tools=json.loads(request.Tools) + enable_thinking=True (when
request.Metadata says so) to apply_chat_template. Falls back on
TypeError for tokenizers whose template doesn't accept those kwargs.
- _build_generation_params: return an additional (logits_params,
stop_words) pair. Maps RepetitionPenalty / PresencePenalty /
FrequencyPenalty to mlx_lm.sample_utils.make_logits_processors and
threads StopPrompts through to post-decode truncation.
- New _tool_module_from_tokenizer / _finalize_output / _truncate_at_stop
helpers. _finalize_output runs split_reasoning when has_thinking is
true and parse_tool_calls (using a SimpleNamespace shim around the
wrapper's tool_parser callable) when has_tool_calling is true, then
extracts prompt_tokens, generation_tokens and (best-effort) logprobs
from the last GenerationResponse chunk.
- Predict: use make_logits_processors, accumulate text + last_response,
finalize into a structured Reply carrying chat_deltas,
prompt_tokens, tokens, logprobs. Early-stops on user stop sequences.
- PredictStream: per-chunk Reply still carries raw message bytes for
back-compat but now also emits chat_deltas=[ChatDelta(content=delta)].
On loop exit, emit a terminal Reply with structured
reasoning_content / tool_calls / token counts / logprobs — so the Go
side sees tool calls without needing the regex fallback.
- TokenizeString RPC: uses the TokenizerWrapper's encode(); returns
length + tokens or FAILED_PRECONDITION if the model isn't loaded.
- Free RPC: drops model / tokenizer / lru_cache, runs gc.collect(),
calls mx.metal.clear_cache() when available, and best-effort clears
torch.cuda as a belt-and-suspenders.
* feat(mlx-vlm): mirror MLX parity (tool parsers + ChatDelta + samplers)
Same treatment as the MLX backend: emit structured Reply.chat_deltas,
tool_calls, reasoning_content, token counts and logprobs, and extend
sampling parameter coverage beyond the temp/top_p pair the backend
used to handle.
- Imports: drop the inline is_float/is_int helpers, pull parse_options /
messages_to_dicts from python_utils and split_reasoning /
parse_tool_calls from mlx_utils. Also import make_sampler and
make_logits_processors from mlx_lm.sample_utils — mlx-vlm re-uses them.
- LoadModel: use parse_options; call mlx_vlm.tool_parsers._infer_tool_parser
/ load_tool_module to auto-detect a tool module from the processor's
chat_template. Stash think_start / think_end / has_thinking so later
finalisation can split reasoning blocks without duck-typing on each
call. Logs the detected parser type.
- _prepare_prompt: convert proto Messages via messages_to_dicts (so
tool_call_id / tool_calls survive), pass tools=json.loads(request.Tools)
and enable_thinking=True to apply_chat_template when present, fall
back on TypeError for older mlx-vlm versions. Also handle the
prompt-only + media and empty-prompt + media paths consistently.
- _build_generation_params: return (max_tokens, sampler_params,
logits_params, stop_words). Maps repetition_penalty / presence_penalty /
frequency_penalty and passes them through make_logits_processors.
- _finalize_output / _truncate_at_stop: common helper used by Predict
and PredictStream to split reasoning, run parse_tool_calls against the
auto-detected tool module, build ToolCallDelta list, and extract token
counts + logprobs from the last GenerationResult.
- Predict / PredictStream: switch from mlx_vlm.generate to mlx_vlm.stream_generate
in both paths, accumulate text + last_response, pass sampler and
logits_processors through, emit content-only ChatDelta per streaming
chunk followed by a terminal Reply carrying reasoning_content,
tool_calls, prompt_tokens, tokens and logprobs. Non-streaming Predict
returns the same structured Reply shape.
- New helper _collect_media extracted from the duplicated base64 image /
audio decode loop.
- New TokenizeString RPC using the processor's tokenizer.encode and
Free RPC that drops model/processor/config, runs gc + Metal cache
clear + best-effort torch.cuda cache clear.
* feat(importer/mlx): auto-set tool_parser/reasoning_parser on import
Mirror what core/gallery/importers/vllm.go does: after applying the
shared inference defaults, look up the model URI in parser_defaults.json
and append matching tool_parser:/reasoning_parser: entries to Options.
The MLX backends auto-detect tool parsers from the chat template at
runtime so they don't actually consume these options — but surfacing
them in the generated YAML:
- keeps the import experience consistent with vllm
- gives users a single visible place to override
- documents the intended parser for a given model family
* test(mlx): add helper unit tests + TokenizeString/Free + e2e make targets
- backend/python/mlx/test.py: add TestSharedHelpers with server-less
unit tests for parse_options, messages_to_dicts, split_reasoning and
parse_tool_calls (using a SimpleNamespace shim to fake a tool module
without requiring a model). Plus test_tokenize_string and test_free
RPC tests that load a tiny MLX-quantized Llama and exercise the new
RPCs end-to-end.
- backend/python/mlx-vlm/test.py: same helper unit tests + cleanup of
the duplicated import block at the top of the file.
- Makefile: register BACKEND_MLX and BACKEND_MLX_VLM (they were missing
from the docker-build-target eval list — only mlx-distributed had a
generated target before). Add test-extra-backend-mlx and
test-extra-backend-mlx-vlm convenience targets that build the
respective image and run tests/e2e-backends with the tools capability
against mlx-community/Qwen2.5-0.5B-Instruct-4bit. The MLX backend
auto-detects the tool parser from the chat template so no
BACKEND_TEST_OPTIONS is needed (unlike vllm).
* fix(libbackend): don't pass --copies to venv unless PORTABLE_PYTHON=true
backend/python/common/libbackend.sh:ensureVenv() always invoked
'python -m venv --copies', but macOS system python (and some other
builds) refuses with:
Error: This build of python cannot create venvs without using symlinks
--copies only matters when _makeVenvPortable later relocates the venv,
which only happens when PORTABLE_PYTHON=true. Make --copies conditional
on that flag and fall back to default (symlinked) venv otherwise.
Caught while bringing up the mlx backend on Apple Silicon — the same
build path is used by every Python backend with USE_PIP=true.
* fix(mlx): support mlx-lm 0.29.x tool calling + drop deprecated clear_cache
The released mlx-lm 0.29.x ships a much simpler tool-calling API than
HEAD: TokenizerWrapper detects the <tool_call>...</tool_call> markers
from the tokenizer vocab and exposes has_tool_calling /
tool_call_start / tool_call_end, but does NOT expose a tool_parser
callable on the wrapper and does NOT ship a mlx_lm.tool_parsers
subpackage at all (those only exist on main).
Caught while running the smoke test on Apple Silicon with the
released mlx-lm 0.29.1: tokenizer.tool_parser raised AttributeError
(falling through to the underlying HF tokenizer), so
_tool_module_from_tokenizer always returned None and tool calls slipped
through as raw <tool_call>...</tool_call> text in Reply.message instead
of being parsed into ChatDelta.tool_calls.
Fix: when has_tool_calling is True but tokenizer.tool_parser is missing,
default the parse_tool_call callable to json.loads(body.strip()) — that's
exactly what mlx_lm.tool_parsers.json_tools.parse_tool_call does on HEAD
and covers the only format 0.29 detects (<tool_call>JSON</tool_call>).
Future mlx-lm releases that ship more parsers will be picked up
automatically via the tokenizer.tool_parser attribute when present.
Also tighten the LoadModel logging — the old log line read
init_kwargs.get('tool_parser_type') which doesn't exist on 0.29 and
showed None even when has_tool_calling was True. Log the actual
tool_call_start / tool_call_end markers instead.
While here, switch Free()'s Metal cache clear from the deprecated
mx.metal.clear_cache to mx.clear_cache (mlx >= 0.30), with a
fallback for older releases. Mirrored to the mlx-vlm backend.
* feat(mlx-distributed): mirror MLX parity (tool calls + ChatDelta + sampler)
Same treatment as the mlx and mlx-vlm backends: emit Reply.chat_deltas
with structured tool_calls / reasoning_content / token counts /
logprobs, expand sampling parameter coverage beyond temp+top_p, and
add the missing TokenizeString and Free RPCs.
Notes specific to mlx-distributed:
- Rank 0 is the only rank that owns a sampler — workers participate in
the pipeline-parallel forward pass via mx.distributed and don't
re-implement sampling. So the new logits_params (repetition_penalty,
presence_penalty, frequency_penalty) and stop_words apply on rank 0
only; we don't need to extend coordinator.broadcast_generation_params,
which still ships only max_tokens / temperature / top_p to workers
(everything else is a rank-0 concern).
- Free() now broadcasts CMD_SHUTDOWN to workers when a coordinator is
active, so they release the model on their end too. The constant is
already defined and handled by the existing worker loop in
backend.py:633 (CMD_SHUTDOWN = -1).
- Drop the locally-defined is_float / is_int / parse_options trio in
favor of python_utils.parse_options, re-exported under the module
name for back-compat with anything that imported it directly.
- _prepare_prompt: route through messages_to_dicts so tool_call_id /
tool_calls / reasoning_content survive, pass tools=json.loads(
request.Tools) and enable_thinking=True to apply_chat_template, fall
back on TypeError for templates that don't accept those kwargs.
- New _tool_module_from_tokenizer (with the json.loads fallback for
mlx-lm 0.29.x), _finalize_output, _truncate_at_stop helpers — same
contract as the mlx backend.
- LoadModel logs the auto-detected has_tool_calling / has_thinking /
tool_call_start / tool_call_end so users can see what the wrapper
picked up for the loaded model.
- backend/python/mlx-distributed/test.py: add the same TestSharedHelpers
unit tests (parse_options, messages_to_dicts, split_reasoning,
parse_tool_calls) that exist for mlx and mlx-vlm.
This commit is contained in:
committed by
GitHub
parent
daa0272f2e
commit
016da02845
20
Makefile
20
Makefile
@@ -519,6 +519,22 @@ test-extra-backend-vllm: docker-build-vllm
|
||||
BACKEND_TEST_OPTIONS=tool_parser:hermes \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
## mlx is Apple-Silicon-first — the MLX backend auto-detects the right tool
|
||||
## parser from the chat template, so no tool_parser: option is needed (it
|
||||
## would be ignored at runtime). Run this on macOS / arm64 with Metal; the
|
||||
## Linux/CPU mlx variant is untested in CI.
|
||||
test-extra-backend-mlx: docker-build-mlx
|
||||
BACKEND_IMAGE=local-ai-backend:mlx \
|
||||
BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \
|
||||
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
test-extra-backend-mlx-vlm: docker-build-mlx-vlm
|
||||
BACKEND_IMAGE=local-ai-backend:mlx-vlm \
|
||||
BACKEND_TEST_MODEL_NAME=mlx-community/Qwen2.5-0.5B-Instruct-4bit \
|
||||
BACKEND_TEST_CAPS=health,load,predict,stream,tools \
|
||||
$(MAKE) test-extra-backend
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
IMAGE_TYPE?=core
|
||||
BASE_IMAGE?=ubuntu:24.04
|
||||
@@ -652,6 +668,8 @@ BACKEND_NEMO = nemo|python|.|false|true
|
||||
BACKEND_VOXCPM = voxcpm|python|.|false|true
|
||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
BACKEND_MLX = mlx|python|.|false|true
|
||||
BACKEND_MLX_VLM = mlx-vlm|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||
@@ -720,6 +738,8 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_QWEN3_TTS_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_VLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
|
||||
@@ -344,7 +344,16 @@ function ensureVenv() {
|
||||
|
||||
if [ ! -d "${EDIR}/venv" ]; then
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
"${interpreter}" -m venv --copies "${EDIR}/venv"
|
||||
# --copies is only needed when we will later relocate the venv via
|
||||
# _makeVenvPortable (PORTABLE_PYTHON=true). Some Python builds —
|
||||
# notably macOS system Python — refuse to create a venv with
|
||||
# --copies because the build doesn't support it. Fall back to
|
||||
# symlinks in that case.
|
||||
local venv_args=""
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
venv_args="--copies"
|
||||
fi
|
||||
"${interpreter}" -m venv ${venv_args} "${EDIR}/venv"
|
||||
source "${EDIR}/venv/bin/activate"
|
||||
"${interpreter}" -m pip install --upgrade pip
|
||||
else
|
||||
|
||||
100
backend/python/common/mlx_utils.py
Normal file
100
backend/python/common/mlx_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Shared utilities for the mlx and mlx-vlm gRPC backends.
|
||||
|
||||
These helpers wrap mlx-lm's and mlx-vlm's native tool-parser modules, which
|
||||
auto-detect the right parser from the model's chat template. Each tool
|
||||
module exposes ``tool_call_start``, ``tool_call_end`` and
|
||||
``parse_tool_call(text, tools) -> dict | list[dict]``.
|
||||
|
||||
The split-reasoning helper is generic enough to work with any think-start /
|
||||
think-end delimiter pair.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
|
||||
def split_reasoning(text, think_start, think_end):
|
||||
"""Split ``<think>...</think>`` blocks out of ``text``.
|
||||
|
||||
Returns ``(reasoning_content, remaining_text)``. When ``think_start`` is
|
||||
empty or not found, returns ``("", text)`` unchanged.
|
||||
"""
|
||||
if not think_start or not text or think_start not in text:
|
||||
return "", text
|
||||
pattern = re.compile(
|
||||
re.escape(think_start) + r"(.*?)" + re.escape(think_end or ""),
|
||||
re.DOTALL,
|
||||
)
|
||||
reasoning_parts = pattern.findall(text)
|
||||
if not reasoning_parts:
|
||||
return "", text
|
||||
remaining = pattern.sub("", text).strip()
|
||||
return "\n".join(p.strip() for p in reasoning_parts), remaining
|
||||
|
||||
|
||||
def parse_tool_calls(text, tool_module, tools):
|
||||
"""Extract tool calls from ``text`` using a mlx-lm tool module.
|
||||
|
||||
Ports the ``process_tool_calls`` logic from
|
||||
``mlx_vlm/server.py`` (v0.10 onwards). ``tool_module`` must expose
|
||||
``tool_call_start``, ``tool_call_end`` and ``parse_tool_call``.
|
||||
|
||||
Returns ``(calls, remaining_text)`` where ``calls`` is a list of dicts:
|
||||
|
||||
[{"index": int, "id": str, "name": str, "arguments": str (JSON)}]
|
||||
|
||||
and ``remaining_text`` is the free-form text with the tool call blocks
|
||||
removed. ``(calls, text)`` is returned unchanged if ``tool_module`` is
|
||||
``None`` or the start delimiter isn't present.
|
||||
"""
|
||||
if tool_module is None or not text:
|
||||
return [], text
|
||||
start = getattr(tool_module, "tool_call_start", None)
|
||||
end = getattr(tool_module, "tool_call_end", None)
|
||||
parse_fn = getattr(tool_module, "parse_tool_call", None)
|
||||
if not start or parse_fn is None or start not in text:
|
||||
return [], text
|
||||
|
||||
if end == "" or end is None:
|
||||
pattern = re.compile(
|
||||
re.escape(start) + r".*?(?:\n|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
else:
|
||||
pattern = re.compile(
|
||||
re.escape(start) + r".*?" + re.escape(end),
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
matches = pattern.findall(text)
|
||||
if not matches:
|
||||
return [], text
|
||||
|
||||
remaining = pattern.sub(" ", text).strip()
|
||||
calls = []
|
||||
for match in matches:
|
||||
call_body = match.strip().removeprefix(start)
|
||||
if end:
|
||||
call_body = call_body.removesuffix(end)
|
||||
call_body = call_body.strip()
|
||||
try:
|
||||
parsed = parse_fn(call_body, tools)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[mlx_utils] Invalid tool call: {call_body!r} ({e})",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
if not isinstance(parsed, list):
|
||||
parsed = [parsed]
|
||||
for tc in parsed:
|
||||
calls.append(
|
||||
{
|
||||
"index": len(calls),
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": (tc.get("name") or "").strip(),
|
||||
"arguments": json.dumps(tc.get("arguments", {}), ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
return calls, remaining
|
||||
65
backend/python/common/python_utils.py
Normal file
65
backend/python/common/python_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Generic utilities shared across Python gRPC backends.
|
||||
|
||||
These helpers don't depend on any specific inference framework and can be
|
||||
imported by any backend that needs to parse LocalAI gRPC options or build a
|
||||
chat-template-compatible message list from proto Message objects.
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
def parse_options(options_list):
|
||||
"""Parse Options[] list of ``key:value`` strings into a dict.
|
||||
|
||||
Supports type inference for common cases (bool, int, float). Unknown or
|
||||
mixed-case values are returned as strings.
|
||||
|
||||
Used by LoadModel to extract backend-specific options passed via
|
||||
``ModelOptions.Options`` in ``backend.proto``.
|
||||
"""
|
||||
opts = {}
|
||||
for opt in options_list:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
# Try type conversion
|
||||
if value.lower() in ("true", "false"):
|
||||
opts[key] = value.lower() == "true"
|
||||
else:
|
||||
try:
|
||||
opts[key] = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
opts[key] = float(value)
|
||||
except ValueError:
|
||||
opts[key] = value
|
||||
return opts
|
||||
|
||||
|
||||
def messages_to_dicts(proto_messages):
|
||||
"""Convert proto ``Message`` objects to dicts suitable for ``apply_chat_template``.
|
||||
|
||||
Handles: ``role``, ``content``, ``name``, ``tool_call_id``,
|
||||
``reasoning_content``, ``tool_calls`` (JSON string → Python list).
|
||||
|
||||
HuggingFace chat templates (and their MLX/vLLM wrappers) expect a list of
|
||||
plain dicts — proto Message objects don't work directly with Jinja, so
|
||||
this conversion is needed before every ``apply_chat_template`` call.
|
||||
"""
|
||||
result = []
|
||||
for msg in proto_messages:
|
||||
d = {"role": msg.role, "content": msg.content or ""}
|
||||
if msg.name:
|
||||
d["name"] = msg.name
|
||||
if msg.tool_call_id:
|
||||
d["tool_call_id"] = msg.tool_call_id
|
||||
if msg.reasoning_content:
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
return result
|
||||
@@ -1,63 +1,22 @@
|
||||
"""Shared utilities for vLLM-based backends."""
|
||||
import json
|
||||
"""vLLM-specific helpers for the vllm and vllm-omni gRPC backends.
|
||||
|
||||
Generic helpers (``parse_options``, ``messages_to_dicts``) live in
|
||||
``python_utils`` and are re-exported here for backwards compatibility with
|
||||
existing imports in both backends.
|
||||
"""
|
||||
import sys
|
||||
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
|
||||
def parse_options(options_list):
|
||||
"""Parse Options[] list of 'key:value' strings into a dict.
|
||||
|
||||
Supports type inference for common cases (bool, int, float).
|
||||
Used by LoadModel to extract backend-specific options.
|
||||
"""
|
||||
opts = {}
|
||||
for opt in options_list:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
# Try type conversion
|
||||
if value.lower() in ("true", "false"):
|
||||
opts[key] = value.lower() == "true"
|
||||
else:
|
||||
try:
|
||||
opts[key] = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
opts[key] = float(value)
|
||||
except ValueError:
|
||||
opts[key] = value
|
||||
return opts
|
||||
|
||||
|
||||
def messages_to_dicts(proto_messages):
|
||||
"""Convert proto Message objects to list of dicts for apply_chat_template().
|
||||
|
||||
Handles: role, content, name, tool_call_id, reasoning_content, tool_calls (JSON string -> list).
|
||||
"""
|
||||
result = []
|
||||
for msg in proto_messages:
|
||||
d = {"role": msg.role, "content": msg.content or ""}
|
||||
if msg.name:
|
||||
d["name"] = msg.name
|
||||
if msg.tool_call_id:
|
||||
d["tool_call_id"] = msg.tool_call_id
|
||||
if msg.reasoning_content:
|
||||
d["reasoning_content"] = msg.reasoning_content
|
||||
if msg.tool_calls:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(msg.tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append(d)
|
||||
return result
|
||||
__all__ = ["parse_options", "messages_to_dicts", "setup_parsers"]
|
||||
|
||||
|
||||
def setup_parsers(opts):
|
||||
"""Return (tool_parser_cls, reasoning_parser_cls) tuple from opts dict.
|
||||
"""Return ``(tool_parser_cls, reasoning_parser_cls)`` from an opts dict.
|
||||
|
||||
Uses vLLM's native ToolParserManager and ReasoningParserManager.
|
||||
Returns (None, None) if vLLM is not installed or parsers not available.
|
||||
Uses vLLM's native ``ToolParserManager`` / ``ReasoningParserManager``.
|
||||
Returns ``(None, None)`` if vLLM isn't installed or the requested
|
||||
parser name can't be resolved.
|
||||
"""
|
||||
tool_parser_cls = None
|
||||
reasoning_parser_cls = None
|
||||
|
||||
@@ -15,17 +15,21 @@ Two startup modes:
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
from python_utils import messages_to_dicts, parse_options as _shared_parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
|
||||
import backend_pb2
|
||||
@@ -62,37 +66,10 @@ def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None):
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
|
||||
def is_float(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_int(s):
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def parse_options(options):
|
||||
"""Parse key:value option strings into a dict."""
|
||||
result = {}
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
result[key] = value
|
||||
return result
|
||||
# Re-export the shared helper under the local name for back-compat with
|
||||
# any callers (and the existing distributed worker tests) that imported
|
||||
# parse_options directly from this module.
|
||||
parse_options = _shared_parse_options
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
@@ -188,6 +165,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
)
|
||||
print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr)
|
||||
|
||||
# Log auto-detected TokenizerWrapper capabilities. Same shape
|
||||
# as the mlx backend: has_tool_calling / has_thinking from
|
||||
# mlx_lm.tokenizer_utils + the start/end markers it sniffed
|
||||
# from the chat template / vocab.
|
||||
has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False))
|
||||
has_thinking = bool(getattr(self.tokenizer, "has_thinking", False))
|
||||
tcs = getattr(self.tokenizer, "tool_call_start", None)
|
||||
tce = getattr(self.tokenizer, "tool_call_end", None)
|
||||
print(
|
||||
f"[Rank 0] Tokenizer capabilities: has_tool_calling={has_tools} "
|
||||
f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
print(f"[Rank 0] Error loading model: {err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading model: {err}")
|
||||
@@ -201,7 +192,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
tokens = self._get_tokens_from_prompt(prompt_text)
|
||||
@@ -211,7 +202,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
|
||||
self.coordinator.broadcast_tokens(tokens)
|
||||
|
||||
max_tokens, sampler_params = self._build_generation_params(request)
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
|
||||
|
||||
if self.coordinator:
|
||||
gen_params = self.coordinator.broadcast_generation_params(
|
||||
@@ -222,6 +213,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
max_tokens = gen_params["max_tokens"]
|
||||
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
# Use prompt cache in single-node mode
|
||||
gen_kwargs = {}
|
||||
@@ -238,22 +230,44 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
tokens = remaining_tokens if remaining_tokens else cache_key
|
||||
|
||||
generated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
**gen_kwargs,
|
||||
):
|
||||
generated.append(response.text)
|
||||
last_response = response
|
||||
if cache_key is not None:
|
||||
cache_key.append(response.token)
|
||||
if stop_words and any(s in "".join(generated) for s in stop_words):
|
||||
break
|
||||
|
||||
if self.lru_cache is not None and cache_key is not None:
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8'))
|
||||
full_text = self._truncate_at_stop("".join(generated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
|
||||
return backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr)
|
||||
@@ -268,7 +282,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
tokens = self._get_tokens_from_prompt(prompt_text)
|
||||
@@ -278,7 +292,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
|
||||
self.coordinator.broadcast_tokens(tokens)
|
||||
|
||||
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
|
||||
request, default_max_tokens=512
|
||||
)
|
||||
|
||||
if self.coordinator:
|
||||
gen_params = self.coordinator.broadcast_generation_params(
|
||||
@@ -289,6 +305,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
max_tokens = gen_params["max_tokens"]
|
||||
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
# Use prompt cache in single-node mode
|
||||
gen_kwargs = {}
|
||||
@@ -304,17 +321,45 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
gen_kwargs['prompt_cache'] = prompt_cache
|
||||
tokens = remaining_tokens if remaining_tokens else cache_key
|
||||
|
||||
accumulated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
**gen_kwargs,
|
||||
):
|
||||
if cache_key is not None:
|
||||
cache_key.append(response.token)
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
accumulated.append(response.text)
|
||||
last_response = response
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(response.text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
|
||||
)
|
||||
if stop_words and any(s in "".join(accumulated) for s in stop_words):
|
||||
break
|
||||
|
||||
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content="",
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr)
|
||||
@@ -335,12 +380,74 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
context.set_details("Embeddings are not supported in the MLX distributed backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
async def TokenizeString(self, request, context):
|
||||
if not hasattr(self, "tokenizer") or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("tokenizer not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokens = self.tokenizer.encode(request.Prompt)
|
||||
if hasattr(tokens, "tolist"):
|
||||
tokens = tokens.tolist()
|
||||
tokens = list(tokens)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
async def Free(self, request, context):
|
||||
try:
|
||||
# If we're rank 0 of a distributed run, tell workers to shut
|
||||
# down their per-request loops first so they release the model.
|
||||
if self.coordinator is not None:
|
||||
try:
|
||||
from coordinator import CMD_SHUTDOWN
|
||||
self.coordinator.broadcast_command(CMD_SHUTDOWN)
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] failed to broadcast shutdown: {e}", file=sys.stderr)
|
||||
if hasattr(self, "model"):
|
||||
del self.model
|
||||
if hasattr(self, "tokenizer"):
|
||||
del self.tokenizer
|
||||
if self.lru_cache is not None:
|
||||
try:
|
||||
self.lru_cache.clear()
|
||||
except Exception:
|
||||
pass
|
||||
self.lru_cache = None
|
||||
self.coordinator = None
|
||||
self.group = None
|
||||
gc.collect()
|
||||
try:
|
||||
import mlx.core as mx # type: ignore
|
||||
if hasattr(mx, "clear_cache"):
|
||||
mx.clear_cache()
|
||||
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
|
||||
mx.metal.clear_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="MLX distributed model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
def _prepare_prompt(self, request):
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages]
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
messages = messages_to_dicts(request.Messages)
|
||||
kwargs = {"tokenize": False, "add_generation_prompt": True}
|
||||
if request.Tools:
|
||||
try:
|
||||
kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if request.Metadata.get("enable_thinking", "").lower() == "true":
|
||||
kwargs["enable_thinking"] = True
|
||||
try:
|
||||
return self.tokenizer.apply_chat_template(messages, **kwargs)
|
||||
except TypeError:
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return request.Prompt
|
||||
|
||||
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
|
||||
@@ -349,6 +456,82 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return tokens.tolist()
|
||||
return list(tokens)
|
||||
|
||||
def _tool_module_from_tokenizer(self):
|
||||
"""Same shim as the mlx backend: fall back to json.loads when the
|
||||
installed mlx-lm doesn't expose a tool_parser callable on the
|
||||
wrapper (true on 0.29.x — only HEAD ships parsers)."""
|
||||
start = getattr(self.tokenizer, "tool_call_start", None)
|
||||
end = getattr(self.tokenizer, "tool_call_end", None)
|
||||
if not start:
|
||||
return None
|
||||
parse_fn = getattr(self.tokenizer, "tool_parser", None)
|
||||
if parse_fn is None:
|
||||
def parse_fn(body, tools): # noqa: E306
|
||||
return json.loads(body.strip())
|
||||
return types.SimpleNamespace(
|
||||
tool_call_start=start,
|
||||
tool_call_end=end or "",
|
||||
parse_tool_call=parse_fn,
|
||||
)
|
||||
|
||||
def _truncate_at_stop(self, text, stop_words):
|
||||
if not stop_words:
|
||||
return text
|
||||
earliest = len(text)
|
||||
for stop in stop_words:
|
||||
if not stop:
|
||||
continue
|
||||
idx = text.find(stop)
|
||||
if idx >= 0 and idx < earliest:
|
||||
earliest = idx
|
||||
return text[:earliest] if earliest < len(text) else text
|
||||
|
||||
def _finalize_output(self, request, generated_text, last_response):
|
||||
content = generated_text
|
||||
reasoning_content = ""
|
||||
if getattr(self.tokenizer, "has_thinking", False):
|
||||
think_start = getattr(self.tokenizer, "think_start", "") or ""
|
||||
think_end = getattr(self.tokenizer, "think_end", "") or ""
|
||||
reasoning_content, content = split_reasoning(content, think_start, think_end)
|
||||
|
||||
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
|
||||
tool_module = None
|
||||
if getattr(self.tokenizer, "has_tool_calling", False):
|
||||
tool_module = self._tool_module_from_tokenizer()
|
||||
if tool_module is not None:
|
||||
parsed_tools = None
|
||||
if request.Tools:
|
||||
try:
|
||||
parsed_tools = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
parsed_tools = None
|
||||
calls, content = parse_tool_calls(content, tool_module, parsed_tools)
|
||||
for c in calls:
|
||||
tool_calls_proto.append(
|
||||
backend_pb2.ToolCallDelta(
|
||||
index=c["index"], id=c["id"], name=c["name"], arguments=c["arguments"],
|
||||
)
|
||||
)
|
||||
|
||||
prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
|
||||
completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
|
||||
|
||||
logprobs_bytes = b""
|
||||
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
|
||||
try:
|
||||
lp = getattr(last_response, "logprobs", None)
|
||||
if lp is not None:
|
||||
token_id = int(getattr(last_response, "token", 0) or 0)
|
||||
token_text = self.tokenizer.decode([token_id]) if token_id else ""
|
||||
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
|
||||
logprobs_bytes = json.dumps(
|
||||
{"content": [{"token": token_text, "logprob": top_logprob}]}
|
||||
).encode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Logprobs extraction failed: {e}", file=sys.stderr)
|
||||
|
||||
return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes
|
||||
|
||||
def _build_generation_params(self, request, default_max_tokens=200):
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -373,6 +556,22 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
'xtc_probability': 0.0,
|
||||
}
|
||||
|
||||
# Logits processor parameters — pulled from the request and
|
||||
# forwarded to make_logits_processors. Rank 0 is the only rank
|
||||
# running the sampler so we don't need to broadcast these to
|
||||
# workers (workers participate in the pipeline-parallel forward
|
||||
# pass only).
|
||||
logits_params = {}
|
||||
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
|
||||
if repetition_penalty and repetition_penalty != 1.0:
|
||||
logits_params['repetition_penalty'] = repetition_penalty
|
||||
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
|
||||
if presence_penalty:
|
||||
logits_params['presence_penalty'] = presence_penalty
|
||||
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
|
||||
if frequency_penalty:
|
||||
logits_params['frequency_penalty'] = frequency_penalty
|
||||
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
@@ -392,9 +591,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
for opt_key, param_key in option_mapping.items():
|
||||
if opt_key in self.options:
|
||||
sampler_params[param_key] = self.options[opt_key]
|
||||
for opt_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
|
||||
if opt_key in self.options:
|
||||
logits_params[opt_key] = self.options[opt_key]
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
stop_words = list(getattr(request, 'StopPrompts', []) or [])
|
||||
return max_tokens, sampler_params, logits_params, stop_words
|
||||
|
||||
# XTC special tokens
|
||||
xtc_special_tokens = []
|
||||
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
@@ -6,6 +9,12 @@ import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
# Make the shared helpers importable so we can unit-test them without a
|
||||
# running gRPC server.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -85,3 +94,44 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.fail("sampling params service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestSharedHelpers(unittest.TestCase):
|
||||
"""Server-less unit tests for the helpers the mlx-distributed backend depends on."""
|
||||
|
||||
def test_parse_options_typed(self):
|
||||
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true"])
|
||||
self.assertEqual(opts["temperature"], 0.7)
|
||||
self.assertEqual(opts["max_tokens"], 128)
|
||||
self.assertIs(opts["trust"], True)
|
||||
|
||||
def test_messages_to_dicts_roundtrip(self):
|
||||
msgs = [
|
||||
backend_pb2.Message(role="user", content="hi"),
|
||||
backend_pb2.Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
|
||||
),
|
||||
backend_pb2.Message(role="tool", content="42", tool_call_id="call_1", name="f"),
|
||||
]
|
||||
out = messages_to_dicts(msgs)
|
||||
self.assertEqual(out[0], {"role": "user", "content": "hi"})
|
||||
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
|
||||
self.assertEqual(out[2]["tool_call_id"], "call_1")
|
||||
|
||||
def test_split_reasoning(self):
|
||||
r, c = split_reasoning("<think>plan</think>final", "<think>", "</think>")
|
||||
self.assertEqual(r, "plan")
|
||||
self.assertEqual(c, "final")
|
||||
|
||||
def test_parse_tool_calls_with_shim(self):
|
||||
tm = types.SimpleNamespace(
|
||||
tool_call_start="<tool_call>",
|
||||
tool_call_end="</tool_call>",
|
||||
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
|
||||
)
|
||||
calls, remaining = parse_tool_calls("<tool_call>Paris</tool_call>", tm, tools=None)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0]["name"], "get_weather")
|
||||
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
from typing import List
|
||||
import time
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
@@ -15,30 +18,18 @@ import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
from mlx_vlm import load, generate, stream_generate
|
||||
from mlx_vlm import load, stream_generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config, load_image
|
||||
from mlx_vlm.tool_parsers import _infer_tool_parser, load_tool_module
|
||||
from mlx_vlm.utils import load_config
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
import mlx.core as mx
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
@@ -78,36 +69,52 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr)
|
||||
print(f"Request: {request}", file=sys.stderr)
|
||||
|
||||
# Parse options like in the diffusers backend
|
||||
options = request.Options
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We store all the options in a dict for later use
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
self.options[key] = value
|
||||
|
||||
|
||||
# Parse Options[] key:value strings into a typed dict
|
||||
self.options = parse_options(request.Options)
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
|
||||
# Load model and processor using MLX-VLM
|
||||
# mlx-vlm load function returns (model, processor) instead of (model, tokenizer)
|
||||
self.model, self.processor = load(request.Model)
|
||||
|
||||
|
||||
# Load model config for chat template support
|
||||
self.config = load_config(request.Model)
|
||||
|
||||
|
||||
# Auto-infer the tool parser from the chat template. mlx-vlm has
|
||||
# its own _infer_tool_parser that falls back to mlx-lm parsers.
|
||||
tokenizer = (
|
||||
self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
|
||||
)
|
||||
self.tool_module = None
|
||||
if hasattr(tokenizer, "chat_template"):
|
||||
try:
|
||||
parser_type = _infer_tool_parser(tokenizer.chat_template)
|
||||
if parser_type is not None:
|
||||
self.tool_module = load_tool_module(parser_type)
|
||||
print(
|
||||
f"[mlx-vlm] auto-detected tool parser: {parser_type}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[mlx-vlm] no tool parser matched the chat template",
|
||||
file=sys.stderr,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[mlx-vlm] failed to load tool parser: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Reasoning tokens — check if the tokenizer advertises thinking
|
||||
# markers. Fall back to empty strings (split_reasoning no-ops).
|
||||
self.think_start = getattr(tokenizer, "think_start", "") or ""
|
||||
self.think_end = getattr(tokenizer, "think_end", "") or ""
|
||||
self.has_thinking = bool(
|
||||
getattr(tokenizer, "has_thinking", False) or self.think_start
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}")
|
||||
@@ -128,63 +135,72 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
temp_files = []
|
||||
try:
|
||||
# Process images and audios from request
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
|
||||
# Process images
|
||||
if request.Images:
|
||||
for img_data in request.Images:
|
||||
img_path = self.load_image_from_base64(img_data)
|
||||
if img_path:
|
||||
image_paths.append(img_path)
|
||||
temp_files.append(img_path)
|
||||
|
||||
# Process audios
|
||||
if request.Audios:
|
||||
for audio_data in request.Audios:
|
||||
audio_path = self.load_audio_from_base64(audio_data)
|
||||
if audio_path:
|
||||
audio_paths.append(audio_path)
|
||||
temp_files.append(audio_path)
|
||||
|
||||
# Prepare the prompt with multimodal information
|
||||
prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, generation_params = self._build_generation_params(request)
|
||||
|
||||
print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
|
||||
print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
|
||||
|
||||
# Generate text using MLX-VLM with multimodal inputs
|
||||
response = generate(
|
||||
image_paths, audio_paths = self._collect_media(request, temp_files)
|
||||
|
||||
prompt = self._prepare_prompt(
|
||||
request,
|
||||
num_images=len(image_paths),
|
||||
num_audios=len(audio_paths),
|
||||
)
|
||||
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
print(
|
||||
f"Generating text with MLX-VLM - max_tokens: {max_tokens}, "
|
||||
f"images: {len(image_paths)}, audios: {len(audio_paths)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
accumulated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
prompt=prompt,
|
||||
image=image_paths if image_paths else None,
|
||||
audio=audio_paths if audio_paths else None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=generation_params.get('temp', 0.6),
|
||||
top_p=generation_params.get('top_p', 1.0),
|
||||
verbose=False
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
):
|
||||
accumulated.append(response.text)
|
||||
last_response = response
|
||||
if stop_words and any(s in "".join(accumulated) for s in stop_words):
|
||||
break
|
||||
|
||||
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
|
||||
|
||||
|
||||
return backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Generation failed: {str(e)}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
self.cleanup_temp_files(temp_files)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""
|
||||
A gRPC method that calculates embeddings for a given sentence.
|
||||
|
||||
|
||||
Note: MLX-VLM doesn't support embeddings directly. This method returns an error.
|
||||
|
||||
Args:
|
||||
@@ -199,6 +215,79 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
context.set_details("Embeddings are not supported in the MLX-VLM backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
def _collect_media(self, request, temp_files):
|
||||
"""Decode base64 Images and Audios into temp file paths.
|
||||
|
||||
Appends every temp file to ``temp_files`` so the finally block can
|
||||
clean up even on mid-generation errors.
|
||||
"""
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
if request.Images:
|
||||
for img_data in request.Images:
|
||||
img_path = self.load_image_from_base64(img_data)
|
||||
if img_path:
|
||||
image_paths.append(img_path)
|
||||
temp_files.append(img_path)
|
||||
if request.Audios:
|
||||
for audio_data in request.Audios:
|
||||
audio_path = self.load_audio_from_base64(audio_data)
|
||||
if audio_path:
|
||||
audio_paths.append(audio_path)
|
||||
temp_files.append(audio_path)
|
||||
return image_paths, audio_paths
|
||||
|
||||
async def TokenizeString(self, request, context):
|
||||
"""Tokenize ``request.Prompt`` via the processor's tokenizer."""
|
||||
if not hasattr(self, "processor") or self.processor is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("processor not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokenizer = (
|
||||
self.processor.tokenizer
|
||||
if hasattr(self.processor, "tokenizer")
|
||||
else self.processor
|
||||
)
|
||||
tokens = tokenizer.encode(request.Prompt)
|
||||
if hasattr(tokens, "tolist"):
|
||||
tokens = tokens.tolist()
|
||||
tokens = list(tokens)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
async def Free(self, request, context):
|
||||
"""Drop the loaded model, processor and tool module."""
|
||||
try:
|
||||
if hasattr(self, "model"):
|
||||
del self.model
|
||||
if hasattr(self, "processor"):
|
||||
del self.processor
|
||||
if hasattr(self, "config"):
|
||||
del self.config
|
||||
self.tool_module = None
|
||||
gc.collect()
|
||||
# mlx.clear_cache (mlx >= 0.30) supersedes mlx.metal.clear_cache.
|
||||
try:
|
||||
if hasattr(mx, "clear_cache"):
|
||||
mx.clear_cache()
|
||||
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
|
||||
mx.metal.clear_cache()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import torch # type: ignore
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="MLX-VLM model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support.
|
||||
@@ -212,36 +301,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
temp_files = []
|
||||
try:
|
||||
# Process images and audios from request
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
|
||||
# Process images
|
||||
if request.Images:
|
||||
for img_data in request.Images:
|
||||
img_path = self.load_image_from_base64(img_data)
|
||||
if img_path:
|
||||
image_paths.append(img_path)
|
||||
temp_files.append(img_path)
|
||||
|
||||
# Process audios
|
||||
if request.Audios:
|
||||
for audio_data in request.Audios:
|
||||
audio_path = self.load_audio_from_base64(audio_data)
|
||||
if audio_path:
|
||||
audio_paths.append(audio_path)
|
||||
temp_files.append(audio_path)
|
||||
|
||||
# Prepare the prompt with multimodal information
|
||||
prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
|
||||
print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
|
||||
print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
|
||||
|
||||
# Stream text generation using MLX-VLM with multimodal inputs
|
||||
image_paths, audio_paths = self._collect_media(request, temp_files)
|
||||
|
||||
prompt = self._prepare_prompt(
|
||||
request,
|
||||
num_images=len(image_paths),
|
||||
num_audios=len(audio_paths),
|
||||
)
|
||||
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
|
||||
request, default_max_tokens=512
|
||||
)
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
print(
|
||||
f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, "
|
||||
f"images: {len(image_paths)}, audios: {len(audio_paths)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
accumulated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
@@ -249,77 +330,91 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
image=image_paths if image_paths else None,
|
||||
audio=audio_paths if audio_paths else None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=generation_params.get('temp', 0.6),
|
||||
top_p=generation_params.get('top_p', 1.0),
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
):
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
|
||||
accumulated.append(response.text)
|
||||
last_response = response
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(response.text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
|
||||
)
|
||||
if stop_words and any(s in "".join(accumulated) for s in stop_words):
|
||||
break
|
||||
|
||||
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content="",
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Streaming generation failed: {str(e)}")
|
||||
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
self.cleanup_temp_files(temp_files)
|
||||
|
||||
def _build_template_kwargs(self, request, num_images, num_audios):
|
||||
"""Collect kwargs for ``apply_chat_template`` that survive model variants."""
|
||||
kwargs = {"num_images": num_images, "num_audios": num_audios}
|
||||
if request.Tools:
|
||||
try:
|
||||
kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if request.Metadata.get("enable_thinking", "").lower() == "true":
|
||||
kwargs["enable_thinking"] = True
|
||||
return kwargs
|
||||
|
||||
def _apply_template(self, request, messages, num_images, num_audios):
|
||||
kwargs = self._build_template_kwargs(request, num_images, num_audios)
|
||||
try:
|
||||
return apply_chat_template(self.processor, self.config, messages, **kwargs)
|
||||
except TypeError:
|
||||
# Fallback for older mlx-vlm versions that reject tools=/enable_thinking=
|
||||
return apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios,
|
||||
)
|
||||
|
||||
def _prepare_prompt(self, request, num_images=0, num_audios=0):
|
||||
"""
|
||||
Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs.
|
||||
|
||||
Args:
|
||||
request: The gRPC request containing prompt and message information.
|
||||
num_images: Number of images in the request.
|
||||
num_audios: Number of audio files in the request.
|
||||
|
||||
Returns:
|
||||
str: The prepared prompt.
|
||||
Prepare the prompt for MLX-VLM generation, handling chat templates and
|
||||
multimodal inputs. Forwards tool definitions and enable_thinking when
|
||||
present on the request.
|
||||
"""
|
||||
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
# Convert gRPC messages to the format expected by apply_chat_template
|
||||
messages = []
|
||||
for msg in request.Messages:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# Use mlx-vlm's apply_chat_template which handles multimodal inputs
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
elif request.Prompt:
|
||||
# If we have a direct prompt but also have images/audio, we need to format it properly
|
||||
messages = messages_to_dicts(request.Messages)
|
||||
return self._apply_template(request, messages, num_images, num_audios)
|
||||
|
||||
if request.Prompt:
|
||||
if num_images > 0 or num_audios > 0:
|
||||
# Create a simple message structure for multimodal prompt
|
||||
messages = [{"role": "user", "content": request.Prompt}]
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return request.Prompt
|
||||
else:
|
||||
# Fallback to empty prompt with multimodal template if we have media
|
||||
if num_images > 0 or num_audios > 0:
|
||||
messages = [{"role": "user", "content": ""}]
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return ""
|
||||
return self._apply_template(request, messages, num_images, num_audios)
|
||||
return request.Prompt
|
||||
|
||||
# Fallback to empty prompt with multimodal template if we have media
|
||||
if num_images > 0 or num_audios > 0:
|
||||
messages = [{"role": "user", "content": ""}]
|
||||
return self._apply_template(request, messages, num_images, num_audios)
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
@@ -327,62 +422,122 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def _build_generation_params(self, request, default_max_tokens=200):
|
||||
"""
|
||||
Build generation parameters from request attributes and options for MLX-VLM.
|
||||
|
||||
Args:
|
||||
request: The gRPC request.
|
||||
default_max_tokens: Default max_tokens if not specified.
|
||||
Build generation parameters from request attributes and options.
|
||||
|
||||
Returns:
|
||||
tuple: (max_tokens, generation_params dict)
|
||||
tuple: (max_tokens, sampler_params, logits_params, stop_words)
|
||||
"""
|
||||
# Extract max_tokens
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens)
|
||||
if max_tokens == 0:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
# Extract generation parameters from request attributes
|
||||
temp = getattr(request, 'Temperature', 0.0)
|
||||
if temp == 0.0:
|
||||
temp = 0.6 # Default temperature
|
||||
|
||||
top_p = getattr(request, 'TopP', 0.0)
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0 # Default top_p
|
||||
|
||||
# Initialize generation parameters for MLX-VLM
|
||||
generation_params = {
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens) or default_max_tokens
|
||||
|
||||
temp = getattr(request, 'Temperature', 0.0) or 0.6
|
||||
top_p = getattr(request, 'TopP', 0.0) or 1.0
|
||||
min_p = getattr(request, 'MinP', 0.0) or 0.0
|
||||
top_k = getattr(request, 'TopK', 0) or 0
|
||||
|
||||
sampler_params = {
|
||||
'temp': temp,
|
||||
'top_p': top_p,
|
||||
'min_p': min_p,
|
||||
'top_k': top_k,
|
||||
}
|
||||
|
||||
# Add seed if specified
|
||||
|
||||
logits_params = {}
|
||||
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
|
||||
if repetition_penalty and repetition_penalty != 1.0:
|
||||
logits_params['repetition_penalty'] = repetition_penalty
|
||||
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
|
||||
if presence_penalty:
|
||||
logits_params['presence_penalty'] = presence_penalty
|
||||
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
|
||||
if frequency_penalty:
|
||||
logits_params['frequency_penalty'] = frequency_penalty
|
||||
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Override with options if available
|
||||
|
||||
if hasattr(self, 'options'):
|
||||
# Max tokens from options
|
||||
if 'max_tokens' in self.options:
|
||||
max_tokens = self.options['max_tokens']
|
||||
|
||||
# Generation parameters from options
|
||||
param_option_mapping = {
|
||||
'temp': 'temp',
|
||||
'temperature': 'temp', # alias
|
||||
'top_p': 'top_p',
|
||||
option_mapping = {
|
||||
'temp': 'temp', 'temperature': 'temp',
|
||||
'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k',
|
||||
}
|
||||
|
||||
for option_key, param_key in param_option_mapping.items():
|
||||
for option_key, param_key in option_mapping.items():
|
||||
if option_key in self.options:
|
||||
generation_params[param_key] = self.options[option_key]
|
||||
|
||||
# Handle seed from options
|
||||
sampler_params[param_key] = self.options[option_key]
|
||||
for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
|
||||
if option_key in self.options:
|
||||
logits_params[option_key] = self.options[option_key]
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
return max_tokens, generation_params
|
||||
|
||||
stop_words = list(getattr(request, 'StopPrompts', []) or [])
|
||||
return max_tokens, sampler_params, logits_params, stop_words
|
||||
|
||||
def _finalize_output(self, request, generated_text, last_response):
|
||||
"""Split reasoning + tool calls out of generated_text and return the
|
||||
tuple consumed by Reply-builders."""
|
||||
content = generated_text
|
||||
reasoning_content = ""
|
||||
|
||||
if getattr(self, "has_thinking", False):
|
||||
reasoning_content, content = split_reasoning(content, self.think_start, self.think_end)
|
||||
|
||||
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
|
||||
if self.tool_module is not None:
|
||||
parsed_tools = None
|
||||
if request.Tools:
|
||||
try:
|
||||
parsed_tools = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
parsed_tools = None
|
||||
calls, content = parse_tool_calls(content, self.tool_module, parsed_tools)
|
||||
for c in calls:
|
||||
tool_calls_proto.append(
|
||||
backend_pb2.ToolCallDelta(
|
||||
index=c["index"],
|
||||
id=c["id"],
|
||||
name=c["name"],
|
||||
arguments=c["arguments"],
|
||||
)
|
||||
)
|
||||
|
||||
prompt_tokens = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
|
||||
completion_tokens = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
|
||||
|
||||
logprobs_bytes = b""
|
||||
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
|
||||
try:
|
||||
lp = getattr(last_response, "logprobs", None)
|
||||
if lp is not None:
|
||||
token_id = int(getattr(last_response, "token", 0) or 0)
|
||||
tokenizer = (
|
||||
self.processor.tokenizer
|
||||
if hasattr(self.processor, "tokenizer")
|
||||
else self.processor
|
||||
)
|
||||
token_text = tokenizer.decode([token_id]) if token_id else ""
|
||||
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
|
||||
logprobs_bytes = json.dumps(
|
||||
{"content": [{"token": token_text, "logprob": top_logprob}]}
|
||||
).encode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[mlx-vlm] Logprobs extraction failed: {e}", file=sys.stderr)
|
||||
|
||||
return content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes
|
||||
|
||||
def _truncate_at_stop(self, text, stop_words):
|
||||
if not stop_words:
|
||||
return text
|
||||
earliest = len(text)
|
||||
for stop in stop_words:
|
||||
if not stop:
|
||||
continue
|
||||
idx = text.find(stop)
|
||||
if idx >= 0 and idx < earliest:
|
||||
earliest = idx
|
||||
return text[:earliest] if earliest < len(text) else text
|
||||
|
||||
def load_image_from_base64(self, image_data: str):
|
||||
"""
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import grpc
|
||||
import backend_pb2_grpc
|
||||
import backend_pb2
|
||||
# Make the shared helpers importable so we can unit-test them without a
|
||||
# running gRPC server.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
@@ -143,4 +145,55 @@ class TestBackendServicer(unittest.TestCase):
|
||||
print(err)
|
||||
self.fail("Embedding service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestSharedHelpers(unittest.TestCase):
|
||||
"""Server-less unit tests for the helpers the mlx-vlm backend depends on."""
|
||||
|
||||
def test_parse_options_typed(self):
|
||||
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello"])
|
||||
self.assertEqual(opts["temperature"], 0.7)
|
||||
self.assertEqual(opts["max_tokens"], 128)
|
||||
self.assertIs(opts["trust"], True)
|
||||
self.assertEqual(opts["name"], "hello")
|
||||
|
||||
def test_messages_to_dicts_roundtrip(self):
|
||||
msgs = [
|
||||
backend_pb2.Message(role="user", content="hi"),
|
||||
backend_pb2.Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
|
||||
),
|
||||
backend_pb2.Message(
|
||||
role="tool",
|
||||
content="42",
|
||||
tool_call_id="call_1",
|
||||
name="f",
|
||||
),
|
||||
]
|
||||
out = messages_to_dicts(msgs)
|
||||
self.assertEqual(out[0], {"role": "user", "content": "hi"})
|
||||
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
|
||||
self.assertEqual(out[2]["tool_call_id"], "call_1")
|
||||
|
||||
def test_split_reasoning(self):
|
||||
r, c = split_reasoning("<think>plan</think>final", "<think>", "</think>")
|
||||
self.assertEqual(r, "plan")
|
||||
self.assertEqual(c, "final")
|
||||
|
||||
def test_parse_tool_calls_with_shim(self):
|
||||
tm = types.SimpleNamespace(
|
||||
tool_call_start="<tool_call>",
|
||||
tool_call_end="</tool_call>",
|
||||
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
|
||||
)
|
||||
calls, remaining = parse_tool_calls(
|
||||
"<tool_call>Paris</tool_call>",
|
||||
tm,
|
||||
tools=None,
|
||||
)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0]["name"], "get_weather")
|
||||
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
|
||||
@@ -2,11 +2,13 @@
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import types
|
||||
from typing import List
|
||||
import time
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
@@ -15,13 +17,13 @@ import grpc
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
||||
from grpc_auth import get_auth_interceptors
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
from mlx_lm import load, generate, stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm import load, stream_generate
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||
import mlx.core as mx
|
||||
import base64
|
||||
import io
|
||||
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
|
||||
@@ -30,21 +32,6 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
@@ -78,46 +65,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
print(f"Loading MLX model: {request.Model}", file=sys.stderr)
|
||||
print(f"Request: {request}", file=sys.stderr)
|
||||
|
||||
# Parse options like in the diffusers backend
|
||||
options = request.Options
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We store all the options in a dict for later use
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
self.options[key] = value
|
||||
|
||||
|
||||
# Parse Options[] key:value strings into a typed dict (shared helper)
|
||||
self.options = parse_options(request.Options)
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
|
||||
# Build tokenizer config for MLX using options
|
||||
tokenizer_config = {}
|
||||
|
||||
|
||||
# Handle trust_remote_code from request or options
|
||||
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
|
||||
|
||||
# Handle EOS token from options
|
||||
if "eos_token" in self.options:
|
||||
tokenizer_config["eos_token"] = self.options["eos_token"]
|
||||
|
||||
|
||||
# Handle other tokenizer config options
|
||||
for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]:
|
||||
if key in self.options:
|
||||
tokenizer_config[key] = self.options[key]
|
||||
|
||||
|
||||
# Load model and tokenizer using MLX
|
||||
if tokenizer_config:
|
||||
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
|
||||
@@ -125,6 +93,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
else:
|
||||
self.model, self.tokenizer = load(request.Model)
|
||||
|
||||
# mlx_lm.load() returns a TokenizerWrapper that detects tool
|
||||
# calling and thinking markers from the chat template / vocab.
|
||||
# mlx-lm >= 0.30 also exposes a parser callable on the wrapper;
|
||||
# earlier versions don't (we fall back to json.loads inside
|
||||
# _tool_module_from_tokenizer below).
|
||||
has_tools = bool(getattr(self.tokenizer, "has_tool_calling", False))
|
||||
has_thinking = bool(getattr(self.tokenizer, "has_thinking", False))
|
||||
tcs = getattr(self.tokenizer, "tool_call_start", None)
|
||||
tce = getattr(self.tokenizer, "tool_call_end", None)
|
||||
print(
|
||||
f"MLX tokenizer capabilities: has_tool_calling={has_tools} "
|
||||
f"has_thinking={has_thinking} tool_call_start={tcs!r} tool_call_end={tce!r}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Initialize thread-safe LRU prompt cache for efficient generation
|
||||
max_cache_entries = self.options.get("max_cache_entries", 10)
|
||||
self.max_kv_size = self.options.get("max_kv_size", None)
|
||||
@@ -134,7 +117,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
can_trim_fn=can_trim_prompt_cache,
|
||||
trim_fn=trim_prompt_cache,
|
||||
)
|
||||
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}")
|
||||
@@ -172,30 +155,58 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
remaining_tokens = cache_key
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, sampler_params = self._build_generation_params(request)
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request)
|
||||
|
||||
print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
|
||||
print(
|
||||
f"Generating text with MLX - max_tokens: {max_tokens}, "
|
||||
f"cache_hit: {len(remaining_tokens) < len(cache_key)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Create sampler with parameters
|
||||
# Create sampler and optional logits processors (penalties)
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
# Use stream_generate to track generated tokens for cache key
|
||||
# Use stream_generate to collect text + track tokens for cache key
|
||||
generated_text = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=remaining_tokens if remaining_tokens else cache_key,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
generated_text.append(response.text)
|
||||
cache_key.append(response.token)
|
||||
last_response = response
|
||||
# Early stop on user-provided stop sequences
|
||||
if stop_words and any(s in "".join(generated_text) for s in stop_words):
|
||||
break
|
||||
|
||||
# Insert completed cache
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8'))
|
||||
full_text = self._truncate_at_stop("".join(generated_text), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
|
||||
return backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX Predict: {e}", file=sys.stderr)
|
||||
@@ -206,7 +217,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Embedding(self, request, context):
|
||||
"""
|
||||
A gRPC method that calculates embeddings for a given sentence.
|
||||
|
||||
|
||||
Note: MLX-LM doesn't support embeddings directly. This method returns an error.
|
||||
|
||||
Args:
|
||||
@@ -221,6 +232,62 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
context.set_details("Embeddings are not supported in the MLX backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
async def TokenizeString(self, request, context):
|
||||
"""Tokenize ``request.Prompt`` using the loaded model's tokenizer."""
|
||||
if not hasattr(self, "tokenizer") or self.tokenizer is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("tokenizer not loaded")
|
||||
return backend_pb2.TokenizationResponse()
|
||||
try:
|
||||
tokens = self.tokenizer.encode(request.Prompt)
|
||||
if hasattr(tokens, "tolist"):
|
||||
tokens = tokens.tolist()
|
||||
tokens = list(tokens)
|
||||
return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens)
|
||||
except Exception as e:
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return backend_pb2.TokenizationResponse()
|
||||
|
||||
async def Free(self, request, context):
|
||||
"""Drop the loaded model, tokenizer and prompt cache.
|
||||
|
||||
Metal / CUDA memory is released via ``gc.collect()`` + the
|
||||
platform-specific cache clear hooks when available.
|
||||
"""
|
||||
try:
|
||||
if hasattr(self, "model"):
|
||||
del self.model
|
||||
if hasattr(self, "tokenizer"):
|
||||
del self.tokenizer
|
||||
if hasattr(self, "lru_cache") and self.lru_cache is not None:
|
||||
try:
|
||||
self.lru_cache.clear()
|
||||
except Exception:
|
||||
pass
|
||||
self.lru_cache = None
|
||||
gc.collect()
|
||||
# Metal: drop the cached allocator. mlx.clear_cache (mlx >= 0.30)
|
||||
# supersedes the now-deprecated mlx.metal.clear_cache.
|
||||
try:
|
||||
if hasattr(mx, "clear_cache"):
|
||||
mx.clear_cache()
|
||||
elif hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
|
||||
mx.metal.clear_cache()
|
||||
except Exception:
|
||||
pass
|
||||
# CUDA: release the torch cache if a CUDA-backed mlx variant
|
||||
# happens to be installed alongside torch (best-effort).
|
||||
try:
|
||||
import torch # type: ignore
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return backend_pb2.Result(success=True, message="MLX model freed")
|
||||
except Exception as e:
|
||||
return backend_pb2.Result(success=False, message=str(e))
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results using MLX.
|
||||
@@ -251,24 +318,64 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
remaining_tokens = cache_key
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(
|
||||
request, default_max_tokens=512
|
||||
)
|
||||
|
||||
print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
|
||||
print(
|
||||
f"Streaming text with MLX - max_tokens: {max_tokens}, "
|
||||
f"cache_hit: {len(remaining_tokens) < len(cache_key)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Create sampler with parameters
|
||||
# Create sampler and optional logits processors (penalties)
|
||||
sampler = make_sampler(**sampler_params)
|
||||
logits_processors = make_logits_processors(**logits_params) if logits_params else None
|
||||
|
||||
# Stream text generation using MLX with proper parameters
|
||||
accumulated = []
|
||||
last_response = None
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=remaining_tokens if remaining_tokens else cache_key,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
cache_key.append(response.token)
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
accumulated.append(response.text)
|
||||
last_response = response
|
||||
# Emit a content delta. Structured reasoning / tool parsing
|
||||
# happens on the final chunk so we don't fragment the state
|
||||
# machine in v1.
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(response.text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=response.text)],
|
||||
)
|
||||
# Early stop on user-provided stop sequences
|
||||
if stop_words and any(s in "".join(accumulated) for s in stop_words):
|
||||
break
|
||||
|
||||
# Final chunk: run reasoning + tool parsing on accumulated text
|
||||
# and emit the structured ChatDelta with token counts + logprobs.
|
||||
full_text = self._truncate_at_stop("".join(accumulated), stop_words)
|
||||
content, reasoning_content, tool_calls_proto, prompt_tokens, completion_tokens, logprobs_bytes = (
|
||||
self._finalize_output(request, full_text, last_response)
|
||||
)
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
tokens=completion_tokens,
|
||||
logprobs=logprobs_bytes,
|
||||
chat_deltas=[
|
||||
backend_pb2.ChatDelta(
|
||||
content="",
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX PredictStream: {e}", file=sys.stderr)
|
||||
@@ -294,21 +401,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
Returns:
|
||||
str: The prepared prompt.
|
||||
"""
|
||||
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
|
||||
# If tokenizer template is enabled and messages are provided instead
|
||||
# of prompt, apply the tokenizer template (forwards tool definitions
|
||||
# and enable_thinking when the model supports them).
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
# Convert gRPC messages to the format expected by apply_chat_template
|
||||
messages = []
|
||||
for msg in request.Messages:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
messages = messages_to_dicts(request.Messages)
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return request.Prompt
|
||||
kwargs = {"tokenize": False, "add_generation_prompt": True}
|
||||
if request.Tools:
|
||||
try:
|
||||
kwargs["tools"] = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
enable_thinking = request.Metadata.get("enable_thinking", "").lower()
|
||||
if enable_thinking == "true":
|
||||
kwargs["enable_thinking"] = True
|
||||
|
||||
try:
|
||||
return self.tokenizer.apply_chat_template(messages, **kwargs)
|
||||
except TypeError:
|
||||
# Fallback for tokenizers whose template doesn't accept
|
||||
# tools= or enable_thinking=.
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return request.Prompt
|
||||
|
||||
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
|
||||
"""
|
||||
@@ -338,18 +457,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
default_max_tokens: Default max_tokens if not specified.
|
||||
|
||||
Returns:
|
||||
tuple: (max_tokens, sampler_params dict)
|
||||
tuple: (max_tokens, sampler_params dict, logits_processor_params dict,
|
||||
stop_words list)
|
||||
"""
|
||||
# Extract max_tokens
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens)
|
||||
if max_tokens == 0:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
|
||||
# Extract sampler parameters from request attributes
|
||||
temp = getattr(request, 'Temperature', 0.0)
|
||||
if temp == 0.0:
|
||||
temp = 0.6 # Default temperature
|
||||
|
||||
|
||||
top_p = getattr(request, 'TopP', 0.0)
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0 # Default top_p
|
||||
@@ -369,18 +489,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
'xtc_threshold': 0.0,
|
||||
'xtc_probability': 0.0,
|
||||
}
|
||||
|
||||
|
||||
# Logits processor parameters — only set fields the request actually
|
||||
# provides so we can feed them unconditionally to make_logits_processors.
|
||||
logits_params = {}
|
||||
repetition_penalty = getattr(request, 'RepetitionPenalty', 0.0) or 0.0
|
||||
if repetition_penalty and repetition_penalty != 1.0:
|
||||
logits_params['repetition_penalty'] = repetition_penalty
|
||||
presence_penalty = getattr(request, 'PresencePenalty', 0.0) or 0.0
|
||||
if presence_penalty:
|
||||
logits_params['presence_penalty'] = presence_penalty
|
||||
frequency_penalty = getattr(request, 'FrequencyPenalty', 0.0) or 0.0
|
||||
if frequency_penalty:
|
||||
logits_params['frequency_penalty'] = frequency_penalty
|
||||
|
||||
# Add seed if specified
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
|
||||
|
||||
# Override with options if available
|
||||
if hasattr(self, 'options'):
|
||||
# Max tokens from options
|
||||
if 'max_tokens' in self.options:
|
||||
max_tokens = self.options['max_tokens']
|
||||
|
||||
|
||||
# Sampler parameters from options
|
||||
sampler_option_mapping = {
|
||||
'temp': 'temp',
|
||||
@@ -391,32 +524,142 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
'xtc_threshold': 'xtc_threshold',
|
||||
'xtc_probability': 'xtc_probability',
|
||||
}
|
||||
|
||||
|
||||
for option_key, param_key in sampler_option_mapping.items():
|
||||
if option_key in self.options:
|
||||
sampler_params[param_key] = self.options[option_key]
|
||||
|
||||
|
||||
# Logits processor overrides
|
||||
for option_key in ('repetition_penalty', 'presence_penalty', 'frequency_penalty'):
|
||||
if option_key in self.options:
|
||||
logits_params[option_key] = self.options[option_key]
|
||||
|
||||
# Handle seed from options
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
|
||||
# Special tokens for XTC sampling (if tokenizer has eos_token_ids)
|
||||
xtc_special_tokens = []
|
||||
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
|
||||
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
|
||||
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
||||
xtc_special_tokens = [self.tokenizer.eos_token_id]
|
||||
|
||||
|
||||
# Add newline token if available
|
||||
try:
|
||||
newline_tokens = self.tokenizer.encode("\n")
|
||||
xtc_special_tokens.extend(newline_tokens)
|
||||
except:
|
||||
except Exception:
|
||||
pass # Skip if encoding fails
|
||||
|
||||
|
||||
sampler_params['xtc_special_tokens'] = xtc_special_tokens
|
||||
|
||||
return max_tokens, sampler_params
|
||||
|
||||
# Stop sequences are applied post-decode (mlx-lm doesn't have a
|
||||
# built-in stop-sequence sampler param). Preserve the list here.
|
||||
stop_words = list(getattr(request, 'StopPrompts', []) or [])
|
||||
|
||||
return max_tokens, sampler_params, logits_params, stop_words
|
||||
|
||||
def _tool_module_from_tokenizer(self):
|
||||
"""Build a duck-typed tool module from the TokenizerWrapper.
|
||||
|
||||
On mlx-lm >= 0.30 the wrapper exposes a ``tool_parser`` callable
|
||||
that's been resolved from the model's chat template. On older
|
||||
releases (e.g. 0.29.x) the wrapper only carries the start/end
|
||||
markers — fall back to ``json.loads`` of the body, which matches
|
||||
what ``mlx_lm.tool_parsers.json_tools.parse_tool_call`` does on
|
||||
HEAD and covers the only format 0.29 detects (``<tool_call>``).
|
||||
"""
|
||||
start = getattr(self.tokenizer, "tool_call_start", None)
|
||||
end = getattr(self.tokenizer, "tool_call_end", None)
|
||||
if not start:
|
||||
return None
|
||||
parse_fn = getattr(self.tokenizer, "tool_parser", None)
|
||||
if parse_fn is None:
|
||||
def parse_fn(body, tools): # noqa: E306 — local fallback
|
||||
return json.loads(body.strip())
|
||||
return types.SimpleNamespace(
|
||||
tool_call_start=start,
|
||||
tool_call_end=end or "",
|
||||
parse_tool_call=parse_fn,
|
||||
)
|
||||
|
||||
def _finalize_output(self, request, generated_text, last_response):
|
||||
"""Build a ChatDelta + token counts + logprobs from accumulated output.
|
||||
|
||||
Returns ``(content, reasoning_content, tool_calls_proto,
|
||||
prompt_token_count, completion_token_count, logprobs_bytes)``.
|
||||
"""
|
||||
content = generated_text
|
||||
reasoning_content = ""
|
||||
|
||||
if getattr(self.tokenizer, "has_thinking", False):
|
||||
think_start = getattr(self.tokenizer, "think_start", "") or ""
|
||||
think_end = getattr(self.tokenizer, "think_end", "") or ""
|
||||
reasoning_content, content = split_reasoning(content, think_start, think_end)
|
||||
|
||||
tool_calls_proto: List[backend_pb2.ToolCallDelta] = []
|
||||
tool_module = None
|
||||
if getattr(self.tokenizer, "has_tool_calling", False):
|
||||
tool_module = self._tool_module_from_tokenizer()
|
||||
if tool_module is not None:
|
||||
parsed_tools = None
|
||||
if request.Tools:
|
||||
try:
|
||||
parsed_tools = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
parsed_tools = None
|
||||
calls, content = parse_tool_calls(content, tool_module, parsed_tools)
|
||||
for c in calls:
|
||||
tool_calls_proto.append(
|
||||
backend_pb2.ToolCallDelta(
|
||||
index=c["index"],
|
||||
id=c["id"],
|
||||
name=c["name"],
|
||||
arguments=c["arguments"],
|
||||
)
|
||||
)
|
||||
|
||||
prompt_token_count = int(getattr(last_response, "prompt_tokens", 0) or 0) if last_response else 0
|
||||
completion_token_count = int(getattr(last_response, "generation_tokens", 0) or 0) if last_response else 0
|
||||
|
||||
logprobs_bytes = b""
|
||||
# Logprobs extraction — only when the request asked for them.
|
||||
if last_response is not None and int(getattr(request, "Logprobs", 0) or 0) > 0:
|
||||
try:
|
||||
lp = getattr(last_response, "logprobs", None)
|
||||
if lp is not None:
|
||||
# GenerationResponse.logprobs on the last chunk is the
|
||||
# logprob distribution of the final token. Without a
|
||||
# per-token history we at minimum surface the last token's
|
||||
# top-1 logprob so clients get a non-empty field.
|
||||
token_id = int(getattr(last_response, "token", 0) or 0)
|
||||
token_text = self.tokenizer.decode([token_id]) if token_id else ""
|
||||
top_logprob = float(lp[token_id]) if hasattr(lp, "__getitem__") else 0.0
|
||||
logprobs_bytes = json.dumps(
|
||||
{
|
||||
"content": [
|
||||
{"token": token_text, "logprob": top_logprob}
|
||||
]
|
||||
}
|
||||
).encode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[mlx] Logprobs extraction failed: {e}", file=sys.stderr)
|
||||
|
||||
return content, reasoning_content, tool_calls_proto, prompt_token_count, completion_token_count, logprobs_bytes
|
||||
|
||||
def _truncate_at_stop(self, text, stop_words):
|
||||
"""Truncate ``text`` at the first occurrence of any stop sequence."""
|
||||
if not stop_words:
|
||||
return text
|
||||
earliest = len(text)
|
||||
for stop in stop_words:
|
||||
if not stop:
|
||||
continue
|
||||
idx = text.find(stop)
|
||||
if idx >= 0 and idx < earliest:
|
||||
earliest = idx
|
||||
return text[:earliest] if earliest < len(text) else text
|
||||
|
||||
async def serve(address):
|
||||
# Start asyncio gRPC server
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import types
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
# Make the shared helpers importable so we can unit-test them without a
|
||||
# running gRPC server.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
from python_utils import messages_to_dicts, parse_options
|
||||
from mlx_utils import parse_tool_calls, split_reasoning
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service.
|
||||
@@ -231,4 +240,104 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.tearDown()
|
||||
|
||||
|
||||
def test_tokenize_string(self):
|
||||
"""TokenizeString should return a non-empty token list for a known prompt."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||
)
|
||||
self.assertTrue(response.success)
|
||||
resp = stub.TokenizeString(backend_pb2.PredictOptions(Prompt="Hello, world"))
|
||||
self.assertGreater(resp.length, 0)
|
||||
self.assertEqual(len(list(resp.tokens)), resp.length)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TokenizeString service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_free(self):
|
||||
"""Free should release the model and not crash on subsequent calls."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||
)
|
||||
self.assertTrue(response.success)
|
||||
free_resp = stub.Free(backend_pb2.HealthMessage())
|
||||
self.assertTrue(free_resp.success)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Free service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestSharedHelpers(unittest.TestCase):
|
||||
"""Server-less unit tests for the helpers the mlx backend depends on."""
|
||||
|
||||
def test_parse_options_typed(self):
|
||||
opts = parse_options(["temperature:0.7", "max_tokens:128", "trust:true", "name:hello", "no_colon_skipped"])
|
||||
self.assertEqual(opts["temperature"], 0.7)
|
||||
self.assertEqual(opts["max_tokens"], 128)
|
||||
self.assertIs(opts["trust"], True)
|
||||
self.assertEqual(opts["name"], "hello")
|
||||
self.assertNotIn("no_colon_skipped", opts)
|
||||
|
||||
def test_messages_to_dicts_roundtrip(self):
|
||||
# Build proto Message objects (via backend_pb2 to match real gRPC)
|
||||
msgs = [
|
||||
backend_pb2.Message(role="user", content="hi"),
|
||||
backend_pb2.Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls='[{"id":"call_1","type":"function","function":{"name":"f","arguments":"{}"}}]',
|
||||
),
|
||||
backend_pb2.Message(
|
||||
role="tool",
|
||||
content="42",
|
||||
tool_call_id="call_1",
|
||||
name="f",
|
||||
),
|
||||
]
|
||||
out = messages_to_dicts(msgs)
|
||||
self.assertEqual(out[0], {"role": "user", "content": "hi"})
|
||||
self.assertEqual(out[1]["role"], "assistant")
|
||||
self.assertEqual(out[1]["tool_calls"][0]["function"]["name"], "f")
|
||||
self.assertEqual(out[2]["tool_call_id"], "call_1")
|
||||
self.assertEqual(out[2]["name"], "f")
|
||||
|
||||
def test_split_reasoning(self):
|
||||
r, c = split_reasoning("<think>step 1\nstep 2</think>The answer is 42.", "<think>", "</think>")
|
||||
self.assertEqual(r, "step 1\nstep 2")
|
||||
self.assertEqual(c, "The answer is 42.")
|
||||
|
||||
def test_split_reasoning_no_marker(self):
|
||||
r, c = split_reasoning("just text", "<think>", "</think>")
|
||||
self.assertEqual(r, "")
|
||||
self.assertEqual(c, "just text")
|
||||
|
||||
def test_parse_tool_calls_with_shim(self):
|
||||
tm = types.SimpleNamespace(
|
||||
tool_call_start="<tool_call>",
|
||||
tool_call_end="</tool_call>",
|
||||
parse_tool_call=lambda body, tools: {"name": "get_weather", "arguments": {"location": body.strip()}},
|
||||
)
|
||||
calls, remaining = parse_tool_calls(
|
||||
"Sure: <tool_call>Paris</tool_call>",
|
||||
tm,
|
||||
tools=None,
|
||||
)
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0]["name"], "get_weather")
|
||||
self.assertEqual(calls[0]["arguments"], '{"location": "Paris"}')
|
||||
self.assertEqual(calls[0]["index"], 0)
|
||||
self.assertNotIn("<tool_call>", remaining)
|
||||
|
||||
|
||||
# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py
|
||||
@@ -84,6 +84,20 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
// Apply per-model-family inference parameter defaults
|
||||
config.ApplyInferenceDefaults(&modelConfig, details.URI)
|
||||
|
||||
// Auto-set tool_parser / reasoning_parser from parser_defaults.json so
|
||||
// the generated YAML mirrors what the vllm importer produces. The mlx
|
||||
// backends auto-detect parsers from the chat template at runtime and
|
||||
// ignore these Options, but surfacing them in the config keeps the two
|
||||
// paths consistent and gives users a single place to override.
|
||||
if parsers := config.MatchParserDefaults(details.URI); parsers != nil {
|
||||
if tp, ok := parsers["tool_parser"]; ok {
|
||||
modelConfig.Options = append(modelConfig.Options, "tool_parser:"+tp)
|
||||
}
|
||||
if rp, ok := parsers["reasoning_parser"]; ok {
|
||||
modelConfig.Options = append(modelConfig.Options, "reasoning_parser:"+rp)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
|
||||
Reference in New Issue
Block a user