Compare commits

...

7 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
3239c55e40 Move mx_all_gather_tasks into utils_mlx 2026-03-03 14:39:23 +00:00
Ryuichi Leo Takashige
725264cc33 pass CI yet again 2026-03-03 14:33:14 +00:00
rltakashige
401ccfbd30 Merge branch 'main' into leo/prepare-batch-implementation 2026-03-03 14:32:25 +00:00
Ryuichi Leo Takashige
06beffe0e2 Pass CI 2026-03-03 14:18:34 +00:00
Evan Quiney
e9193581bc Batch cleanup (#1649)
what da ya think!

---------

Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-03-03 14:06:13 +00:00
Ryuichi Leo Takashige
69628383c5 Match with image runner 2026-03-03 10:49:17 +00:00
Ryuichi Leo Takashige
f77a672126 Refactor runner for implementing batching 2026-03-03 10:49:17 +00:00
16 changed files with 1390 additions and 1009 deletions

View File

@@ -1,16 +1,20 @@
import asyncio import asyncio
import pytest import pytest
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError from exo_pyo3_bindings import (
Keypair,
NetworkingHandle,
NoPeersSubscribedToTopicError,
PyFromSwarm,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sleep_on_multiple_items() -> None: async def test_sleep_on_multiple_items() -> None:
print("PYTHON: starting handle") print("PYTHON: starting handle")
h = NetworkingHandle(Keypair.generate_ed25519()) h = NetworkingHandle(Keypair.generate())
ct = asyncio.create_task(_await_cons(h)) rt = asyncio.create_task(_await_recv(h))
mt = asyncio.create_task(_await_msg(h))
# sleep for 4 ticks # sleep for 4 ticks
for i in range(4): for i in range(4):
@@ -22,13 +26,11 @@ async def test_sleep_on_multiple_items() -> None:
print("caught it", e) print("caught it", e)
async def _await_cons(h: NetworkingHandle): async def _await_recv(h: NetworkingHandle):
while True: while True:
c = await h.connection_update_recv() event = await h.recv()
print(f"PYTHON: connection update: {c}") match event:
case PyFromSwarm.Connection() as c:
print(f"PYTHON: connection update: {c}")
async def _await_msg(h: NetworkingHandle): case PyFromSwarm.Message() as m:
while True: print(f"PYTHON: message: {m}")
m = await h.gossipsub_recv()
print(f"PYTHON: message: {m}")

View File

@@ -258,6 +258,6 @@ def get_node_id_keypair(
# if no valid credentials, create new ones and persist # if no valid credentials, create new ones and persist
with open(path, "w+b") as f: with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519() keypair = Keypair.generate()
f.write(keypair.to_bytes()) f.write(keypair.to_bytes())
return keypair return keypair

View File

@@ -437,6 +437,7 @@ def mlx_generate(
group: mx.distributed.Group | None, group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None, on_prefill_progress: Callable[[int, int], None] | None = None,
distributed_prompt_progress_callback: Callable[[], None] | None = None, distributed_prompt_progress_callback: Callable[[], None] | None = None,
on_generation_token: Callable[[], None] | None = None,
) -> Generator[GenerationResponse]: ) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation # Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory() mx.reset_peak_memory()
@@ -644,6 +645,9 @@ def mlx_generate(
full_prompt_tokens, caches, cache_snapshots full_prompt_tokens, caches, cache_snapshots
) )
if on_generation_token is not None:
on_generation_token()
yield GenerationResponse( yield GenerationResponse(
text=text, text=text,
token=out.token, token=out.token,

View File

@@ -41,6 +41,7 @@ from exo.download.download_utils import build_model_path
from exo.shared.types.common import Host from exo.shared.types.common import Host
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.mlx import Model from exo.shared.types.mlx import Model
from exo.shared.types.tasks import TaskId, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import ( from exo.shared.types.worker.instances import (
BoundInstance, BoundInstance,
@@ -748,3 +749,55 @@ def _parse_kimi_tool_calls(text: str):
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny] return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
else: else:
return [_parse_single_tool(text)] return [_parse_single_tool(text)]
def mx_all_gather_tasks(
tasks: list[TextGeneration],
group: mx.distributed.Group | None,
) -> tuple[list[TextGeneration], list[TextGeneration]]:
def encode_task_id(task_id: TaskId) -> list[int]:
utf8_task_id = task_id.encode()
return [
int.from_bytes(utf8_task_id[i : i + 1]) for i in range(len(utf8_task_id))
]
def decode_task_id(encoded_task_id: list[int]) -> TaskId:
return TaskId(
bytes.decode(b"".join((x).to_bytes(length=1) for x in encoded_task_id))
)
uuid_byte_length = 36
n_tasks = len(tasks)
all_counts = cast(
list[int],
mx.distributed.all_gather(mx.array([n_tasks]), group=group).tolist(),
)
max_tasks = max(all_counts)
world_size: int = 1 if group is None else group.size()
if max_tasks == 0:
return [], []
padded = [encode_task_id(task.task_id) for task in tasks] + [
[0] * uuid_byte_length
] * (max_tasks - n_tasks)
gathered = cast(
list[list[list[int]]],
mx.distributed.all_gather(mx.array(padded), group=group)
.reshape(world_size, max_tasks, -1)
.tolist(),
)
all_task_ids: list[list[TaskId]] = [
[decode_task_id(encoded_task_id) for encoded_task_id in rank_tasks[:count]]
for rank_tasks, count in zip(gathered, all_counts, strict=True)
]
agreed_ids: set[TaskId] = set(all_task_ids[0])
for rank_tasks in all_task_ids[1:]:
agreed_ids &= set(rank_tasks)
local_tasks = {task.task_id: task for task in tasks}
agreed = [local_tasks[tid] for tid in sorted(agreed_ids)]
different = [task for task in tasks if task.task_id not in agreed_ids]
return agreed, different

View File

@@ -297,10 +297,10 @@ def _pending_tasks(
# the task status _should_ be set to completed by the LAST runner # the task status _should_ be set to completed by the LAST runner
# it is currently set by the first # it is currently set by the first
# this is definitely a hack # this is definitely a hack
if task.task_id in runner.completed: if task.task_id in runner.completed or task.task_id in runner.in_progress:
continue continue
if isinstance(runner.status, RunnerReady) and all( if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning)) isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
): ):

View File

@@ -32,11 +32,19 @@ def entrypoint(
# Import main after setting global logger - this lets us just import logger from this module # Import main after setting global logger - this lets us just import logger from this module
try: try:
if bound_instance.is_image_model: if bound_instance.is_image_model:
from exo.worker.runner.image_models.runner import main from exo.worker.runner.image_models.runner import Runner as ImageRunner
else:
from exo.worker.runner.llm_inference.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver) runner = ImageRunner(
bound_instance, event_sender, task_receiver, cancel_receiver
)
runner.main()
else:
from exo.worker.runner.llm_inference.runner import Runner
runner = Runner(
bound_instance, event_sender, task_receiver, cancel_receiver
)
runner.main()
except ClosedResourceError: except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly") logger.warning("Runner communication closed unexpectedly")

View File

@@ -182,272 +182,266 @@ def _send_image_chunk(
) )
def main( class Runner:
bound_instance: BoundInstance, def __init__(
event_sender: MpSender[Event], self,
task_receiver: MpReceiver[Task], bound_instance: BoundInstance,
cancel_receiver: MpReceiver[TaskId], event_sender: MpSender[Event],
): task_receiver: MpReceiver[Task],
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) cancel_receiver: MpReceiver[TaskId],
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard)) ):
self.event_sender = event_sender
self.task_receiver = task_receiver
self.cancel_receiver = cancel_receiver
self.bound_instance = bound_instance
instance, runner_id, shard_metadata = ( soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
bound_instance.instance, resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(shard_metadata, "should_timeout", 0):
time.sleep(timeout)
setup_start_time = time.time() self.instance, self.runner_id, self.shard_metadata = (
cancelled_tasks = set[TaskId]() bound_instance.instance,
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
self.device_rank = self.shard_metadata.device_rank
image_model: DistributedImageModel | None = None logger.info("hello from the runner")
group = None if getattr(self.shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(self.shard_metadata, "should_timeout", 0):
time.sleep(timeout)
current_status: RunnerStatus = RunnerIdle() self.setup_start_time = time.time()
logger.info("runner created") self.cancelled_tasks = set[TaskId]()
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) self.image_model: DistributedImageModel | None = None
) self.group = None
seen = set[TaskId]()
with task_receiver as tasks: self.current_status: RunnerStatus = RunnerIdle()
for task in tasks: logger.info("runner created")
if task.task_id in seen: self.update_status(RunnerIdle())
logger.warning("repeat task - potential error") self.seen = set[TaskId]()
seen.add(task.task_id)
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK")) def update_status(self, status: RunnerStatus):
event_sender.send( self.current_status = status
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running) self.event_sender.send(
RunnerStatusUpdated(
runner_id=self.runner_id, runner_status=self.current_status
) )
match task: )
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
group = initialize_mlx(bound_instance)
logger.info("runner connected") def send_task_status(self, task: Task, status: TaskStatus):
current_status = RunnerConnected() self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=status)
)
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to def acknowledge_task(self, task: Task):
case LoadModel() if ( self.event_sender.send(TaskAcknowledged(task_id=task.task_id))
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert ( def main(self):
ModelTask.TextToImage in shard_metadata.model_card.tasks with self.task_receiver as tasks:
or ModelTask.ImageToImage in shard_metadata.model_card.tasks for task in tasks:
), f"Incorrect model task(s): {shard_metadata.model_card.tasks}" if task.task_id in self.seen:
logger.warning("repeat task - potential error")
image_model = initialize_image_model(bound_instance) self.seen.add(task.task_id)
current_status = RunnerLoaded() self.cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
logger.info("runner loaded") self.send_task_status(task, TaskStatus.Running)
self.handle_task(task)
case StartWarmup() if isinstance(current_status, RunnerLoaded): was_cancelled = (task.task_id in self.cancelled_tasks) or (
current_status = RunnerWarmingUp() TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
logger.info(f"warming up inference for instance: {instance}")
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if _is_primary_output_node(shard_metadata):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
if _is_primary_output_node(shard_metadata):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
except Exception as e:
if _is_primary_output_node(shard_metadata):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
if not TYPE_CHECKING:
del image_model, group
mx.clear_cache()
import gc
gc.collect()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
was_cancelled = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if not was_cancelled:
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
) )
event_sender.send( if not was_cancelled:
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) self.send_task_status(task, TaskStatus.Complete)
) self.update_status(self.current_status)
if isinstance(current_status, RunnerShutdown): if isinstance(self.current_status, RunnerShutdown):
break break
def handle_task(self, task: Task):
match task:
case ConnectToGroup() if isinstance(
self.current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
self.update_status(RunnerConnecting())
self.acknowledge_task(task)
self.group = initialize_mlx(self.bound_instance)
logger.info("runner connected")
self.current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(self.current_status, RunnerConnected)
and self.group is not None
) or (isinstance(self.current_status, RunnerIdle) and self.group is None):
logger.info("runner loading")
self.update_status(RunnerLoading())
self.acknowledge_task(task)
assert (
ModelTask.TextToImage in self.shard_metadata.model_card.tasks
or ModelTask.ImageToImage in self.shard_metadata.model_card.tasks
), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}"
self.image_model = initialize_image_model(self.bound_instance)
self.current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(self.current_status, RunnerLoaded):
logger.info("runner warming up")
self.update_status(RunnerWarmingUp())
self.acknowledge_task(task)
logger.info(f"warming up inference for instance: {self.instance}")
assert self.image_model
image = warmup_image_generator(model=self.image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
logger.info(
f"runner initialized in {time.time() - self.setup_start_time} seconds"
)
self.current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(task_params=task_params, command_id=command_id) if (
isinstance(self.current_status, RunnerReady)
):
assert self.image_model
logger.info(f"received image generation request: {str(task)[:500]}")
logger.info("runner running")
self.update_status(RunnerRunning())
self.acknowledge_task(task)
try:
image_index = 0
for response in generate_image(
model=self.image_model, task=task_params
):
is_primary_output = _is_primary_output_node(self.shard_metadata)
if is_primary_output:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if _is_primary_output_node(self.shard_metadata):
self.event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=self.shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(
self.event_sender, task.task_id, self.device_rank
)
self.current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(self.current_status, RunnerReady)
):
assert self.image_model
logger.info(f"received image edits request: {str(task)[:500]}")
logger.info("runner running")
self.update_status(RunnerRunning())
self.acknowledge_task(task)
try:
image_index = 0
for response in generate_image(
model=self.image_model, task=task_params
):
if _is_primary_output_node(self.shard_metadata):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
self.shard_metadata,
self.event_sender,
image_index,
)
image_index += 1
except Exception as e:
if _is_primary_output_node(self.shard_metadata):
self.event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=self.shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
finally:
_send_traces_if_enabled(
self.event_sender, task.task_id, self.device_rank
)
self.current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
logger.info("runner shutting down")
if not TYPE_CHECKING:
del self.image_model, self.group
mx.clear_cache()
import gc
gc.collect()
self.update_status(RunnerShuttingDown())
self.acknowledge_task(task)
self.current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {self.current_status=}"
)

View File

@@ -0,0 +1,293 @@
import itertools
import math
import time
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Generator, Iterable
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
from exo.shared.types.common import ModelId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import TaskId, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import (
PrefillCancelled,
mlx_generate,
warmup_inference,
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
mx_all_gather_tasks,
mx_any,
)
from exo.worker.runner.bootstrap import logger
from .model_output_parsers import apply_all_parsers
from .tool_parsers import ToolParser
class Cancelled:
pass
class Finished:
pass
class GeneratorQueue[T]:
def __init__(self):
self._q = deque[T]()
def push(self, t: T):
self._q.append(t)
def gen(self) -> Generator[T | None]:
while True:
if len(self._q) == 0:
yield None
else:
yield self._q.popleft()
class InferenceGenerator(ABC):
@abstractmethod
def warmup(self) -> None: ...
@abstractmethod
def submit(
self,
task: TextGeneration,
) -> None: ...
@abstractmethod
def step(
self,
) -> Iterable[
tuple[TaskId, ToolCallResponse | GenerationResponse | Cancelled | Finished]
]: ...
@abstractmethod
def close(self) -> None: ...
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
"""Check for debug prompt triggers in the input."""
from exo.worker.engines.mlx.utils_mlx import mlx_force_oom
if len(task_params.input) == 0:
return
prompt = task_params.input[0].content
if not prompt:
return
if EXO_RUNNER_MUST_FAIL in prompt:
raise Exception("Artificial runner exception - for testing purposes only.")
if EXO_RUNNER_MUST_OOM in prompt:
mlx_force_oom()
if EXO_RUNNER_MUST_TIMEOUT in prompt:
time.sleep(100)
@dataclass(eq=False)
class SequentialGenerator(InferenceGenerator):
model: Model
tokenizer: TokenizerWrapper
group: mx.distributed.Group | None
kv_prefix_cache: KVPrefixCache | None
tool_parser: ToolParser | None
model_id: ModelId
device_rank: int
cancel_receiver: MpReceiver[TaskId]
event_sender: MpSender[Event]
check_for_cancel_every: int = 50
_cancelled_tasks: set[TaskId] = field(default_factory=set, init=False)
_maybe_queue: list[TextGeneration] = field(default_factory=list, init=False)
_queue: deque[TextGeneration] = field(default_factory=deque, init=False)
_active: (
tuple[
TextGeneration,
# mlx generator that does work
Generator[GenerationResponse],
# queue that the 1st generator should push to and 3rd generator should pull from
GeneratorQueue[GenerationResponse],
# generator to get parsed outputs
Generator[GenerationResponse | ToolCallResponse | None],
]
| None
) = field(default=None, init=False)
def warmup(self):
logger.info(f"warming up inference for instance: {self.model_id}")
t = time.monotonic()
toks = warmup_inference(
model=self.model,
tokenizer=self.tokenizer,
group=self.group,
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
if self.group is not None:
self.check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]),
group=self.group,
)
).item()
)
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
def submit(
self,
task: TextGeneration,
) -> None:
self._cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
self._maybe_queue.append(task)
def agree_on_tasks(self) -> None:
"""Agree between all ranks about the task ordering (some may have received in different order or not at all)."""
agreed, different = mx_all_gather_tasks(self._maybe_queue, self.group)
self._queue.extend(task for task in self._maybe_queue if task in agreed)
self._maybe_queue = [task for task in self._maybe_queue if task in different]
def step(
self,
) -> Iterable[
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
]:
if self._active is None:
self.agree_on_tasks()
if self._queue:
self._start_next()
else:
return map(lambda task: (task, Cancelled()), self._cancelled_tasks)
assert self._active is not None
task, mlx_gen, queue, output_generator = self._active
response = None
try:
queue.push(next(mlx_gen))
response = next(output_generator)
except (StopIteration, PrefillCancelled):
response = Finished()
self._active = None
if self._queue:
self._start_next()
except Exception as e:
self._send_error(task, e)
self._active = None
raise
return itertools.chain(
[] if response is None else [(task.task_id, response)],
map(lambda task: (task, Cancelled()), self._cancelled_tasks),
)
def _start_next(self) -> None:
task = self._queue.popleft()
try:
mlx_gen = self._build_generator(task)
except Exception as e:
self._send_error(task, e)
raise
queue = GeneratorQueue[GenerationResponse]()
output_generator = apply_all_parsers(
queue.gen(),
apply_chat_template(self.tokenizer, task.task_params),
self.tool_parser,
self.tokenizer,
type(self.model),
self.model_id,
)
self._active = (task, mlx_gen, queue, output_generator)
def _send_error(self, task: TextGeneration, e: Exception) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:
_check_for_debug_prompts(task.task_params)
prompt = apply_chat_template(self.tokenizer, task.task_params)
def on_prefill_progress(processed: int, total: int) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=PrefillProgressChunk(
model=self.model_id,
processed_tokens=processed,
total_tokens=total,
),
)
)
def distributed_prompt_progress_callback() -> None:
self._cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self._cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self._cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
self.agree_on_tasks()
tokens_since_cancel_check = self.check_for_cancel_every
def on_generation_token() -> None:
nonlocal tokens_since_cancel_check
tokens_since_cancel_check += 1
if tokens_since_cancel_check >= self.check_for_cancel_every:
tokens_since_cancel_check = 0
self._cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self._cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self._cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
self.agree_on_tasks()
return mlx_generate(
model=self.model,
tokenizer=self.tokenizer,
task=task.task_params,
prompt=prompt,
kv_prefix_cache=self.kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
on_generation_token=on_generation_token,
group=self.group,
)
def close(self) -> None:
del self.model, self.tokenizer, self.group

View File

@@ -0,0 +1,376 @@
from collections.abc import Generator
from functools import cache
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ToolCallItem
from exo.shared.types.common import ModelId
from exo.shared.types.mlx import Model
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.engines.mlx.utils_mlx import (
detect_thinking_prompt_suffix,
)
from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def apply_all_parsers(
receiver: Generator[GenerationResponse | None],
prompt: str,
tool_parser: ToolParser | None,
tokenizer: TokenizerWrapper,
model_type: type[Model],
model_id: ModelId,
) -> Generator[GenerationResponse | ToolCallResponse | None]:
mlx_generator = receiver
if tokenizer.has_thinking:
mlx_generator = parse_thinking_models(
mlx_generator,
tokenizer,
starts_in_thinking=detect_thinking_prompt_suffix(prompt, tokenizer),
)
if issubclass(model_type, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
elif (
issubclass(model_type, DeepseekV32Model)
and "deepseek" in model_id.normalize().lower()
):
mlx_generator = parse_deepseek_v32(mlx_generator)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
return mlx_generator
def parse_gpt_oss(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
# Debug: log every token with state
logger.debug(
f"parse_gpt_oss token={response.token} text={response.text!r} "
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
f"state={stream.state} current_tool={current_tool_name!r}"
)
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
logger.info(
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
)
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
],
usage=response.usage,
)
tool_arg_parts = []
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
if ch != "analysis" and thinking:
thinking = False
if delta:
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
if response.finish_reason is not None:
yield response
def parse_deepseek_v32(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
Uses accumulated-text matching (not per-token marker checks) because
DSML markers like <DSMLfunction_calls> may span multiple tokens.
Also handles <think>...</think> blocks for thinking mode.
"""
from exo.worker.engines.mlx.dsml_encoding import (
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
parse_dsml_output,
)
accumulated = ""
in_tool_call = False
thinking = False
# Tokens buffered while we detect the start of a DSML block
pending_buffer: list[GenerationResponse] = []
# Text accumulated during a tool call block
tool_call_text = ""
for response in responses:
if response is None:
yield None
continue
# ── Handle thinking tags ──
if not thinking and THINKING_START in response.text:
thinking = True
# Yield any text before the <think> tag
before = response.text[: response.text.index(THINKING_START)]
if before:
yield response.model_copy(update={"text": before})
continue
if thinking and THINKING_END in response.text:
thinking = False
# Yield any text after the </think> tag
after = response.text[
response.text.index(THINKING_END) + len(THINKING_END) :
]
if after:
yield response.model_copy(update={"text": after, "is_thinking": False})
continue
if thinking:
yield response.model_copy(update={"is_thinking": True})
continue
# ── Handle tool call accumulation ──
if in_tool_call:
tool_call_text += response.text
if TOOL_CALLS_END in tool_call_text:
# Parse the accumulated DSML block
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# EOS reached before end marker — yield buffered text as-is
if response.finish_reason is not None:
logger.info("DSML tool call parsing interrupted by EOS")
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# ── Detect start of tool call block ──
accumulated += response.text
if TOOL_CALLS_START in accumulated:
# The start marker might be split across pending_buffer + current token
start_idx = accumulated.index(TOOL_CALLS_START)
# Yield any pending tokens that are purely before the marker
pre_text = accumulated[:start_idx]
if pre_text:
# Flush pending buffer tokens that contributed text before the marker
for buf_resp in pending_buffer:
if pre_text:
chunk = buf_resp.text
if len(chunk) <= len(pre_text):
yield buf_resp
pre_text = pre_text[len(chunk) :]
else:
yield buf_resp.model_copy(update={"text": pre_text})
pre_text = ""
pending_buffer = []
tool_call_text = accumulated[start_idx:]
accumulated = ""
# Check if the end marker is already present (entire tool call in one token)
if TOOL_CALLS_END in tool_call_text:
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
tool_call_text = ""
else:
in_tool_call = True
continue
# Check if accumulated text might be the start of a DSML marker
# Buffer tokens if we see a partial match at the end
if _could_be_dsml_prefix(accumulated):
pending_buffer.append(response)
continue
# No partial match — flush all pending tokens and the current one
for buf_resp in pending_buffer:
yield buf_resp
pending_buffer = []
accumulated = ""
yield response
# Flush any remaining pending buffer at generator end
for buf_resp in pending_buffer:
yield buf_resp
def _could_be_dsml_prefix(text: str) -> bool:
"""Check if the end of text could be the start of a DSML function_calls marker.
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
This allows us to buffer tokens until we can determine if a tool call is starting.
"""
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
# Only check the last portion of text that could overlap with the marker
max_check = len(TOOL_CALLS_START)
tail = text[-max_check:] if len(text) > max_check else text
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
for i in range(len(tail)):
suffix = tail[i:]
if TOOL_CALLS_START.startswith(suffix):
return True
return False
def parse_thinking_models(
responses: Generator[GenerationResponse | None],
tokenizer: TokenizerWrapper,
starts_in_thinking: bool = True,
) -> Generator[GenerationResponse | None]:
"""Route thinking tokens via is_thinking flag.
Swallows think tag tokens, sets is_thinking on all others.
Always yields tokens with finish_reason to avoid hanging the chunk stream.
"""
in_thinking = starts_in_thinking
for response in responses:
if response is None:
yield None
continue
is_think_tag = (
tokenizer.think_end is not None and response.text == tokenizer.think_end
) or (
tokenizer.think_start is not None and response.text == tokenizer.think_start
)
if is_think_tag:
in_thinking = response.text != tokenizer.think_end
# Never swallow finish_reason — the chunk stream needs it to terminate.
if response.finish_reason is not None:
yield response.model_copy(update={"text": "", "is_thinking": False})
continue
yield response.model_copy(update={"is_thinking": in_thinking})
def parse_tool_calls(
responses: Generator[GenerationResponse | None], tool_parser: ToolParser
) -> Generator[GenerationResponse | ToolCallResponse | None]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
if not in_tool_call and response.text.startswith(tool_parser.start_parsing):
in_tool_call = True
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"tool call parsing interrupted, yield partial tool call as text"
)
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
)
yield response
else:
# fallthrough
yield response

View File

File diff suppressed because it is too large Load Diff

View File

@@ -52,6 +52,7 @@ class RunnerSupervisor:
_tg: TaskGroup = field(default_factory=TaskGroup, init=False) _tg: TaskGroup = field(default_factory=TaskGroup, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False) status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False) pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
in_progress: set[TaskId] = field(default_factory=set, init=False)
completed: set[TaskId] = field(default_factory=set, init=False) completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False) cancelled: set[TaskId] = field(default_factory=set, init=False)
_cancel_watch_runner: anyio.CancelScope = field( _cancel_watch_runner: anyio.CancelScope = field(
@@ -157,6 +158,7 @@ class RunnerSupervisor:
async def cancel_task(self, task_id: TaskId): async def cancel_task(self, task_id: TaskId):
if task_id in self.completed: if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed") logger.info(f"Unable to cancel {task_id} as it has been completed")
self.cancelled.add(task_id)
return return
self.cancelled.add(task_id) self.cancelled.add(task_id)
with anyio.move_on_after(0.5) as scope: with anyio.move_on_after(0.5) as scope:
@@ -173,6 +175,7 @@ class RunnerSupervisor:
self.status = event.runner_status self.status = event.runner_status
if isinstance(event, TaskAcknowledged): if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set() self.pending.pop(event.task_id).set()
self.in_progress.add(event.task_id)
continue continue
if ( if (
isinstance(event, TaskStatusUpdated) isinstance(event, TaskStatusUpdated)
@@ -189,6 +192,7 @@ class RunnerSupervisor:
RunnerShuttingDown, RunnerShuttingDown,
), ),
) )
self.in_progress.discard(event.task_id)
self.completed.add(event.task_id) self.completed.add(event.task_id)
await self._event_sender.send(event) await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e: except (ClosedResourceError, BrokenResourceError) as e:

View File

@@ -20,6 +20,8 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance bound_instance: BoundInstance
status: RunnerStatus status: RunnerStatus
completed: set[TaskId] = field(default_factory=set) completed: set[TaskId] = field(default_factory=set)
in_progress: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask): class OtherTask(BaseTask):

View File

@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
encode_messages, encode_messages,
parse_dsml_output, parse_dsml_output,
) )
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32 from exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32
# ── Shared fixtures ────────────────────────────────────────────── # ── Shared fixtures ──────────────────────────────────────────────

View File

@@ -6,6 +6,8 @@ from typing import Callable
import mlx.core as mx import mlx.core as mx
import pytest import pytest
import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator
import exo.worker.runner.llm_inference.model_output_parsers as mlx_model_output_parsers
import exo.worker.runner.llm_inference.runner as mlx_runner import exo.worker.runner.llm_inference.runner as mlx_runner
from exo.shared.types.chunks import TokenChunk from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import ( from exo.shared.types.events import (
@@ -114,27 +116,41 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group # initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup())) monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer))) monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1)) monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin) monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False)) monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
def fake_all_gather(
tasks: list[TextGeneration], group: object
) -> tuple[list[TextGeneration], list[TextGeneration]]:
return (tasks, [])
monkeypatch.setattr(mlx_batch_generator, "mx_all_gather_tasks", fake_all_gather)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1). # Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None. # Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt")) monkeypatch.setattr(
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False)) mlx_batch_generator, "apply_chat_template", make_nothin("test prompt")
)
monkeypatch.setattr(
mlx_model_output_parsers, "detect_thinking_prompt_suffix", make_nothin(False)
)
def fake_generate(*_1: object, **_2: object): def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None) yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate) monkeypatch.setattr(mlx_batch_generator, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness. # Use a fake event_sender to remove test flakiness.
class EventCollector: class EventCollector:
def __init__(self) -> None: def __init__(self, on_event: Callable[[Event], None] | None = None) -> None:
self.events: list[Event] = [] self.events: list[Event] = []
self._on_event = on_event
def send(self, event: Event) -> None: def send(self, event: Event) -> None:
self.events.append(event) self.events.append(event)
if self._on_event:
self._on_event(event)
def close(self) -> None: def close(self) -> None:
pass pass
@@ -159,7 +175,7 @@ class MockGroup:
return 1 return 1
def _run(tasks: Iterable[Task]): def _run(tasks: Iterable[Task], send_after_ready: list[Task] | None = None):
bound_instance = get_bound_mlx_ring_instance( bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID, instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID, model_id=MODEL_A_ID,
@@ -169,7 +185,23 @@ def _run(tasks: Iterable[Task]):
task_sender, task_receiver = mp_channel[Task]() task_sender, task_receiver = mp_channel[Task]()
_cancel_sender, cancel_receiver = mp_channel[TaskId]() _cancel_sender, cancel_receiver = mp_channel[TaskId]()
event_sender = EventCollector()
on_event: Callable[[Event], None] | None = None
if send_after_ready:
_saw_running = False
def _on_event(event: Event) -> None:
nonlocal _saw_running
if isinstance(event, RunnerStatusUpdated):
if isinstance(event.runner_status, RunnerRunning):
_saw_running = True
elif _saw_running and isinstance(event.runner_status, RunnerReady):
for t in send_after_ready:
task_sender.send(t)
on_event = _on_event
event_sender = EventCollector(on_event=on_event)
with task_sender: with task_sender:
for t in tasks: for t in tasks:
@@ -183,18 +215,22 @@ def _run(tasks: Iterable[Task]):
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather", "exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])), make_nothin(mx.array([1])),
): ):
mlx_runner.main( runner = mlx_runner.Runner(
bound_instance, bound_instance,
event_sender, # pyright: ignore[reportArgumentType] event_sender, # pyright: ignore[reportArgumentType]
task_receiver, task_receiver,
cancel_receiver, cancel_receiver,
) )
runner.main()
return event_sender.events return event_sender.events
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch): def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK]) events = _run(
[INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK],
send_after_ready=[SHUTDOWN_TASK],
)
expected_chunk = ChunkGenerated( expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID, command_id=COMMAND_1_ID,

View File

@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse, GenerationResponse,
ToolCallResponse, ToolCallResponse,
) )
from exo.worker.runner.llm_inference.runner import parse_gpt_oss from exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer. # Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary. # These are stable since they come from the model's vocabulary.
@@ -107,7 +107,7 @@ def _collect(
def _gen() -> Generator[GenerationResponse, None, None]: def _gen() -> Generator[GenerationResponse, None, None]:
yield from _make_gen_responses(tokens) yield from _make_gen_responses(tokens)
return list(parse_gpt_oss(_gen())) return list(x for x in parse_gpt_oss(_gen()) if x is not None)
def _get_tool_call( def _get_tool_call(

View File

@@ -4,7 +4,7 @@ from collections.abc import Generator
from typing import Any from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.llm_inference.runner import parse_tool_calls from exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser