Compare commits

...

17 Commits

Author SHA1 Message Date
Alex Cheema
7469f44e58 fix: clean up stale runners from state when instance is deleted
apply_instance_deleted() previously only removed the instance from
state.instances, leaving its runner entries orphaned in state.runners
with their last known status (e.g. RunnerReady). After a node kill and
rejoin, readiness checks would see these stale entries and attempt
inference against dead runner processes, causing post-recovery failures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 18:41:22 -08:00
Alex Cheema
5d26d2dcd6 fix: eliminate serialization bottlenecks in continuous batching pipeline
The batch prefill deferral (35973b86) was insufficient alone because
multiple other serialization points prevented true concurrent request
processing. This fixes four bottlenecks:

- Drain all available tasks per plan_step cycle instead of one-per-100ms
- Keep tasks in supervisor pending set after ACK to prevent re-dispatch
- Process TextGeneration tasks inline during decode loop (no break/restart)
- Pre-tokenize in queue_request so sync_and_insert_pending is lightweight
- Reuse prompt from queue_request for thinking detection (no double apply_chat_template)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 18:36:56 -08:00
Alex Cheema
35973b8698 fix: defer batch prefill for true continuous batching
Move sync_and_insert_pending() out of the per-task loop so all
concurrently-arrived requests share a single batched prefill pass.
Previously each request was prefilled individually as it arrived,
serializing what should be a parallel operation.

Also break early from the generation TimeBudget loop when new tasks
are waiting, so they get inserted sooner rather than blocking for
the full 0.5s budget.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:52:04 -08:00
Alex Cheema
41d9d2a61f fix: add has_thinking to mock tokenizers in edge case tests
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:02:07 -08:00
Alex Cheema
efbf9850eb style: apply nix fmt to new edge case test file
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 16:55:38 -08:00
Alex Cheema
4a22c4b512 fix: restore per-request sampling, model-specific parsers, error handling, and tracing
- Use per-request temperature/top_p/top_k from TextGenerationTaskParams
  instead of hardcoded sampler defaults in BatchGenerationEngine
- Restore model-specific tokenizer patches (Kimi, GLM) at load time
- Add GptOssTracker for per-request GPT-OSS stream parsing with
  thinking channels and tool call routing
- Filter Kimi section boundary tokens from batch output
- Detect thinking prompt suffix and prepend think_start on first token
- Wrap batch_engine.step() in try/except, send ErrorChunks on failure
- Call _send_traces_if_enabled() when text generation tasks complete

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 16:55:18 -08:00
Alex Cheema
4d0fe5d17b test: add edge-case tests for continuous batching
Cover concurrent tool calls, length/stop finish reasons, multiple
completions per step, staggered draining, and batches of 5-10 requests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 16:27:45 -08:00
Alex Cheema
9fe7251796 fix: add generation loop, deferred task completion, and tool call tracking
The batch engine integration had three critical issues:
1. No generation loop - batch_engine.step() only ran during shutdown drain
2. Tasks marked complete before any tokens were generated
3. Tool calls dropped - parse_tool_calls pipeline was disconnected

Restructure runner main() into a two-phase while loop that alternates
between TimeBudget-based generation steps and task polling. Add
ToolCallTracker for per-request tool call state in the batch path,
and defer TaskStatusUpdated(Complete) until finish_reason is set.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:22:09 -08:00
Alex Cheema
1c8f69ce00 fix: address PR review comments for continuous batching
- Use get_args(FinishReason) instead of hardcoded finish reason checks
- Use new Python type syntax (def share_object[T]) instead of TypeVar
- Assert obj is not None for rank 0 with message to use mx_barrier()
- Raise RuntimeError on size=0 instead of silently returning None
- Simplify share_object callers since return type is now T (not T | None)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 14:29:13 -08:00
Alex Cheema
f19166617a fix: use EventCollector instead of mp_channel to fix flaky test
Replace mp_channel event receiver with direct EventCollector in
test_continuous_batching.py to eliminate multiprocessing pipe race
condition that caused test_runner_status_reflects_active_requests
to intermittently miss RunnerRunning events.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 14:00:28 -08:00
Alex Cheema
51e959c979 feat: integrate BatchGenerationEngine into runner for continuous batching
Replace synchronous per-request text generation with BatchGenerationEngine,
enabling continuous batching of multiple concurrent inference requests.

- Runner accepts TextGeneration in both RunnerReady and RunnerRunning states
- Requests are queued and sync-inserted into the batch engine
- Batch engine is drained during shutdown to complete in-flight requests
- Only rank 0 emits ChunkGenerated events (distributed-safe)
- Enable previously-skipped continuous batching tests
- Update event ordering tests for the new batch-based flow

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 13:47:04 -08:00
Alex Cheema
cd43588a04 Merge remote-tracking branch 'origin/main' into alexcheema/continuous-batching 2026-02-13 10:09:40 -08:00
Alex Cheema
7b879593bb fix: update continuous batching types after main merge
Replace ChatCompletionTaskParams with TextGenerationTaskParams and
ChatCompletion with TextGeneration to match the refactored type
hierarchy from main. Add missing usage parameter to GenerationResponse
constructors and add type annotations to StreamingDetokenizer stubs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 07:29:13 -08:00
Alex Cheema
e4e895d7a8 Merge remote-tracking branch 'origin/main' into alexcheema/continuous-batching
# Conflicts:
#	AGENTS.md
2026-02-13 05:57:15 -08:00
Alex Cheema
db400dbb75 skip continuous batching tests pending type migration
The continuous batching runner architecture references old types
(ChatCompletion, ChatCompletionTaskParams) that were renamed on main.
Skip the test module until the batch engine code is updated.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 06:18:40 -08:00
Alex Cheema
15fad9c632 Merge remote-tracking branch 'origin/main' into alexcheema/continuous-batching
# Conflicts:
#	.mlx_typings/mlx_lm/tokenizer_utils.pyi
#	src/exo/worker/runner/runner.py
#	src/exo/worker/runner/runner_supervisor.py
#	src/exo/worker/tests/unittests/test_runner/test_event_ordering.py
2026-02-05 06:12:49 -08:00
Alex Cheema
842beefac0 feat: add continuous batching for distributed inference
Implements continuous batching using mlx_lm's BatchGenerator for efficient
multi-request handling in distributed mode.

Key changes:
- Add BatchGenerationEngine that wraps mlx_lm's BatchGenerator for continuous
  batching with prefill batching (up to 8 requests) and decode batching
- Add TimeBudget pattern for controlling generation loop timing with periodic
  distributed sync
- Add distributed_sync utilities for broadcasting objects across ranks using
  mx.distributed.all_sum()
- Stream tokens immediately as generated for smooth streaming (not in batches)
- Fix distributed correctness: deferred shutdown handling, sync_completions
  always syncs in distributed mode to prevent deadlocks

Performance results on Kimi K2 Thinking (658GB) with Tensor RDMA:
- Batch 1:  10.7 tok/s (baseline)
- Batch 4:  34.6 tok/s (3.2x speedup)
- Batch 16: 41.8 tok/s (3.9x speedup)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:27:15 +00:00
18 changed files with 2892 additions and 482 deletions

View File

@@ -276,24 +276,23 @@ class BatchGenerator:
logprobs: mx.array
finish_reason: Optional[str]
unprocessed_prompts: List[Any]
def __init__(
self,
model,
model: nn.Module,
max_tokens: int = ...,
stop_tokens: Optional[set] = ...,
stop_tokens: Optional[set[int]] = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
completion_batch_size: int = ...,
prefill_batch_size: int = ...,
prefill_step_size: int = ...,
) -> None: ...
def insert(
self, prompts, max_tokens: Union[List[int], int, None] = ...
): # -> list[Any]:
...
def stats(self): # -> BatchStats:
...
def next(self): # -> list[Any]:
...
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
) -> List[int]: ...
def stats(self) -> BatchStats: ...
def next(self) -> List[Response]: ...
def batch_generate(
model,

View File

@@ -39,11 +39,11 @@ class StreamingDetokenizer:
"""
__slots__ = ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
@property
def last_segment(self):
def last_segment(self) -> str:
"""Return the last segment of readable text since last time this property was accessed."""
class NaiveStreamingDetokenizer(StreamingDetokenizer):

View File

@@ -116,10 +116,49 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Model Storage
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
## Creating Model Instances via API
When testing with the API, you must first create a model instance before sending chat completions:
```bash
# 1. Get instance previews for a model
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
# 2. Create an instance from the first valid preview
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
# 3. Wait for the runner to become ready (check logs for "runner ready")
# 4. Send chat completions using the full model ID
curl -X POST http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
```
## Logs
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
### Distributed Testing
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
```bash
# On each machine in the test cluster, use the same unique namespace
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
```
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.
## Dashboard UI Testing & Screenshots
### Building and Running the Dashboard

1
conftest.py Normal file
View File

@@ -0,0 +1 @@
collect_ignore = ["tests/start_distributed_test.py"]

View File

@@ -184,10 +184,19 @@ def apply_instance_created(event: InstanceCreated, state: State) -> State:
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
deleted_instance = state.instances.get(event.instance_id)
new_instances: Mapping[InstanceId, Instance] = {
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
}
return state.model_copy(update={"instances": new_instances})
runner_ids_to_remove: set[RunnerId] = set()
if deleted_instance is not None:
runner_ids_to_remove = set(
deleted_instance.shard_assignments.runner_to_shard.keys()
)
new_runners: Mapping[RunnerId, RunnerStatus] = {
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
}
return state.model_copy(update={"instances": new_instances, "runners": new_runners})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:

View File

@@ -0,0 +1,142 @@
from exo.shared.apply import apply_instance_deleted
from exo.shared.models.model_cards import ModelId
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import InstanceDeleted
from exo.shared.types.state import State
from exo.shared.types.worker.instances import InstanceId, MlxRingInstance
from exo.shared.types.worker.runners import (
RunnerId,
RunnerReady,
ShardAssignments,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
INSTANCE_1_ID,
INSTANCE_2_ID,
MODEL_A_ID,
MODEL_B_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
)
def _make_instance(
instance_id: InstanceId,
model_id: ModelId,
node_to_runner: dict[NodeId, RunnerId],
runner_to_shard: dict[RunnerId, ShardMetadata],
) -> MlxRingInstance:
return MlxRingInstance(
instance_id=instance_id,
shard_assignments=ShardAssignments(
model_id=model_id,
node_to_runner=node_to_runner,
runner_to_shard=runner_to_shard,
),
hosts_by_node={},
ephemeral_port=50000,
)
def test_instance_deleted_removes_runners():
"""Deleting an instance must also remove its runner entries from state."""
shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
instance = _make_instance(
INSTANCE_1_ID,
MODEL_A_ID,
{NODE_A: RUNNER_1_ID},
{RUNNER_1_ID: shard},
)
state = State(
instances={INSTANCE_1_ID: instance},
runners={RUNNER_1_ID: RunnerReady()},
)
new_state = apply_instance_deleted(
InstanceDeleted(instance_id=INSTANCE_1_ID), state
)
assert INSTANCE_1_ID not in new_state.instances
assert RUNNER_1_ID not in new_state.runners
def test_instance_deleted_removes_only_its_runners():
"""Deleting one instance must not remove runners belonging to another."""
shard_a = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
shard_b = get_pipeline_shard_metadata(MODEL_B_ID, device_rank=0)
instance_1 = _make_instance(
INSTANCE_1_ID,
MODEL_A_ID,
{NODE_A: RUNNER_1_ID},
{RUNNER_1_ID: shard_a},
)
instance_2 = _make_instance(
INSTANCE_2_ID,
MODEL_B_ID,
{NODE_B: RUNNER_2_ID},
{RUNNER_2_ID: shard_b},
)
state = State(
instances={INSTANCE_1_ID: instance_1, INSTANCE_2_ID: instance_2},
runners={RUNNER_1_ID: RunnerReady(), RUNNER_2_ID: RunnerReady()},
)
new_state = apply_instance_deleted(
InstanceDeleted(instance_id=INSTANCE_1_ID), state
)
assert INSTANCE_1_ID not in new_state.instances
assert RUNNER_1_ID not in new_state.runners
# Instance 2 and its runner must remain
assert INSTANCE_2_ID in new_state.instances
assert RUNNER_2_ID in new_state.runners
def test_instance_deleted_multi_node_removes_all_runners():
"""Deleting a multi-node instance removes all of its runners."""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = _make_instance(
INSTANCE_1_ID,
MODEL_A_ID,
{NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
{RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
state = State(
instances={INSTANCE_1_ID: instance},
runners={RUNNER_1_ID: RunnerReady(), RUNNER_2_ID: RunnerReady()},
)
new_state = apply_instance_deleted(
InstanceDeleted(instance_id=INSTANCE_1_ID), state
)
assert INSTANCE_1_ID not in new_state.instances
assert RUNNER_1_ID not in new_state.runners
assert RUNNER_2_ID not in new_state.runners
assert len(new_state.runners) == 0
def test_instance_deleted_unknown_id_is_noop_for_runners():
"""Deleting a non-existent instance should not affect runners."""
shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
instance = _make_instance(
INSTANCE_1_ID,
MODEL_A_ID,
{NODE_A: RUNNER_1_ID},
{RUNNER_1_ID: shard},
)
unknown_id = InstanceId("99999999-9999-4999-8999-999999999999")
state = State(
instances={INSTANCE_1_ID: instance},
runners={RUNNER_1_ID: RunnerReady()},
)
new_state = apply_instance_deleted(InstanceDeleted(instance_id=unknown_id), state)
# Everything should remain untouched
assert INSTANCE_1_ID in new_state.instances
assert RUNNER_1_ID in new_state.runners

View File

@@ -50,7 +50,9 @@ class RunnerReady(BaseRunnerStatus):
class RunnerRunning(BaseRunnerStatus):
pass
"""Runner is processing requests and can accept more (continuous batching)."""
active_requests: int = 0
class RunnerShuttingDown(BaseRunnerStatus):

View File

@@ -0,0 +1,317 @@
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
import time
from dataclasses import dataclass, field
from typing import get_args
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import TaskId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import MAX_TOKENS
from exo.worker.engines.mlx.generator.distributed_sync import share_object
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.runner.bootstrap import logger
@dataclass
class PendingInsert:
"""Pre-tokenized request ready for batch insertion."""
command_id: CommandId
task_id: TaskId
tokens: list[int]
max_tokens: int
prompt_tokens: int
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
@dataclass
class ActiveRequest:
"""Tracks an active request in the batch."""
command_id: CommandId
task_id: TaskId
uid: int # BatchGenerator's internal ID
detokenizer: StreamingDetokenizer
tokens_generated: int = 0
prompt_tokens: int = 0
start_time: float = field(default_factory=time.perf_counter)
@dataclass
class BatchedGenerationResponse:
"""Response from batch engine, tagged with command_id and task_id."""
command_id: CommandId
task_id: TaskId
response: GenerationResponse
class BatchGenerationEngine:
"""Manages continuous batching using mlx_lm's BatchGenerator."""
def __init__(
self,
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
max_tokens: int = MAX_TOKENS,
completion_batch_size: int = 32,
prefill_batch_size: int = 8,
prefill_step_size: int = 2048,
):
self.model = model
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.active_requests: dict[int, ActiveRequest] = {}
self._pending_inserts: list[PendingInsert] = []
self._pending_completions: list[
int
] = [] # UIDs completed but not yet synced/removed
self.group = group
self.rank = group.rank() if group else 0
self.is_distributed = group is not None and group.size() > 1
sampler = make_sampler(temp=0.7, top_p=1.0)
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
self.batch_gen: BatchGenerator = BatchGenerator(
model=model,
max_tokens=max_tokens,
stop_tokens=eos_tokens,
sampler=sampler,
completion_batch_size=completion_batch_size,
prefill_batch_size=prefill_batch_size,
prefill_step_size=prefill_step_size,
)
logger.info(
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
)
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: TextGenerationTaskParams,
) -> str:
"""Queue a pre-tokenized request for insertion. Only rank 0 should call this.
Tokenization happens here (eagerly) so that sync_and_insert_pending()
only does the lightweight batch_gen.insert() call, keeping the decode
thread unblocked for as long as possible.
Returns the prompt string for caller use (e.g. thinking-mode detection).
"""
assert self.rank == 0, "Only rank 0 should queue requests"
prompt_str = apply_chat_template(self.tokenizer, task_params)
tokens: list[int] = self.tokenizer.encode(prompt_str, add_special_tokens=False)
max_tokens = task_params.max_output_tokens or self.max_tokens
self._pending_inserts.append(
PendingInsert(
command_id=command_id,
task_id=task_id,
tokens=tokens,
max_tokens=max_tokens,
prompt_tokens=len(tokens),
temperature=task_params.temperature,
top_p=task_params.top_p,
top_k=task_params.top_k,
)
)
logger.info(
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)}, prompt_tokens={len(tokens)})"
)
return prompt_str
def sync_and_insert_pending(self) -> list[int]:
"""Sync pre-tokenized pending inserts across ranks and insert them. Returns UIDs.
Tokens are already prepared by queue_request(), so this method only does
the lightweight batch_gen.insert() call plus distributed sync if needed.
"""
inserts_to_process: list[PendingInsert]
if not self.is_distributed:
# Non-distributed: just insert directly from pending
inserts_to_process = list(self._pending_inserts)
else:
# Distributed: broadcast pre-tokenized inserts from rank 0 to all ranks
assert self.group is not None
inserts_to_process = share_object(
self._pending_inserts if self.rank == 0 else None,
self.rank,
self.group,
)
if not inserts_to_process:
self._pending_inserts.clear()
return []
# Update sampler from per-request parameters (last request wins for batch)
last = inserts_to_process[-1]
self.batch_gen.sampler = make_sampler( # pyright: ignore[reportAttributeAccessIssue]
temp=last.temperature if last.temperature is not None else 0.7,
top_p=last.top_p if last.top_p is not None else 1.0,
top_k=last.top_k if last.top_k is not None else 0,
)
# Single batched insert for efficient prefill — tokens already prepared
all_tokens = [p.tokens for p in inserts_to_process]
all_max_tokens = [p.max_tokens for p in inserts_to_process]
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
# Track all inserted requests
for i, uid in enumerate(uids):
p = inserts_to_process[i]
self.active_requests[uid] = ActiveRequest(
command_id=p.command_id,
task_id=p.task_id,
uid=uid,
detokenizer=self.tokenizer.detokenizer,
prompt_tokens=p.prompt_tokens,
)
logger.info(
f"Inserted request {p.command_id} with uid={uid}, prompt_tokens={p.prompt_tokens}, max_tokens={p.max_tokens}"
)
self._pending_inserts.clear()
return uids
def step(self) -> list[BatchedGenerationResponse]:
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
responses = self.batch_gen.next()
if not responses:
return []
results: list[BatchedGenerationResponse] = []
for r in responses:
uid: int = r.uid
req = self.active_requests.get(uid)
if req is None:
logger.warning(f"Received response for unknown uid={uid}")
continue
req.tokens_generated += 1
# Decode the token
token: int = r.token
req.detokenizer.add_token(token)
text: str = req.detokenizer.last_segment
stats: GenerationStats | None = None
finish_reason: FinishReason | None = None
raw_finish_reason: str | None = r.finish_reason
if raw_finish_reason is not None:
# Finalize to get remaining text
req.detokenizer.finalize()
text = req.detokenizer.last_segment
elapsed = time.perf_counter() - req.start_time
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
stats = GenerationStats(
prompt_tps=0.0, # Not tracked per-request in batch mode
generation_tps=generation_tps,
prompt_tokens=req.prompt_tokens,
generation_tokens=req.tokens_generated,
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
)
if raw_finish_reason in get_args(FinishReason):
finish_reason = raw_finish_reason # pyright: ignore[reportAssignmentType]
else:
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
finish_reason = "stop"
# Track completion but don't remove yet - wait for sync_completions()
self._pending_completions.append(uid)
logger.info(
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
)
results.append(
BatchedGenerationResponse(
command_id=req.command_id,
task_id=req.task_id,
response=GenerationResponse(
text=text,
token=token,
finish_reason=finish_reason,
stats=stats,
usage=None,
),
)
)
# In non-distributed mode, clean up completions immediately
if not self.is_distributed:
self._remove_completed()
return results
def sync_completions(self) -> None:
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
if not self.is_distributed:
# Non-distributed: early return if nothing to do
if not self._pending_completions:
return
self._remove_completed()
return
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
# This prevents deadlock if one rank has completions and another doesn't
assert self.group is not None
self._pending_completions = share_object(
self._pending_completions if self.rank == 0 else None,
self.rank,
self.group,
)
self._remove_completed()
def _remove_completed(self) -> None:
"""Remove completed requests from tracking."""
for uid in self._pending_completions:
if uid in self.active_requests:
del self.active_requests[uid]
self._pending_completions.clear()
@property
def has_active_requests(self) -> bool:
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
@property
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def active_count(self) -> int:
return len(self.active_requests)
@property
def pending_count(self) -> int:
return len(self.batch_gen.unprocessed_prompts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def has_pending_completions(self) -> bool:
return bool(self._pending_completions)

View File

@@ -0,0 +1,34 @@
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
# pyright: reportAny=false
import pickle
from typing import cast
import mlx.core as mx
def share_object[T](obj: T | None, rank: int, group: mx.distributed.Group) -> T:
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data.
Rank 0 must always provide a non-None object. Non-rank-0 callers pass None
(they are receivers only). Use mx_barrier() instead if no data needs to be shared.
"""
if rank == 0:
assert obj is not None, (
"Rank 0 must provide data; use mx_barrier() to sync without data"
)
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
mx.eval(mx.distributed.all_sum(data, group=group))
return obj
else:
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
if size == 0:
raise RuntimeError(
"share_object received size=0 from rank 0 — protocol violation"
)
data = mx.zeros(size, dtype=mx.uint8)
data = mx.distributed.all_sum(data, group=group)
mx.eval(data)
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))

View File

@@ -0,0 +1,104 @@
"""Time budget iterator for controlling generation loop timing in distributed mode.
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
rather than syncing every token. This reduces distributed sync overhead.
"""
import time
from typing import Iterator
import mlx.core as mx
from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
class TimeBudget(Iterator[None]):
"""Controls generation loop timing, syncing across ranks periodically.
In distributed mode, periodically syncs timing across all ranks to
dynamically adjust iteration count based on actual performance.
In non-distributed mode, simply runs for the time budget.
Usage:
for _ in TimeBudget(budget=0.5):
batch_engine.step()
# ... process responses ...
"""
def __init__(
self,
budget: float = 0.5,
iterations: int = 25,
sync_frequency: int = 10,
group: mx.distributed.Group | None = None,
):
"""Initialize TimeBudget.
Args:
budget: Time budget in seconds before yielding control
iterations: Initial number of iterations per budget period (distributed only)
sync_frequency: How often to sync timing across ranks (distributed only)
group: Distributed group, or None for non-distributed mode
"""
self._budget = budget
self._iterations = iterations
self._sync_frequency = sync_frequency
self._group = group
self._is_distributed = group is not None and group.size() > 1
# Runtime state
self._start: float = 0.0
self._current_iterations: int = 0
self._loops: int = 0
self._time_spent: float = 0.0
def __iter__(self) -> "TimeBudget":
self._start = time.perf_counter()
self._current_iterations = 0
return self
def __next__(self) -> None:
if not self._is_distributed:
# Non-distributed: just check time budget
if time.perf_counter() - self._start > self._budget:
raise StopIteration()
return None
# Distributed mode: iteration-based with periodic timing sync
self._current_iterations += 1
if self._current_iterations > self._iterations:
self._loops += 1
self._time_spent += time.perf_counter() - self._start
if self._loops % self._sync_frequency == 0:
# Sync timing across all ranks
assert self._group is not None
with mx.stream(generation_stream):
time_array = mx.array([self._time_spent], dtype=mx.float32)
total_time = mx.distributed.all_sum(time_array, group=self._group)
mx.eval(total_time)
loop_time = float(total_time.item())
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
if avg_loop_time > 0:
factor = self._budget / avg_loop_time
self._iterations = max(round(self._iterations * factor), 1)
logger.debug(
f"TimeBudget adjusted iterations to {self._iterations}"
)
self._loops = 0
self._time_spent = 0.0
raise StopIteration()
return None
@property
def iterations(self) -> int:
"""Current iterations per budget period."""
return self._iterations

View File

@@ -172,109 +172,120 @@ class Worker:
async def plan_step(self):
while True:
await anyio.sleep(0.1)
task: Task | None = plan(
self.node_id,
self.runners,
self.state.downloads,
self.state.instances,
self.state.runners,
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
)
if task is None:
continue
# Drain all available tasks before sleeping again.
# This ensures concurrent request arrivals are dispatched
# rapidly rather than one-per-100ms.
while True:
task: Task | None = plan(
self.node_id,
self.runners,
self.state.downloads,
self.state.instances,
self.state.runners,
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
)
if task is None:
break
# Gate DownloadModel on backoff BEFORE emitting TaskCreated
# to prevent flooding the event log with useless events
if isinstance(task, DownloadModel):
model_id = task.shard_metadata.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
continue
# Gate DownloadModel on backoff BEFORE emitting TaskCreated
# to prevent flooding the event log with useless events
if isinstance(task, DownloadModel):
model_id = task.shard_metadata.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
break
logger.info(f"Worker plan: {task.__class__.__name__}")
assert task.task_status
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
logger.info(f"Worker plan: {task.__class__.__name__}")
assert task.task_status
await self.event_sender.send(
TaskCreated(task_id=task.task_id, task=task)
)
# lets not kill the worker if a runner is unresponsive
match task:
case CreateRunner():
self._create_supervisor(task)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
case DownloadModel(shard_metadata=shard):
model_id = shard.model_card.model_id
self._download_backoff.record_attempt(model_id)
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
except TimeoutError:
# lets not kill the worker if a runner is unresponsive
match task:
case CreateRunner():
self._create_supervisor(task)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
chunks = self.input_chunk_buffer.get(cmd_id, {})
assembled = "".join(chunks[i] for i in range(len(chunks)))
logger.info(
f"Assembled input image from {len(chunks)} chunks, "
f"total size: {len(assembled)} bytes"
)
# Create modified task with assembled image data
modified_task = ImageEdits(
task_id=task.task_id,
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsTaskParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
model=task.task_params.model,
n=task.task_params.n,
quality=task.task_params.quality,
output_format=task.task_params.output_format,
response_format=task.task_params.response_format,
size=task.task_params.size,
image_strength=task.task_params.image_strength,
bench=task.task_params.bench,
stream=task.task_params.stream,
partial_images=task.task_params.partial_images,
advanced_params=task.task_params.advanced_params,
),
)
# Cleanup buffers
if cmd_id in self.input_chunk_buffer:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
case DownloadModel(shard_metadata=shard):
model_id = shard.model_card.model_id
self._download_backoff.record_attempt(model_id)
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Running,
)
)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.TimedOut,
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
chunks = self.input_chunk_buffer.get(cmd_id, {})
assembled = "".join(chunks[i] for i in range(len(chunks)))
logger.info(
f"Assembled input image from {len(chunks)} chunks, "
f"total size: {len(assembled)} bytes"
)
# Create modified task with assembled image data
modified_task = ImageEdits(
task_id=task.task_id,
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsTaskParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
model=task.task_params.model,
n=task.task_params.n,
quality=task.task_params.quality,
output_format=task.task_params.output_format,
response_format=task.task_params.response_format,
size=task.task_params.size,
image_strength=task.task_params.image_strength,
bench=task.task_params.bench,
stream=task.task_params.stream,
partial_images=task.task_params.partial_images,
advanced_params=task.task_params.advanced_params,
),
)
# Cleanup buffers
if cmd_id in self.input_chunk_buffer:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(
task
)
def shutdown(self):
self._tg.cancel_scope.cancel()

View File

@@ -295,12 +295,14 @@ def _pending_tasks(
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
if task.task_id in runner.completed or task.task_id in runner.pending:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

File diff suppressed because it is too large Load Diff

View File

@@ -148,7 +148,11 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
# Signal start_task() to return, but keep the entry
# in self.pending so _pending_tasks won't re-dispatch.
pending_event = self.pending.get(event.task_id)
if pending_event is not None:
pending_event.set()
continue
if (
isinstance(event, TaskStatusUpdated)
@@ -166,6 +170,8 @@ class RunnerSupervisor:
),
)
self.completed.add(event.task_id)
# Clean up from pending now that it's fully complete
self.pending.pop(event.task_id, None)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask):

View File

@@ -0,0 +1,388 @@
"""
Tests for continuous batching behavior in the runner.
These tests verify that:
1. Single requests work through the batch path
2. Multiple concurrent requests batch together
3. Tokens are routed to the correct requests
4. Requests complete at different times appropriately
NOTE: These tests require the continuous-batching runner architecture
(BatchGenerationEngine) which is not yet integrated with main.
"""
# ruff: noqa: E402
# pyright: reportAny=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportAttributeAccessIssue=false
# pyright: reportInvalidTypeVarUse=false
from typing import Any
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import (
ChunkGenerated,
Event,
RunnerStatusUpdated,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
TextGeneration,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import RunnerRunning
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
class FakeBatchEngineWithTokens:
"""
Fake batch engine that generates a specified number of tokens per request.
This simulates realistic batch generation behavior where:
- Requests are queued on insert
- Each step() call generates one token for all active requests
- Requests complete when they've generated all their tokens
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, TextGenerationTaskParams]
] = []
self._uid_counter = 0
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
self.rank = 0 # Fake rank for testing
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: TextGenerationTaskParams,
) -> str:
"""Queue a request for insertion."""
self._pending_inserts.append((command_id, task_id, task_params))
return ""
def sync_and_insert_pending(self) -> list[int]:
"""Insert all pending requests."""
uids: list[int] = []
for command_id, task_id, task_params in self._pending_inserts:
uid = self._do_insert(command_id, task_id, task_params)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return len(self._pending_inserts) > 0
def _do_insert(
self,
command_id: CommandId,
task_id: TaskId,
task_params: TextGenerationTaskParams | None,
) -> int:
uid = self._uid_counter
self._uid_counter += 1
# Track: (command_id, task_id, tokens_generated, max_tokens)
max_tokens = (
task_params.max_output_tokens if task_params else self._tokens_per_request
)
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
return uid
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
uids_to_remove: list[int] = []
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
self._active_requests.items()
):
tokens_gen += 1
finish_reason = "stop" if tokens_gen >= max_tokens else None
text = f"token{tokens_gen}"
if finish_reason:
uids_to_remove.append(uid)
else:
self._active_requests[uid] = (
command_id,
task_id,
tokens_gen,
max_tokens,
)
results.append(
BatchedGenerationResponse(
command_id=command_id,
task_id=task_id,
response=GenerationResponse(
token=tokens_gen,
text=text,
finish_reason=finish_reason,
usage=None,
),
)
)
for uid in uids_to_remove:
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return len(self._active_requests) > 0
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
def sync_completions(self) -> None:
pass # Completions already removed in step()
@property
def is_distributed(self) -> bool:
return False # Non-distributed mode for testing
class MockTokenizer:
"""Mock tokenizer with tool calling disabled."""
tool_parser = None
tool_call_start = None
tool_call_end = None
has_tool_calling = False
has_thinking = False
class FakeGroup:
"""Fake MLX distributed group for testing."""
def rank(self) -> int:
return 0
def size(self) -> int:
return 1 # Single node (non-distributed)
def make_nothin[T, U, V](res: T):
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
@pytest.fixture
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
class EventCollector:
"""Collects events directly into a list to avoid mp_channel flakiness."""
def __init__(self) -> None:
self.events: list[Event] = []
def send(self, event: Event) -> None:
self.events.append(event)
def close(self) -> None:
pass
def join(self) -> None:
pass
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
"""
Run tasks through the runner, adding shutdown at the end.
Tasks are sent in order, with shutdown sent last.
The batch engine processes between task handling.
"""
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NodeId(NODE_A),
)
task_sender, task_receiver = mp_channel[Task]()
event_collector = EventCollector()
shutdown_task = Shutdown(
task_id=TaskId("shutdown"),
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
with task_sender:
# Send all tasks including shutdown
for t in tasks:
task_sender.send(t)
task_sender.send(shutdown_task)
# Disable cleanup methods to prevent issues
task_receiver.close = lambda: None
task_receiver.join = lambda: None
mlx_runner.main(bound_instance, event_collector, task_receiver) # type: ignore[arg-type]
return event_collector.events
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
def make_chat_task(
task_id: str, command_id: str, max_tokens: int = 3
) -> TextGeneration:
return TextGeneration(
task_id=TaskId(task_id),
command_id=CommandId(command_id),
task_params=TextGenerationTaskParams(
model=MODEL_A_ID,
input=[InputMessage(role="user", content="hello")],
stream=True,
max_output_tokens=max_tokens,
),
instance_id=INSTANCE_1_ID,
)
def test_single_request_generates_tokens(patch_batch_engine: None):
"""
Verify a single request generates the expected tokens through the batch path.
Tokens are generated during the generation loop (not during shutdown drain).
The task completes after all tokens are generated.
"""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Verify ChunkGenerated events are emitted for all tokens
chunk_events = [
e
for e in events
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd1")
]
assert len(chunk_events) == 3, (
f"Expected 3 ChunkGenerated events, got {len(chunk_events)}"
)
# Last chunk should have finish_reason="stop"
last_chunk = chunk_events[-1].chunk
assert isinstance(last_chunk, TokenChunk)
assert last_chunk.finish_reason == "stop"
# Task should be marked complete after tokens are generated
chat_complete = [
e
for e in events
if isinstance(e, TaskStatusUpdated)
and e.task_id == TaskId("chat1")
and e.task_status == TaskStatus.Complete
]
assert len(chat_complete) == 1, "Expected exactly one chat task Complete status"
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
"""Verify RunnerRunning status includes active_requests count."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_chat_task_acknowledged(patch_batch_engine: None):
"""Verify chat completion task is acknowledged with proper status updates."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find the chat task status events
chat_running = [
e
for e in events
if isinstance(e, TaskStatusUpdated)
and e.task_id == TaskId("chat1")
and e.task_status == TaskStatus.Running
]
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
def test_multiple_requests_generate_tokens(patch_batch_engine: None):
"""Verify multiple requests each generate their expected tokens."""
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
# Both requests should generate their expected number of tokens
cmd1_chunks = [
e
for e in events
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd1")
]
cmd2_chunks = [
e
for e in events
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd2")
]
assert len(cmd1_chunks) == 2, f"Expected 2 chunks for cmd1, got {len(cmd1_chunks)}"
assert len(cmd2_chunks) == 2, f"Expected 2 chunks for cmd2, got {len(cmd2_chunks)}"
# Both tasks should be completed
completed_task_ids = {
e.task_id
for e in events
if isinstance(e, TaskStatusUpdated)
and e.task_status == TaskStatus.Complete
and e.task_id in (TaskId("chat1"), TaskId("chat2"))
}
assert TaskId("chat1") in completed_task_ids
assert TaskId("chat2") in completed_task_ids

View File

@@ -0,0 +1,719 @@
"""
Edge-case tests for continuous batching in the runner.
Tests cover:
1. Concurrent requests with overlapping tool calls
2. Requests that finish mid-generation with 'length' reason
3. Multiple requests finishing on the same step() call
4. Batch of 5+ simultaneous completions
"""
# ruff: noqa: E402
# pyright: reportAny=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportAttributeAccessIssue=false
# pyright: reportInvalidTypeVarUse=false
# pyright: reportPrivateUsage=false
import json
from typing import Any
from unittest.mock import MagicMock
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import (
ChunkGenerated,
Event,
RunnerStatusUpdated,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
TextGeneration,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import RunnerReady
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
# ---------------------------------------------------------------------------
# Fake batch engines
# ---------------------------------------------------------------------------
class ScriptedBatchEngine:
"""Batch engine driven by scripted per-request token sequences.
Each request produces a predefined list of (text, finish_reason) pairs.
One step() call pops one token per active request.
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active: dict[
int, tuple[CommandId, TaskId, list[tuple[str, FinishReason | None]]]
] = {}
self._pending: list[tuple[CommandId, TaskId, TextGenerationTaskParams]] = []
self._uid = 0
self.rank = 0
# map command_id -> scripted tokens, set externally before tasks arrive
self.scripts: dict[str, list[tuple[str, FinishReason | None]]] = {}
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: TextGenerationTaskParams,
) -> str:
self._pending.append((command_id, task_id, task_params))
return ""
def sync_and_insert_pending(self) -> list[int]:
uids: list[int] = []
for cmd_id, task_id, _params in self._pending:
uid = self._uid
self._uid += 1
script = list(self.scripts.get(str(cmd_id), [("tok", "stop")]))
self._active[uid] = (cmd_id, task_id, script)
uids.append(uid)
self._pending.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return bool(self._pending)
@property
def pending_insert_count(self) -> int:
return len(self._pending)
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
done: list[int] = []
for uid, (cmd_id, task_id, script) in self._active.items():
if not script:
continue
text, finish_reason = script.pop(0)
results.append(
BatchedGenerationResponse(
command_id=cmd_id,
task_id=task_id,
response=GenerationResponse(
token=0, text=text, finish_reason=finish_reason, usage=None
),
)
)
if finish_reason is not None:
done.append(uid)
for uid in done:
del self._active[uid]
return results
@property
def has_active_requests(self) -> bool:
return bool(self._active)
@property
def active_count(self) -> int:
return len(self._active)
def sync_completions(self) -> None:
pass
@property
def is_distributed(self) -> bool:
return False
class FakeBatchEngineWithTokens:
"""Generates N tokens per request (reused from the main test file)."""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, TextGenerationTaskParams]
] = []
self._uid_counter = 0
self.rank = 0
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: TextGenerationTaskParams,
) -> str:
self._pending_inserts.append((command_id, task_id, task_params))
return ""
def sync_and_insert_pending(self) -> list[int]:
uids: list[int] = []
for command_id, task_id, task_params in self._pending_inserts:
uid = self._uid_counter
self._uid_counter += 1
max_tokens = task_params.max_output_tokens or 3
self._active_requests[uid] = (command_id, task_id, 0, max_tokens)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
done: list[int] = []
for uid, (cmd_id, task_id, tokens_gen, max_tokens) in list(
self._active_requests.items()
):
tokens_gen += 1
finish = "stop" if tokens_gen >= max_tokens else None
results.append(
BatchedGenerationResponse(
command_id=cmd_id,
task_id=task_id,
response=GenerationResponse(
token=tokens_gen,
text=f"token{tokens_gen}",
finish_reason=finish,
usage=None,
),
)
)
if finish:
done.append(uid)
else:
self._active_requests[uid] = (cmd_id, task_id, tokens_gen, max_tokens)
for uid in done:
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return bool(self._active_requests)
@property
def active_count(self) -> int:
return len(self._active_requests)
def sync_completions(self) -> None:
pass
@property
def is_distributed(self) -> bool:
return False
# ---------------------------------------------------------------------------
# Mock tokenizers
# ---------------------------------------------------------------------------
class MockTokenizer:
tool_parser = None
tool_call_start = None
tool_call_end = None
has_tool_calling = False
has_thinking = False
class MockToolTokenizer:
"""Tokenizer with tool calling enabled for testing."""
has_tool_calling = True
has_thinking = False
tool_call_start = "<tool>"
tool_call_end = "</tool>"
@staticmethod
def _tool_parser(text: str) -> dict[str, Any]:
return json.loads(text)
class FakeGroup:
def rank(self) -> int:
return 0
def size(self) -> int:
return 1
# ---------------------------------------------------------------------------
# Event collector & runner helper
# ---------------------------------------------------------------------------
class EventCollector:
def __init__(self) -> None:
self.events: list[Event] = []
def send(self, event: Event) -> None:
self.events.append(event)
def close(self) -> None:
pass
def join(self) -> None:
pass
def make_nothin[T, U, V](res: T):
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
SETUP_TASKS: list[Task] = [INIT_TASK, LOAD_TASK, WARMUP_TASK]
def make_chat_task(
task_id: str, command_id: str, max_tokens: int = 3
) -> TextGeneration:
return TextGeneration(
task_id=TaskId(task_id),
command_id=CommandId(command_id),
task_params=TextGenerationTaskParams(
model=MODEL_A_ID,
input=[InputMessage(role="user", content="hello")],
stream=True,
max_output_tokens=max_tokens,
),
instance_id=INSTANCE_1_ID,
)
def _run_with_tasks(
tasks: list[Task],
engine_cls: type = FakeBatchEngineWithTokens,
tokenizer_cls: type = MockTokenizer,
engine_instance: Any | None = None,
) -> list[Event]:
"""Run tasks through the runner with configurable engine and tokenizer."""
bound = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NodeId(NODE_A),
)
task_sender, task_receiver = mp_channel[Task]()
collector = EventCollector()
shutdown = Shutdown(
task_id=TaskId("shutdown"),
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
import exo.worker.runner.runner as r
orig_init_mlx = r.initialize_mlx
orig_load = r.load_mlx_items
orig_warmup = r.warmup_inference
orig_check = r._check_for_debug_prompts
orig_engine = r.BatchGenerationEngine
r.initialize_mlx = make_nothin(FakeGroup())
r.load_mlx_items = make_nothin((MagicMock(), tokenizer_cls))
r.warmup_inference = make_nothin(1)
r._check_for_debug_prompts = make_nothin(None)
if engine_instance is not None:
r.BatchGenerationEngine = lambda *_a, **_kw: engine_instance # pyright: ignore[reportUnknownLambdaType]
else:
r.BatchGenerationEngine = engine_cls
try:
with task_sender:
for t in tasks:
task_sender.send(t)
task_sender.send(shutdown)
task_receiver.close = lambda: None
task_receiver.join = lambda: None
r.main(bound, collector, task_receiver) # pyright: ignore[reportArgumentType]
finally:
r.initialize_mlx = orig_init_mlx
r.load_mlx_items = orig_load
r.warmup_inference = orig_warmup
r._check_for_debug_prompts = orig_check
r.BatchGenerationEngine = orig_engine
return collector.events
# ---------------------------------------------------------------------------
# Helpers for querying events
# ---------------------------------------------------------------------------
def chunks_for(events: list[Event], command_id: str) -> list[ChunkGenerated]:
return [
e
for e in events
if isinstance(e, ChunkGenerated) and e.command_id == CommandId(command_id)
]
def completed_task_ids(events: list[Event]) -> set[TaskId]:
return {
e.task_id
for e in events
if isinstance(e, TaskStatusUpdated) and e.task_status == TaskStatus.Complete
}
# ===========================================================================
# Test 1: Concurrent requests with overlapping tool calls
# ===========================================================================
def test_concurrent_tool_calls_and_normal_text():
"""Two concurrent requests: one emits normal text, the other a tool call.
Verifies that:
- The normal request produces TokenChunks with its text
- The tool-call request produces a ToolCallChunk
- Both tasks complete
"""
engine = ScriptedBatchEngine()
# cmd_normal: 2 normal tokens then stop
engine.scripts["cmd_normal"] = [
("hello", None),
(" world", "stop"),
]
# cmd_tool: tool_start, body, tool_end (suppressed), then finish
engine.scripts["cmd_tool"] = [
("<tool>", None), # swallowed by tracker
('{"name":"get_weather","arguments":{"city":"SF"}}', None), # accumulated
("</tool>", None), # triggers ToolCallChunk emission
("done", "stop"), # normal trailing token
]
chat_normal = make_chat_task("t_normal", "cmd_normal", max_tokens=100)
chat_tool = make_chat_task("t_tool", "cmd_tool", max_tokens=100)
events = _run_with_tasks(
[*SETUP_TASKS, chat_normal, chat_tool],
tokenizer_cls=MockToolTokenizer,
engine_instance=engine,
)
# Normal request: all chunks should be TokenChunk
normal_chunks = chunks_for(events, "cmd_normal")
assert len(normal_chunks) == 2
assert all(isinstance(c.chunk, TokenChunk) for c in normal_chunks)
assert normal_chunks[-1].chunk.finish_reason == "stop"
# Tool-call request
tool_chunks = chunks_for(events, "cmd_tool")
# <tool> → swallowed, body → accumulated, </tool> → ToolCallChunk, "done" → TokenChunk
tool_call_events = [c for c in tool_chunks if isinstance(c.chunk, ToolCallChunk)]
token_events = [c for c in tool_chunks if isinstance(c.chunk, TokenChunk)]
assert len(tool_call_events) == 1, (
f"Expected 1 ToolCallChunk, got {len(tool_call_events)}"
)
tc_chunk = tool_call_events[0].chunk
assert isinstance(tc_chunk, ToolCallChunk)
assert tc_chunk.tool_calls[0].name == "get_weather"
assert json.loads(tc_chunk.tool_calls[0].arguments) == {"city": "SF"}
assert len(token_events) == 1, "Expected 1 trailing TokenChunk after tool call"
assert token_events[0].chunk.finish_reason == "stop"
# Both tasks should complete
done = completed_task_ids(events)
assert TaskId("t_normal") in done
assert TaskId("t_tool") in done
def test_tool_call_interrupted_by_finish_reason():
"""Tool call in progress when finish_reason fires — partial text emitted."""
engine = ScriptedBatchEngine()
engine.scripts["cmd1"] = [
("<tool>", None),
('{"name":"f"', "stop"), # finish while inside tool call
]
chat = make_chat_task("t1", "cmd1", max_tokens=100)
events = _run_with_tasks(
[*SETUP_TASKS, chat],
tokenizer_cls=MockToolTokenizer,
engine_instance=engine,
)
chunks = chunks_for(events, "cmd1")
assert len(chunks) == 1
chunk = chunks[0].chunk
assert isinstance(chunk, TokenChunk)
# The interrupted tool call should be emitted as raw text
assert "<tool>" in chunk.text
assert '{"name":"f"' in chunk.text
assert chunk.finish_reason == "stop"
assert TaskId("t1") in completed_task_ids(events)
# ===========================================================================
# Test 2: Request finishing with 'length' reason (timeout mid-generation)
# ===========================================================================
def test_request_finishes_with_length_reason():
"""Request that hits max_tokens limit and finishes with 'length'."""
engine = ScriptedBatchEngine()
engine.scripts["cmd1"] = [
("tok1", None),
("tok2", None),
("tok3", "length"), # hit the token limit
]
chat = make_chat_task("t1", "cmd1", max_tokens=100)
events = _run_with_tasks(
[*SETUP_TASKS, chat],
engine_instance=engine,
)
chunks = chunks_for(events, "cmd1")
assert len(chunks) == 3
# Last chunk should have finish_reason="length"
assert isinstance(chunks[-1].chunk, TokenChunk)
assert chunks[-1].chunk.finish_reason == "length"
# Earlier chunks should have no finish_reason
for c in chunks[:-1]:
assert isinstance(c.chunk, TokenChunk)
assert c.chunk.finish_reason is None
assert TaskId("t1") in completed_task_ids(events)
def test_mixed_finish_reasons_across_requests():
"""Two requests finishing with different reasons: 'stop' and 'length'."""
engine = ScriptedBatchEngine()
engine.scripts["cmd_stop"] = [("a", None), ("b", "stop")]
engine.scripts["cmd_len"] = [("x", None), ("y", "length")]
chat1 = make_chat_task("t_stop", "cmd_stop", max_tokens=100)
chat2 = make_chat_task("t_len", "cmd_len", max_tokens=100)
events = _run_with_tasks(
[*SETUP_TASKS, chat1, chat2],
engine_instance=engine,
)
stop_chunks = chunks_for(events, "cmd_stop")
len_chunks = chunks_for(events, "cmd_len")
assert stop_chunks[-1].chunk.finish_reason == "stop"
assert len_chunks[-1].chunk.finish_reason == "length"
done = completed_task_ids(events)
assert TaskId("t_stop") in done
assert TaskId("t_len") in done
# ===========================================================================
# Test 3: Multiple finish reasons in rapid succession (same step)
# ===========================================================================
def test_all_requests_finish_on_same_step():
"""Three requests that all finish on the same step() call.
This tests that the runner and _process_generation_results correctly
handle multiple completions in a single step.
"""
engine = ScriptedBatchEngine()
# All three produce exactly 1 token and finish
engine.scripts["cmd_a"] = [("alpha", "stop")]
engine.scripts["cmd_b"] = [("beta", "stop")]
engine.scripts["cmd_c"] = [("gamma", "stop")]
tasks = [
*SETUP_TASKS,
make_chat_task("ta", "cmd_a", max_tokens=100),
make_chat_task("tb", "cmd_b", max_tokens=100),
make_chat_task("tc", "cmd_c", max_tokens=100),
]
events = _run_with_tasks([*tasks], engine_instance=engine)
for cmd_id, expected_text in [
("cmd_a", "alpha"),
("cmd_b", "beta"),
("cmd_c", "gamma"),
]:
c = chunks_for(events, cmd_id)
assert len(c) == 1, f"Expected 1 chunk for {cmd_id}, got {len(c)}"
assert isinstance(c[0].chunk, TokenChunk)
assert c[0].chunk.text == expected_text
assert c[0].chunk.finish_reason == "stop"
done = completed_task_ids(events)
assert TaskId("ta") in done
assert TaskId("tb") in done
assert TaskId("tc") in done
# Runner should reach RunnerReady at least after warmup.
# With inline task processing, later requests may be inserted into the
# batch before the generation loop exits, so the runner can stay
# RunnerRunning until Shutdown without an intermediate RunnerReady.
ready_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerReady)
]
assert len(ready_events) >= 1, "Expected RunnerReady at least after warmup"
def test_staggered_completions_in_batch():
"""Four requests with different token counts — they complete at different steps.
Verifies each request gets the right number of chunks and the runner
tracks active_requests correctly as requests drain.
"""
engine = ScriptedBatchEngine()
engine.scripts["c1"] = [("a", "stop")] # finishes step 1
engine.scripts["c2"] = [("a", None), ("b", "stop")] # finishes step 2
engine.scripts["c3"] = [("a", None), ("b", None), ("c", "stop")] # finishes step 3
engine.scripts["c4"] = [
("a", None),
("b", None),
("c", None),
("d", "stop"),
] # finishes step 4
tasks = [
*SETUP_TASKS,
make_chat_task("t1", "c1", max_tokens=100),
make_chat_task("t2", "c2", max_tokens=100),
make_chat_task("t3", "c3", max_tokens=100),
make_chat_task("t4", "c4", max_tokens=100),
]
events = _run_with_tasks([*tasks], engine_instance=engine)
assert len(chunks_for(events, "c1")) == 1
assert len(chunks_for(events, "c2")) == 2
assert len(chunks_for(events, "c3")) == 3
assert len(chunks_for(events, "c4")) == 4
done = completed_task_ids(events)
for tid in ["t1", "t2", "t3", "t4"]:
assert TaskId(tid) in done, f"Task {tid} should be complete"
# ===========================================================================
# Test 4: Batch of 5+ simultaneous completions
# ===========================================================================
@pytest.fixture
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
monkeypatch.setattr(
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MockTokenizer))
)
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
def test_five_simultaneous_completions(patch_batch_engine: None):
"""Five requests submitted together, all generating tokens and completing."""
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=2) for i in range(5)]
events = _run_with_tasks([*SETUP_TASKS, *chats])
for i in range(5):
c = chunks_for(events, f"cmd{i}")
assert len(c) == 2, f"Expected 2 chunks for cmd{i}, got {len(c)}"
assert c[-1].chunk.finish_reason == "stop"
done = completed_task_ids(events)
for i in range(5):
assert TaskId(f"t{i}") in done
def test_eight_requests_staggered(patch_batch_engine: None):
"""Eight requests with varying token counts, verifying all complete correctly."""
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=i + 1) for i in range(8)]
events = _run_with_tasks([*SETUP_TASKS, *chats])
for i in range(8):
c = chunks_for(events, f"cmd{i}")
expected = i + 1
assert len(c) == expected, (
f"Expected {expected} chunks for cmd{i}, got {len(c)}"
)
assert c[-1].chunk.finish_reason == "stop"
done = completed_task_ids(events)
for i in range(8):
assert TaskId(f"t{i}") in done
# Verify runner transitions back to ready after all requests complete
# Find the last RunnerReady before shutdown
ready_events = [
(idx, e)
for idx, e in enumerate(events)
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerReady)
]
shutdown_idx = next(
idx
for idx, e in enumerate(events)
if isinstance(e, TaskStatusUpdated)
and e.task_id == TaskId("shutdown")
and e.task_status == TaskStatus.Running
)
# There should be a RunnerReady event between generation and shutdown
ready_before_shutdown = [idx for idx, _ in ready_events if idx < shutdown_idx]
assert len(ready_before_shutdown) >= 1, (
"Expected RunnerReady between generation completion and shutdown"
)
def test_ten_simultaneous_single_token():
"""Ten requests that each produce exactly one token — all finish on step 1."""
engine = ScriptedBatchEngine()
for i in range(10):
engine.scripts[f"cmd{i}"] = [(f"word{i}", "stop")]
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=100) for i in range(10)]
events = _run_with_tasks([*SETUP_TASKS, *chats], engine_instance=engine)
for i in range(10):
c = chunks_for(events, f"cmd{i}")
assert len(c) == 1
assert isinstance(c[0].chunk, TokenChunk)
assert c[0].chunk.text == f"word{i}"
assert c[0].chunk.finish_reason == "stop"
done = completed_task_ids(events)
assert len(done & {TaskId(f"t{i}") for i in range(10)}) == 10

View File

@@ -6,6 +6,7 @@ import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -19,6 +20,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
TextGeneration,
)
@@ -37,6 +39,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import BatchedGenerationResponse
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -113,15 +116,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
# Use a fake event_sender to remove test flakiness.
@@ -144,6 +139,7 @@ class MockTokenizer:
tool_call_start = None
tool_call_end = None
has_tool_calling = False
has_thinking = False
class MockGroup:
@@ -154,6 +150,70 @@ class MockGroup:
return 1
class FakeBatchEngine:
"""Fake batch engine that generates a single 'hi' token per request."""
def __init__(self, *_args: object, **_kwargs: object):
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
self._pending_inserts: list[tuple[CommandId, TaskId, object]] = []
self._uid_counter = 0
self.rank = 0
def queue_request(
self, command_id: CommandId, task_id: TaskId, task_params: object
) -> str:
self._pending_inserts.append((command_id, task_id, task_params))
return ""
def sync_and_insert_pending(self) -> list[int]:
uids: list[int] = []
for cmd_id, task_id, _params in self._pending_inserts:
uid = self._uid_counter
self._uid_counter += 1
self._active_requests[uid] = (cmd_id, task_id)
uids.append(uid)
self._pending_inserts.clear()
return uids
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
for _uid, (cmd_id, task_id) in list(self._active_requests.items()):
results.append(
BatchedGenerationResponse(
command_id=cmd_id,
task_id=task_id,
response=GenerationResponse(
token=0, text="hi", finish_reason="stop", usage=None
),
)
)
self._active_requests.clear()
return results
def sync_completions(self) -> None:
pass
@property
def has_active_requests(self) -> bool:
return bool(self._active_requests)
@property
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def is_distributed(self) -> bool:
return False
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -219,17 +279,22 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
TaskAcknowledged(task_id=WARMUP_TASK_ID),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
# CHAT TASK: queued, tokens generated, then completed
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID,
runner_status=RunnerRunning(active_requests=1),
),
# Generation loop produces token and completes the task
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
),
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
# SHUTDOWN
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
@@ -238,7 +303,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
],
)