Compare commits

...

1 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
cdf721e6ad Refactor runner for implementing batching 2026-02-27 16:45:39 +00:00
14 changed files with 952 additions and 689 deletions

View File

@@ -258,6 +258,6 @@ def get_node_id_keypair(
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
keypair = Keypair.generate()
f.write(keypair.to_bytes())
return keypair

View File

@@ -1,5 +1,6 @@
import contextlib
import multiprocessing as mp
from collections.abc import Generator
from dataclasses import dataclass, field
from math import inf
from multiprocessing.synchronize import Event
@@ -282,6 +283,54 @@ class MpReceiver[T]:
return d
class NonBlockingGenerator[T](Generator[T | None, None, None]):
def __init__(self, source: MpReceiver[T] | Generator[T | None, None, None]) -> None:
self._receiver: MpReceiver[T] | None = None
self._inner: Generator[T | None, None, None] | None = None
if isinstance(source, MpReceiver):
self._receiver = source
else:
self._inner = source
self._exhausted = False
def send(self, value: None, /) -> T | None:
if self._exhausted:
raise StopIteration
if self._inner is not None:
try:
return next(self._inner)
except (StopIteration, ClosedResourceError):
self._exhausted = True
raise StopIteration from None
assert self._receiver is not None
try:
return self._receiver.receive_nowait()
except WouldBlock:
return None
except (EndOfStream, ClosedResourceError):
self._exhausted = True
raise StopIteration from None
def throw(
self,
typ: type[BaseException] | BaseException,
val: BaseException | object = None,
tb: TracebackType | None = None,
/,
) -> T | None:
raise StopIteration
@property
def is_exhausted(self) -> bool:
return self._exhausted
def try_receive(self) -> T | None:
try:
return next(self)
except StopIteration:
return None
class channel[T]: # noqa: N801
"""Create a pair of asynchronous channels for communicating within the same process"""

View File

@@ -437,6 +437,7 @@ def mlx_generate(
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
distributed_prompt_progress_callback: Callable[[], None] | None = None,
on_generation_token: Callable[[], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -644,6 +645,9 @@ def mlx_generate(
full_prompt_tokens, caches, cache_snapshots
)
if on_generation_token is not None:
on_generation_token()
yield GenerationResponse(
text=text,
token=out.token,

View File

@@ -297,10 +297,10 @@ def _pending_tasks(
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed:
if task.task_id in runner.completed or task.task_id in runner.pending:
continue
if isinstance(runner.status, RunnerReady) and all(
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

@@ -33,10 +33,15 @@ def entrypoint(
try:
if bound_instance.is_image_model:
from exo.worker.runner.image_models.runner import main
else:
from exo.worker.runner.llm_inference.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
else:
from exo.worker.runner.llm_inference.runner import Runner
runner = Runner(
bound_instance, event_sender, task_receiver, cancel_receiver
)
runner.main()
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")

View File

@@ -0,0 +1,178 @@
from collections import deque
from collections.abc import Generator
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
from exo.shared.types.common import ModelId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import TaskId, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import PrefillCancelled, mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
mx_any,
)
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
"""Check for debug prompt triggers in the input."""
import time
from exo.worker.engines.mlx.utils_mlx import mlx_force_oom
if len(task_params.input) == 0:
return
prompt = task_params.input[0].content
if not prompt:
return
if EXO_RUNNER_MUST_FAIL in prompt:
raise Exception("Artificial runner exception - for testing purposes only.")
if EXO_RUNNER_MUST_OOM in prompt:
mlx_force_oom()
if EXO_RUNNER_MUST_TIMEOUT in prompt:
time.sleep(100)
@dataclass(eq=False)
class BatchGenerator:
model: Model
tokenizer: TokenizerWrapper
group: mx.distributed.Group | None
kv_prefix_cache: KVPrefixCache | None
model_id: ModelId
device_rank: int
cancel_receiver: MpReceiver[TaskId]
cancelled_tasks: set[TaskId]
event_sender: MpSender[Event]
check_for_cancel_every: int
_queue: deque[tuple[TextGeneration, MpSender[GenerationResponse]]] = field(
default_factory=deque, init=False
)
_active: (
tuple[
TextGeneration,
MpSender[GenerationResponse],
Generator[GenerationResponse],
]
| None
) = field(default=None, init=False)
def submit(
self,
task: TextGeneration,
sender: MpSender[GenerationResponse],
) -> None:
self._queue.append((task, sender))
if self._active is None:
self._start_next()
def step(self) -> None:
if self._active is None:
if self._queue:
self._start_next()
else:
return
if self._active is None:
return
task, sender, gen = self._active
try:
response = next(gen)
sender.send(response)
except (StopIteration, PrefillCancelled):
sender.close()
self._active = None
if self._queue:
self._start_next()
except Exception as e:
self._send_error(task, e)
sender.close()
self._active = None
raise
def _start_next(self) -> None:
task, sender = self._queue.popleft()
try:
gen = self._build_generator(task)
except Exception as e:
self._send_error(task, e)
sender.close()
raise
self._active = (task, sender, gen)
def _send_error(self, task: TextGeneration, e: Exception) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:
_check_for_debug_prompts(task.task_params)
prompt = apply_chat_template(self.tokenizer, task.task_params)
def on_prefill_progress(processed: int, total: int) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=PrefillProgressChunk(
model=self.model_id,
processed_tokens=processed,
total_tokens=total,
),
)
)
def distributed_prompt_progress_callback() -> None:
self.cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self.cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
tokens_since_cancel_check = self.check_for_cancel_every
def on_generation_token() -> None:
nonlocal tokens_since_cancel_check
tokens_since_cancel_check += 1
if tokens_since_cancel_check >= self.check_for_cancel_every:
tokens_since_cancel_check = 0
self.cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self.cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
return mlx_generate(
model=self.model,
tokenizer=self.tokenizer,
task=task.task_params,
prompt=prompt,
kv_prefix_cache=self.kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
on_generation_token=on_generation_token,
group=self.group,
)

View File

@@ -0,0 +1,341 @@
from collections.abc import Generator
from functools import cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ToolCallItem
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.llm_inference.tool_parsers import ToolParser
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
# Debug: log every token with state
logger.debug(
f"parse_gpt_oss token={response.token} text={response.text!r} "
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
f"state={stream.state} current_tool={current_tool_name!r}"
)
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
logger.info(
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
)
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
],
usage=response.usage,
)
tool_arg_parts = []
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
if ch != "analysis" and thinking:
thinking = False
if delta:
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
if response.finish_reason is not None:
yield response
def parse_deepseek_v32(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
Uses accumulated-text matching (not per-token marker checks) because
DSML markers like <DSMLfunction_calls> may span multiple tokens.
Also handles <think>...</think> blocks for thinking mode.
"""
from exo.worker.engines.mlx.dsml_encoding import (
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
parse_dsml_output,
)
accumulated = ""
in_tool_call = False
thinking = False
# Tokens buffered while we detect the start of a DSML block
pending_buffer: list[GenerationResponse] = []
# Text accumulated during a tool call block
tool_call_text = ""
for response in responses:
if response is None:
yield None
continue
# ── Handle thinking tags ──
if not thinking and THINKING_START in response.text:
thinking = True
# Yield any text before the <think> tag
before = response.text[: response.text.index(THINKING_START)]
if before:
yield response.model_copy(update={"text": before})
continue
if thinking and THINKING_END in response.text:
thinking = False
# Yield any text after the </think> tag
after = response.text[
response.text.index(THINKING_END) + len(THINKING_END) :
]
if after:
yield response.model_copy(update={"text": after, "is_thinking": False})
continue
if thinking:
yield response.model_copy(update={"is_thinking": True})
continue
# ── Handle tool call accumulation ──
if in_tool_call:
tool_call_text += response.text
if TOOL_CALLS_END in tool_call_text:
# Parse the accumulated DSML block
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# EOS reached before end marker — yield buffered text as-is
if response.finish_reason is not None:
logger.info("DSML tool call parsing interrupted by EOS")
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# ── Detect start of tool call block ──
accumulated += response.text
if TOOL_CALLS_START in accumulated:
# The start marker might be split across pending_buffer + current token
start_idx = accumulated.index(TOOL_CALLS_START)
# Yield any pending tokens that are purely before the marker
pre_text = accumulated[:start_idx]
if pre_text:
# Flush pending buffer tokens that contributed text before the marker
for buf_resp in pending_buffer:
if pre_text:
chunk = buf_resp.text
if len(chunk) <= len(pre_text):
yield buf_resp
pre_text = pre_text[len(chunk) :]
else:
yield buf_resp.model_copy(update={"text": pre_text})
pre_text = ""
pending_buffer = []
tool_call_text = accumulated[start_idx:]
accumulated = ""
# Check if the end marker is already present (entire tool call in one token)
if TOOL_CALLS_END in tool_call_text:
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
tool_call_text = ""
else:
in_tool_call = True
continue
# Check if accumulated text might be the start of a DSML marker
# Buffer tokens if we see a partial match at the end
if _could_be_dsml_prefix(accumulated):
pending_buffer.append(response)
continue
# No partial match — flush all pending tokens and the current one
for buf_resp in pending_buffer:
yield buf_resp
pending_buffer = []
accumulated = ""
yield response
# Flush any remaining pending buffer at generator end
for buf_resp in pending_buffer:
yield buf_resp
def _could_be_dsml_prefix(text: str) -> bool:
"""Check if the end of text could be the start of a DSML function_calls marker.
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
This allows us to buffer tokens until we can determine if a tool call is starting.
"""
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
# Only check the last portion of text that could overlap with the marker
max_check = len(TOOL_CALLS_START)
tail = text[-max_check:] if len(text) > max_check else text
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
for i in range(len(tail)):
suffix = tail[i:]
if TOOL_CALLS_START.startswith(suffix):
return True
return False
def parse_thinking_models(
responses: Generator[GenerationResponse | None],
tokenizer: TokenizerWrapper,
starts_in_thinking: bool = True,
) -> Generator[GenerationResponse | None]:
"""Route thinking tokens via is_thinking flag.
Swallows think tag tokens, sets is_thinking on all others.
Always yields tokens with finish_reason to avoid hanging the chunk stream.
"""
in_thinking = starts_in_thinking
for response in responses:
if response is None:
yield None
continue
if isinstance(response, ToolCallResponse):
yield response
continue
is_think_tag = (
tokenizer.think_end is not None and response.text == tokenizer.think_end
) or (
tokenizer.think_start is not None and response.text == tokenizer.think_start
)
if is_think_tag:
in_thinking = response.text != tokenizer.think_end
# Never swallow finish_reason — the chunk stream needs it to terminate.
if response.finish_reason is not None:
yield response.model_copy(update={"text": "", "is_thinking": False})
continue
yield response.model_copy(update={"is_thinking": in_thinking})
def parse_tool_calls(
responses: Generator[GenerationResponse | None], tool_parser: ToolParser
) -> Generator[GenerationResponse | ToolCallResponse | None]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
if not in_tool_call and response.text.startswith(tool_parser.start_parsing):
in_tool_call = True
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"tool call parsing interrupted, yield partial tool call as text"
)
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
)
yield response
else:
# fallthrough
yield response

View File

File diff suppressed because it is too large Load Diff

View File

@@ -172,7 +172,7 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
self.pending[event.task_id].set()
continue
if (
isinstance(event, TaskStatusUpdated)
@@ -190,6 +190,7 @@ class RunnerSupervisor:
),
)
self.completed.add(event.task_id)
self.pending.pop(event.task_id, None)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask):

View File

@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
encode_messages,
parse_dsml_output,
)
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32
from exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32
# ── Shared fixtures ──────────────────────────────────────────────

View File

@@ -6,6 +6,7 @@ from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator
import exo.worker.runner.llm_inference.runner as mlx_runner
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
@@ -115,17 +116,20 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(
mlx_batch_generator, "apply_chat_template", make_nothin("test prompt")
)
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
monkeypatch.setattr(mlx_batch_generator, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness.
@@ -183,12 +187,13 @@ def _run(tasks: Iterable[Task]):
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
mlx_runner.main(
runner = mlx_runner.Runner(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
runner.main()
return event_sender.events

View File

@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.runner.llm_inference.runner import parse_gpt_oss
from exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary.
@@ -107,7 +107,7 @@ def _collect(
def _gen() -> Generator[GenerationResponse, None, None]:
yield from _make_gen_responses(tokens)
return list(parse_gpt_oss(_gen()))
return list(x for x in parse_gpt_oss(_gen()) if x is not None)
def _get_tool_call(

View File

@@ -4,7 +4,7 @@ from collections.abc import Generator
from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.llm_inference.runner import parse_tool_calls
from exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser