mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-16 00:52:56 -05:00
Compare commits
17 Commits
alexcheema
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7469f44e58 | ||
|
|
5d26d2dcd6 | ||
|
|
35973b8698 | ||
|
|
41d9d2a61f | ||
|
|
efbf9850eb | ||
|
|
4a22c4b512 | ||
|
|
4d0fe5d17b | ||
|
|
9fe7251796 | ||
|
|
1c8f69ce00 | ||
|
|
f19166617a | ||
|
|
51e959c979 | ||
|
|
cd43588a04 | ||
|
|
7b879593bb | ||
|
|
e4e895d7a8 | ||
|
|
db400dbb75 | ||
|
|
15fad9c632 | ||
|
|
842beefac0 |
@@ -276,24 +276,23 @@ class BatchGenerator:
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
unprocessed_prompts: List[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: nn.Module,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
stop_tokens: Optional[set[int]] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
|
||||
) -> List[int]: ...
|
||||
def stats(self) -> BatchStats: ...
|
||||
def next(self) -> List[Response]: ...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
|
||||
@@ -39,11 +39,11 @@ class StreamingDetokenizer:
|
||||
"""
|
||||
|
||||
__slots__ = ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
@property
|
||||
def last_segment(self):
|
||||
def last_segment(self) -> str:
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
||||
39
AGENTS.md
39
AGENTS.md
@@ -116,10 +116,49 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Model Storage
|
||||
|
||||
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
|
||||
|
||||
## Creating Model Instances via API
|
||||
|
||||
When testing with the API, you must first create a model instance before sending chat completions:
|
||||
|
||||
```bash
|
||||
# 1. Get instance previews for a model
|
||||
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
|
||||
|
||||
# 2. Create an instance from the first valid preview
|
||||
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
|
||||
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
|
||||
|
||||
# 3. Wait for the runner to become ready (check logs for "runner ready")
|
||||
|
||||
# 4. Send chat completions using the full model ID
|
||||
curl -X POST http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
|
||||
```
|
||||
|
||||
## Logs
|
||||
|
||||
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
### Distributed Testing
|
||||
|
||||
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
|
||||
|
||||
```bash
|
||||
# On each machine in the test cluster, use the same unique namespace
|
||||
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
|
||||
```
|
||||
|
||||
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.
|
||||
|
||||
## Dashboard UI Testing & Screenshots
|
||||
|
||||
### Building and Running the Dashboard
|
||||
|
||||
1
conftest.py
Normal file
1
conftest.py
Normal file
@@ -0,0 +1 @@
|
||||
collect_ignore = ["tests/start_distributed_test.py"]
|
||||
@@ -184,10 +184,19 @@ def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
|
||||
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
deleted_instance = state.instances.get(event.instance_id)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
runner_ids_to_remove: set[RunnerId] = set()
|
||||
if deleted_instance is not None:
|
||||
runner_ids_to_remove = set(
|
||||
deleted_instance.shard_assignments.runner_to_shard.keys()
|
||||
)
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances, "runners": new_runners})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
|
||||
142
src/exo/shared/tests/test_apply/test_apply_instance_deleted.py
Normal file
142
src/exo/shared/tests/test_apply/test_apply_instance_deleted.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from exo.shared.apply import apply_instance_deleted
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.tests.conftest import get_pipeline_shard_metadata
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import InstanceDeleted
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import InstanceId, MlxRingInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerId,
|
||||
RunnerReady,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
INSTANCE_2_ID,
|
||||
MODEL_A_ID,
|
||||
MODEL_B_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
)
|
||||
|
||||
|
||||
def _make_instance(
|
||||
instance_id: InstanceId,
|
||||
model_id: ModelId,
|
||||
node_to_runner: dict[NodeId, RunnerId],
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata],
|
||||
) -> MlxRingInstance:
|
||||
return MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=model_id,
|
||||
node_to_runner=node_to_runner,
|
||||
runner_to_shard=runner_to_shard,
|
||||
),
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_instance_deleted_removes_runners():
|
||||
"""Deleting an instance must also remove its runner entries from state."""
|
||||
shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
|
||||
instance = _make_instance(
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
{NODE_A: RUNNER_1_ID},
|
||||
{RUNNER_1_ID: shard},
|
||||
)
|
||||
state = State(
|
||||
instances={INSTANCE_1_ID: instance},
|
||||
runners={RUNNER_1_ID: RunnerReady()},
|
||||
)
|
||||
|
||||
new_state = apply_instance_deleted(
|
||||
InstanceDeleted(instance_id=INSTANCE_1_ID), state
|
||||
)
|
||||
|
||||
assert INSTANCE_1_ID not in new_state.instances
|
||||
assert RUNNER_1_ID not in new_state.runners
|
||||
|
||||
|
||||
def test_instance_deleted_removes_only_its_runners():
|
||||
"""Deleting one instance must not remove runners belonging to another."""
|
||||
shard_a = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
|
||||
shard_b = get_pipeline_shard_metadata(MODEL_B_ID, device_rank=0)
|
||||
instance_1 = _make_instance(
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
{NODE_A: RUNNER_1_ID},
|
||||
{RUNNER_1_ID: shard_a},
|
||||
)
|
||||
instance_2 = _make_instance(
|
||||
INSTANCE_2_ID,
|
||||
MODEL_B_ID,
|
||||
{NODE_B: RUNNER_2_ID},
|
||||
{RUNNER_2_ID: shard_b},
|
||||
)
|
||||
state = State(
|
||||
instances={INSTANCE_1_ID: instance_1, INSTANCE_2_ID: instance_2},
|
||||
runners={RUNNER_1_ID: RunnerReady(), RUNNER_2_ID: RunnerReady()},
|
||||
)
|
||||
|
||||
new_state = apply_instance_deleted(
|
||||
InstanceDeleted(instance_id=INSTANCE_1_ID), state
|
||||
)
|
||||
|
||||
assert INSTANCE_1_ID not in new_state.instances
|
||||
assert RUNNER_1_ID not in new_state.runners
|
||||
# Instance 2 and its runner must remain
|
||||
assert INSTANCE_2_ID in new_state.instances
|
||||
assert RUNNER_2_ID in new_state.runners
|
||||
|
||||
|
||||
def test_instance_deleted_multi_node_removes_all_runners():
|
||||
"""Deleting a multi-node instance removes all of its runners."""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = _make_instance(
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
{NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
{RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
state = State(
|
||||
instances={INSTANCE_1_ID: instance},
|
||||
runners={RUNNER_1_ID: RunnerReady(), RUNNER_2_ID: RunnerReady()},
|
||||
)
|
||||
|
||||
new_state = apply_instance_deleted(
|
||||
InstanceDeleted(instance_id=INSTANCE_1_ID), state
|
||||
)
|
||||
|
||||
assert INSTANCE_1_ID not in new_state.instances
|
||||
assert RUNNER_1_ID not in new_state.runners
|
||||
assert RUNNER_2_ID not in new_state.runners
|
||||
assert len(new_state.runners) == 0
|
||||
|
||||
|
||||
def test_instance_deleted_unknown_id_is_noop_for_runners():
|
||||
"""Deleting a non-existent instance should not affect runners."""
|
||||
shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
|
||||
instance = _make_instance(
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
{NODE_A: RUNNER_1_ID},
|
||||
{RUNNER_1_ID: shard},
|
||||
)
|
||||
unknown_id = InstanceId("99999999-9999-4999-8999-999999999999")
|
||||
state = State(
|
||||
instances={INSTANCE_1_ID: instance},
|
||||
runners={RUNNER_1_ID: RunnerReady()},
|
||||
)
|
||||
|
||||
new_state = apply_instance_deleted(InstanceDeleted(instance_id=unknown_id), state)
|
||||
|
||||
# Everything should remain untouched
|
||||
assert INSTANCE_1_ID in new_state.instances
|
||||
assert RUNNER_1_ID in new_state.runners
|
||||
@@ -50,7 +50,9 @@ class RunnerReady(BaseRunnerStatus):
|
||||
|
||||
|
||||
class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
"""Runner is processing requests and can accept more (continuous batching)."""
|
||||
|
||||
active_requests: int = 0
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
|
||||
317
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
317
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import TaskId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingInsert:
|
||||
"""Pre-tokenized request ready for batch insertion."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
tokens: list[int]
|
||||
max_tokens: int
|
||||
prompt_tokens: int
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""Tracks an active request in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
uid: int # BatchGenerator's internal ID
|
||||
detokenizer: StreamingDetokenizer
|
||||
tokens_generated: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedGenerationResponse:
|
||||
"""Response from batch engine, tagged with command_id and task_id."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
response: GenerationResponse
|
||||
|
||||
|
||||
class BatchGenerationEngine:
|
||||
"""Manages continuous batching using mlx_lm's BatchGenerator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
completion_batch_size: int = 32,
|
||||
prefill_batch_size: int = 8,
|
||||
prefill_step_size: int = 2048,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.active_requests: dict[int, ActiveRequest] = {}
|
||||
self._pending_inserts: list[PendingInsert] = []
|
||||
self._pending_completions: list[
|
||||
int
|
||||
] = [] # UIDs completed but not yet synced/removed
|
||||
|
||||
self.group = group
|
||||
self.rank = group.rank() if group else 0
|
||||
self.is_distributed = group is not None and group.size() > 1
|
||||
|
||||
sampler = make_sampler(temp=0.7, top_p=1.0)
|
||||
|
||||
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
|
||||
|
||||
self.batch_gen: BatchGenerator = BatchGenerator(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
stop_tokens=eos_tokens,
|
||||
sampler=sampler,
|
||||
completion_batch_size=completion_batch_size,
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
prefill_step_size=prefill_step_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
|
||||
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
|
||||
)
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> str:
|
||||
"""Queue a pre-tokenized request for insertion. Only rank 0 should call this.
|
||||
|
||||
Tokenization happens here (eagerly) so that sync_and_insert_pending()
|
||||
only does the lightweight batch_gen.insert() call, keeping the decode
|
||||
thread unblocked for as long as possible.
|
||||
|
||||
Returns the prompt string for caller use (e.g. thinking-mode detection).
|
||||
"""
|
||||
assert self.rank == 0, "Only rank 0 should queue requests"
|
||||
prompt_str = apply_chat_template(self.tokenizer, task_params)
|
||||
tokens: list[int] = self.tokenizer.encode(prompt_str, add_special_tokens=False)
|
||||
max_tokens = task_params.max_output_tokens or self.max_tokens
|
||||
self._pending_inserts.append(
|
||||
PendingInsert(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
tokens=tokens,
|
||||
max_tokens=max_tokens,
|
||||
prompt_tokens=len(tokens),
|
||||
temperature=task_params.temperature,
|
||||
top_p=task_params.top_p,
|
||||
top_k=task_params.top_k,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)}, prompt_tokens={len(tokens)})"
|
||||
)
|
||||
return prompt_str
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Sync pre-tokenized pending inserts across ranks and insert them. Returns UIDs.
|
||||
|
||||
Tokens are already prepared by queue_request(), so this method only does
|
||||
the lightweight batch_gen.insert() call plus distributed sync if needed.
|
||||
"""
|
||||
inserts_to_process: list[PendingInsert]
|
||||
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: just insert directly from pending
|
||||
inserts_to_process = list(self._pending_inserts)
|
||||
else:
|
||||
# Distributed: broadcast pre-tokenized inserts from rank 0 to all ranks
|
||||
assert self.group is not None
|
||||
inserts_to_process = share_object(
|
||||
self._pending_inserts if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
|
||||
if not inserts_to_process:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
# Update sampler from per-request parameters (last request wins for batch)
|
||||
last = inserts_to_process[-1]
|
||||
self.batch_gen.sampler = make_sampler( # pyright: ignore[reportAttributeAccessIssue]
|
||||
temp=last.temperature if last.temperature is not None else 0.7,
|
||||
top_p=last.top_p if last.top_p is not None else 1.0,
|
||||
top_k=last.top_k if last.top_k is not None else 0,
|
||||
)
|
||||
|
||||
# Single batched insert for efficient prefill — tokens already prepared
|
||||
all_tokens = [p.tokens for p in inserts_to_process]
|
||||
all_max_tokens = [p.max_tokens for p in inserts_to_process]
|
||||
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
|
||||
|
||||
# Track all inserted requests
|
||||
for i, uid in enumerate(uids):
|
||||
p = inserts_to_process[i]
|
||||
self.active_requests[uid] = ActiveRequest(
|
||||
command_id=p.command_id,
|
||||
task_id=p.task_id,
|
||||
uid=uid,
|
||||
detokenizer=self.tokenizer.detokenizer,
|
||||
prompt_tokens=p.prompt_tokens,
|
||||
)
|
||||
logger.info(
|
||||
f"Inserted request {p.command_id} with uid={uid}, prompt_tokens={p.prompt_tokens}, max_tokens={p.max_tokens}"
|
||||
)
|
||||
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
|
||||
responses = self.batch_gen.next()
|
||||
if not responses:
|
||||
return []
|
||||
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
|
||||
for r in responses:
|
||||
uid: int = r.uid
|
||||
req = self.active_requests.get(uid)
|
||||
if req is None:
|
||||
logger.warning(f"Received response for unknown uid={uid}")
|
||||
continue
|
||||
|
||||
req.tokens_generated += 1
|
||||
|
||||
# Decode the token
|
||||
token: int = r.token
|
||||
req.detokenizer.add_token(token)
|
||||
text: str = req.detokenizer.last_segment
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
raw_finish_reason: str | None = r.finish_reason
|
||||
if raw_finish_reason is not None:
|
||||
# Finalize to get remaining text
|
||||
req.detokenizer.finalize()
|
||||
text = req.detokenizer.last_segment
|
||||
|
||||
elapsed = time.perf_counter() - req.start_time
|
||||
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=0.0, # Not tracked per-request in batch mode
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=req.prompt_tokens,
|
||||
generation_tokens=req.tokens_generated,
|
||||
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
|
||||
)
|
||||
|
||||
if raw_finish_reason in get_args(FinishReason):
|
||||
finish_reason = raw_finish_reason # pyright: ignore[reportAssignmentType]
|
||||
else:
|
||||
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
|
||||
finish_reason = "stop"
|
||||
|
||||
# Track completion but don't remove yet - wait for sync_completions()
|
||||
self._pending_completions.append(uid)
|
||||
logger.info(
|
||||
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=req.command_id,
|
||||
task_id=req.task_id,
|
||||
response=GenerationResponse(
|
||||
text=text,
|
||||
token=token,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# In non-distributed mode, clean up completions immediately
|
||||
if not self.is_distributed:
|
||||
self._remove_completed()
|
||||
|
||||
return results
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: early return if nothing to do
|
||||
if not self._pending_completions:
|
||||
return
|
||||
self._remove_completed()
|
||||
return
|
||||
|
||||
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
|
||||
# This prevents deadlock if one rank has completions and another doesn't
|
||||
assert self.group is not None
|
||||
self._pending_completions = share_object(
|
||||
self._pending_completions if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
self._remove_completed()
|
||||
|
||||
def _remove_completed(self) -> None:
|
||||
"""Remove completed requests from tracking."""
|
||||
for uid in self._pending_completions:
|
||||
if uid in self.active_requests:
|
||||
del self.active_requests[uid]
|
||||
self._pending_completions.clear()
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self.active_requests)
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return len(self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def has_pending_completions(self) -> bool:
|
||||
return bool(self._pending_completions)
|
||||
34
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
34
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
import pickle
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def share_object[T](obj: T | None, rank: int, group: mx.distributed.Group) -> T:
|
||||
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data.
|
||||
|
||||
Rank 0 must always provide a non-None object. Non-rank-0 callers pass None
|
||||
(they are receivers only). Use mx_barrier() instead if no data needs to be shared.
|
||||
"""
|
||||
if rank == 0:
|
||||
assert obj is not None, (
|
||||
"Rank 0 must provide data; use mx_barrier() to sync without data"
|
||||
)
|
||||
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
|
||||
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
|
||||
mx.eval(mx.distributed.all_sum(data, group=group))
|
||||
return obj
|
||||
else:
|
||||
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
|
||||
if size == 0:
|
||||
raise RuntimeError(
|
||||
"share_object received size=0 from rank 0 — protocol violation"
|
||||
)
|
||||
data = mx.zeros(size, dtype=mx.uint8)
|
||||
data = mx.distributed.all_sum(data, group=group)
|
||||
mx.eval(data)
|
||||
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))
|
||||
104
src/exo/worker/engines/mlx/generator/time_budget.py
Normal file
104
src/exo/worker/engines/mlx/generator/time_budget.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Time budget iterator for controlling generation loop timing in distributed mode.
|
||||
|
||||
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
|
||||
rather than syncing every token. This reduces distributed sync overhead.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class TimeBudget(Iterator[None]):
|
||||
"""Controls generation loop timing, syncing across ranks periodically.
|
||||
|
||||
In distributed mode, periodically syncs timing across all ranks to
|
||||
dynamically adjust iteration count based on actual performance.
|
||||
|
||||
In non-distributed mode, simply runs for the time budget.
|
||||
|
||||
Usage:
|
||||
for _ in TimeBudget(budget=0.5):
|
||||
batch_engine.step()
|
||||
# ... process responses ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
budget: float = 0.5,
|
||||
iterations: int = 25,
|
||||
sync_frequency: int = 10,
|
||||
group: mx.distributed.Group | None = None,
|
||||
):
|
||||
"""Initialize TimeBudget.
|
||||
|
||||
Args:
|
||||
budget: Time budget in seconds before yielding control
|
||||
iterations: Initial number of iterations per budget period (distributed only)
|
||||
sync_frequency: How often to sync timing across ranks (distributed only)
|
||||
group: Distributed group, or None for non-distributed mode
|
||||
"""
|
||||
self._budget = budget
|
||||
self._iterations = iterations
|
||||
self._sync_frequency = sync_frequency
|
||||
self._group = group
|
||||
self._is_distributed = group is not None and group.size() > 1
|
||||
|
||||
# Runtime state
|
||||
self._start: float = 0.0
|
||||
self._current_iterations: int = 0
|
||||
self._loops: int = 0
|
||||
self._time_spent: float = 0.0
|
||||
|
||||
def __iter__(self) -> "TimeBudget":
|
||||
self._start = time.perf_counter()
|
||||
self._current_iterations = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> None:
|
||||
if not self._is_distributed:
|
||||
# Non-distributed: just check time budget
|
||||
if time.perf_counter() - self._start > self._budget:
|
||||
raise StopIteration()
|
||||
return None
|
||||
|
||||
# Distributed mode: iteration-based with periodic timing sync
|
||||
self._current_iterations += 1
|
||||
if self._current_iterations > self._iterations:
|
||||
self._loops += 1
|
||||
self._time_spent += time.perf_counter() - self._start
|
||||
|
||||
if self._loops % self._sync_frequency == 0:
|
||||
# Sync timing across all ranks
|
||||
assert self._group is not None
|
||||
with mx.stream(generation_stream):
|
||||
time_array = mx.array([self._time_spent], dtype=mx.float32)
|
||||
total_time = mx.distributed.all_sum(time_array, group=self._group)
|
||||
mx.eval(total_time)
|
||||
loop_time = float(total_time.item())
|
||||
|
||||
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
|
||||
|
||||
if avg_loop_time > 0:
|
||||
factor = self._budget / avg_loop_time
|
||||
self._iterations = max(round(self._iterations * factor), 1)
|
||||
logger.debug(
|
||||
f"TimeBudget adjusted iterations to {self._iterations}"
|
||||
)
|
||||
|
||||
self._loops = 0
|
||||
self._time_spent = 0.0
|
||||
|
||||
raise StopIteration()
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Current iterations per budget period."""
|
||||
return self._iterations
|
||||
@@ -172,109 +172,120 @@ class Worker:
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
# Drain all available tasks before sleeping again.
|
||||
# This ensures concurrent request arrivals are dispatched
|
||||
# rapidly rather than one-per-100ms.
|
||||
while True:
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
break
|
||||
|
||||
# Gate DownloadModel on backoff BEFORE emitting TaskCreated
|
||||
# to prevent flooding the event log with useless events
|
||||
if isinstance(task, DownloadModel):
|
||||
model_id = task.shard_metadata.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
continue
|
||||
# Gate DownloadModel on backoff BEFORE emitting TaskCreated
|
||||
# to prevent flooding the event log with useless events
|
||||
if isinstance(task, DownloadModel):
|
||||
model_id = task.shard_metadata.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
break
|
||||
|
||||
logger.info(f"Worker plan: {task.__class__.__name__}")
|
||||
assert task.task_status
|
||||
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
|
||||
logger.info(f"Worker plan: {task.__class__.__name__}")
|
||||
assert task.task_status
|
||||
await self.event_sender.send(
|
||||
TaskCreated(task_id=task.task_id, task=task)
|
||||
)
|
||||
|
||||
# lets not kill the worker if a runner is unresponsive
|
||||
match task:
|
||||
case CreateRunner():
|
||||
self._create_supervisor(task)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
model_id = shard.model_card.model_id
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
),
|
||||
)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
except TimeoutError:
|
||||
# lets not kill the worker if a runner is unresponsive
|
||||
match task:
|
||||
case CreateRunner():
|
||||
self._create_supervisor(task)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsTaskParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
bench=task.task_params.bench,
|
||||
stream=task.task_params.stream,
|
||||
partial_images=task.task_params.partial_images,
|
||||
advanced_params=task.task_params.advanced_params,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
model_id = shard.model_card.model_id
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
),
|
||||
)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Running,
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.TimedOut,
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsTaskParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
bench=task.task_params.bench,
|
||||
stream=task.task_params.stream,
|
||||
partial_images=task.task_params.partial_images,
|
||||
advanced_params=task.task_params.advanced_params,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
task
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
@@ -295,12 +295,14 @@ def _pending_tasks(
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -148,7 +148,11 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
# Signal start_task() to return, but keep the entry
|
||||
# in self.pending so _pending_tasks won't re-dispatch.
|
||||
pending_event = self.pending.get(event.task_id)
|
||||
if pending_event is not None:
|
||||
pending_event.set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
@@ -166,6 +170,8 @@ class RunnerSupervisor:
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
# Clean up from pending now that it's fully complete
|
||||
self.pending.pop(event.task_id, None)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Tests for continuous batching behavior in the runner.
|
||||
|
||||
These tests verify that:
|
||||
1. Single requests work through the batch path
|
||||
2. Multiple concurrent requests batch together
|
||||
3. Tokens are routed to the correct requests
|
||||
4. Requests complete at different times appropriately
|
||||
|
||||
NOTE: These tests require the continuous-batching runner architecture
|
||||
(BatchGenerationEngine) which is not yet integrated with main.
|
||||
"""
|
||||
|
||||
# ruff: noqa: E402
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerRunning
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""
|
||||
Fake batch engine that generates a specified number of tokens per request.
|
||||
|
||||
This simulates realistic batch generation behavior where:
|
||||
- Requests are queued on insert
|
||||
- Each step() call generates one token for all active requests
|
||||
- Requests complete when they've generated all their tokens
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, TextGenerationTaskParams]
|
||||
] = []
|
||||
self._uid_counter = 0
|
||||
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> str:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
return ""
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._do_insert(command_id, task_id, task_params)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def _do_insert(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams | None,
|
||||
) -> int:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
# Track: (command_id, task_id, tokens_generated, max_tokens)
|
||||
max_tokens = (
|
||||
task_params.max_output_tokens if task_params else self._tokens_per_request
|
||||
)
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish_reason = "stop" if tokens_gen >= max_tokens else None
|
||||
text = f"token{tokens_gen}"
|
||||
|
||||
if finish_reason:
|
||||
uids_to_remove.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (
|
||||
command_id,
|
||||
task_id,
|
||||
tokens_gen,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for uid in uids_to_remove:
|
||||
del self._active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
pass # Completions already removed in step()
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False # Non-distributed mode for testing
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
"""Mock tokenizer with tool calling disabled."""
|
||||
|
||||
tool_parser = None
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
"""Fake MLX distributed group for testing."""
|
||||
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1 # Single node (non-distributed)
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
class EventCollector:
|
||||
"""Collects events directly into a list to avoid mp_channel flakiness."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: list[Event] = []
|
||||
|
||||
def send(self, event: Event) -> None:
|
||||
self.events.append(event)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def join(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
|
||||
"""
|
||||
Run tasks through the runner, adding shutdown at the end.
|
||||
|
||||
Tasks are sent in order, with shutdown sent last.
|
||||
The batch engine processes between task handling.
|
||||
"""
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_collector = EventCollector()
|
||||
|
||||
shutdown_task = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
with task_sender:
|
||||
# Send all tasks including shutdown
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown_task)
|
||||
|
||||
# Disable cleanup methods to prevent issues
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
|
||||
mlx_runner.main(bound_instance, event_collector, task_receiver) # type: ignore[arg-type]
|
||||
|
||||
return event_collector.events
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID,
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_single_request_generates_tokens(patch_batch_engine: None):
|
||||
"""
|
||||
Verify a single request generates the expected tokens through the batch path.
|
||||
|
||||
Tokens are generated during the generation loop (not during shutdown drain).
|
||||
The task completes after all tokens are generated.
|
||||
"""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Verify ChunkGenerated events are emitted for all tokens
|
||||
chunk_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd1")
|
||||
]
|
||||
assert len(chunk_events) == 3, (
|
||||
f"Expected 3 ChunkGenerated events, got {len(chunk_events)}"
|
||||
)
|
||||
|
||||
# Last chunk should have finish_reason="stop"
|
||||
last_chunk = chunk_events[-1].chunk
|
||||
assert isinstance(last_chunk, TokenChunk)
|
||||
assert last_chunk.finish_reason == "stop"
|
||||
|
||||
# Task should be marked complete after tokens are generated
|
||||
chat_complete = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Complete
|
||||
]
|
||||
assert len(chat_complete) == 1, "Expected exactly one chat task Complete status"
|
||||
|
||||
|
||||
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
|
||||
"""Verify RunnerRunning status includes active_requests count."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_chat_task_acknowledged(patch_batch_engine: None):
|
||||
"""Verify chat completion task is acknowledged with proper status updates."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find the chat task status events
|
||||
chat_running = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Running
|
||||
]
|
||||
|
||||
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
|
||||
|
||||
|
||||
def test_multiple_requests_generate_tokens(patch_batch_engine: None):
|
||||
"""Verify multiple requests each generate their expected tokens."""
|
||||
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
|
||||
|
||||
# Both requests should generate their expected number of tokens
|
||||
cmd1_chunks = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd1")
|
||||
]
|
||||
cmd2_chunks = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd2")
|
||||
]
|
||||
|
||||
assert len(cmd1_chunks) == 2, f"Expected 2 chunks for cmd1, got {len(cmd1_chunks)}"
|
||||
assert len(cmd2_chunks) == 2, f"Expected 2 chunks for cmd2, got {len(cmd2_chunks)}"
|
||||
|
||||
# Both tasks should be completed
|
||||
completed_task_ids = {
|
||||
e.task_id
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_status == TaskStatus.Complete
|
||||
and e.task_id in (TaskId("chat1"), TaskId("chat2"))
|
||||
}
|
||||
assert TaskId("chat1") in completed_task_ids
|
||||
assert TaskId("chat2") in completed_task_ids
|
||||
@@ -0,0 +1,719 @@
|
||||
"""
|
||||
Edge-case tests for continuous batching in the runner.
|
||||
|
||||
Tests cover:
|
||||
1. Concurrent requests with overlapping tool calls
|
||||
2. Requests that finish mid-generation with 'length' reason
|
||||
3. Multiple requests finishing on the same step() call
|
||||
4. Batch of 5+ simultaneous completions
|
||||
"""
|
||||
|
||||
# ruff: noqa: E402
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
# pyright: reportPrivateUsage=false
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerReady
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake batch engines
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScriptedBatchEngine:
|
||||
"""Batch engine driven by scripted per-request token sequences.
|
||||
|
||||
Each request produces a predefined list of (text, finish_reason) pairs.
|
||||
One step() call pops one token per active request.
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active: dict[
|
||||
int, tuple[CommandId, TaskId, list[tuple[str, FinishReason | None]]]
|
||||
] = {}
|
||||
self._pending: list[tuple[CommandId, TaskId, TextGenerationTaskParams]] = []
|
||||
self._uid = 0
|
||||
self.rank = 0
|
||||
# map command_id -> scripted tokens, set externally before tasks arrive
|
||||
self.scripts: dict[str, list[tuple[str, FinishReason | None]]] = {}
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> str:
|
||||
self._pending.append((command_id, task_id, task_params))
|
||||
return ""
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
uids: list[int] = []
|
||||
for cmd_id, task_id, _params in self._pending:
|
||||
uid = self._uid
|
||||
self._uid += 1
|
||||
script = list(self.scripts.get(str(cmd_id), [("tok", "stop")]))
|
||||
self._active[uid] = (cmd_id, task_id, script)
|
||||
uids.append(uid)
|
||||
self._pending.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending)
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
done: list[int] = []
|
||||
for uid, (cmd_id, task_id, script) in self._active.items():
|
||||
if not script:
|
||||
continue
|
||||
text, finish_reason = script.pop(0)
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=0, text=text, finish_reason=finish_reason, usage=None
|
||||
),
|
||||
)
|
||||
)
|
||||
if finish_reason is not None:
|
||||
done.append(uid)
|
||||
for uid in done:
|
||||
del self._active[uid]
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self._active)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active)
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""Generates N tokens per request (reused from the main test file)."""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, TextGenerationTaskParams]
|
||||
] = []
|
||||
self._uid_counter = 0
|
||||
self.rank = 0
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> str:
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
return ""
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
max_tokens = task_params.max_output_tokens or 3
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
done: list[int] = []
|
||||
for uid, (cmd_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish = "stop" if tokens_gen >= max_tokens else None
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=f"token{tokens_gen}",
|
||||
finish_reason=finish,
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
if finish:
|
||||
done.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (cmd_id, task_id, tokens_gen, max_tokens)
|
||||
for uid in done:
|
||||
del self._active_requests[uid]
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self._active_requests)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock tokenizers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
tool_parser = None
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class MockToolTokenizer:
|
||||
"""Tokenizer with tool calling enabled for testing."""
|
||||
|
||||
has_tool_calling = True
|
||||
has_thinking = False
|
||||
tool_call_start = "<tool>"
|
||||
tool_call_end = "</tool>"
|
||||
|
||||
@staticmethod
|
||||
def _tool_parser(text: str) -> dict[str, Any]:
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event collector & runner helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EventCollector:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[Event] = []
|
||||
|
||||
def send(self, event: Event) -> None:
|
||||
self.events.append(event)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def join(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
SETUP_TASKS: list[Task] = [INIT_TASK, LOAD_TASK, WARMUP_TASK]
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID,
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def _run_with_tasks(
|
||||
tasks: list[Task],
|
||||
engine_cls: type = FakeBatchEngineWithTokens,
|
||||
tokenizer_cls: type = MockTokenizer,
|
||||
engine_instance: Any | None = None,
|
||||
) -> list[Event]:
|
||||
"""Run tasks through the runner with configurable engine and tokenizer."""
|
||||
bound = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
collector = EventCollector()
|
||||
shutdown = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
import exo.worker.runner.runner as r
|
||||
|
||||
orig_init_mlx = r.initialize_mlx
|
||||
orig_load = r.load_mlx_items
|
||||
orig_warmup = r.warmup_inference
|
||||
orig_check = r._check_for_debug_prompts
|
||||
orig_engine = r.BatchGenerationEngine
|
||||
|
||||
r.initialize_mlx = make_nothin(FakeGroup())
|
||||
r.load_mlx_items = make_nothin((MagicMock(), tokenizer_cls))
|
||||
r.warmup_inference = make_nothin(1)
|
||||
r._check_for_debug_prompts = make_nothin(None)
|
||||
if engine_instance is not None:
|
||||
r.BatchGenerationEngine = lambda *_a, **_kw: engine_instance # pyright: ignore[reportUnknownLambdaType]
|
||||
else:
|
||||
r.BatchGenerationEngine = engine_cls
|
||||
|
||||
try:
|
||||
with task_sender:
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown)
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
r.main(bound, collector, task_receiver) # pyright: ignore[reportArgumentType]
|
||||
finally:
|
||||
r.initialize_mlx = orig_init_mlx
|
||||
r.load_mlx_items = orig_load
|
||||
r.warmup_inference = orig_warmup
|
||||
r._check_for_debug_prompts = orig_check
|
||||
r.BatchGenerationEngine = orig_engine
|
||||
|
||||
return collector.events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for querying events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def chunks_for(events: list[Event], command_id: str) -> list[ChunkGenerated]:
|
||||
return [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, ChunkGenerated) and e.command_id == CommandId(command_id)
|
||||
]
|
||||
|
||||
|
||||
def completed_task_ids(events: list[Event]) -> set[TaskId]:
|
||||
return {
|
||||
e.task_id
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated) and e.task_status == TaskStatus.Complete
|
||||
}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test 1: Concurrent requests with overlapping tool calls
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_concurrent_tool_calls_and_normal_text():
|
||||
"""Two concurrent requests: one emits normal text, the other a tool call.
|
||||
|
||||
Verifies that:
|
||||
- The normal request produces TokenChunks with its text
|
||||
- The tool-call request produces a ToolCallChunk
|
||||
- Both tasks complete
|
||||
"""
|
||||
engine = ScriptedBatchEngine()
|
||||
# cmd_normal: 2 normal tokens then stop
|
||||
engine.scripts["cmd_normal"] = [
|
||||
("hello", None),
|
||||
(" world", "stop"),
|
||||
]
|
||||
# cmd_tool: tool_start, body, tool_end (suppressed), then finish
|
||||
engine.scripts["cmd_tool"] = [
|
||||
("<tool>", None), # swallowed by tracker
|
||||
('{"name":"get_weather","arguments":{"city":"SF"}}', None), # accumulated
|
||||
("</tool>", None), # triggers ToolCallChunk emission
|
||||
("done", "stop"), # normal trailing token
|
||||
]
|
||||
|
||||
chat_normal = make_chat_task("t_normal", "cmd_normal", max_tokens=100)
|
||||
chat_tool = make_chat_task("t_tool", "cmd_tool", max_tokens=100)
|
||||
|
||||
events = _run_with_tasks(
|
||||
[*SETUP_TASKS, chat_normal, chat_tool],
|
||||
tokenizer_cls=MockToolTokenizer,
|
||||
engine_instance=engine,
|
||||
)
|
||||
|
||||
# Normal request: all chunks should be TokenChunk
|
||||
normal_chunks = chunks_for(events, "cmd_normal")
|
||||
assert len(normal_chunks) == 2
|
||||
assert all(isinstance(c.chunk, TokenChunk) for c in normal_chunks)
|
||||
assert normal_chunks[-1].chunk.finish_reason == "stop"
|
||||
|
||||
# Tool-call request
|
||||
tool_chunks = chunks_for(events, "cmd_tool")
|
||||
# <tool> → swallowed, body → accumulated, </tool> → ToolCallChunk, "done" → TokenChunk
|
||||
tool_call_events = [c for c in tool_chunks if isinstance(c.chunk, ToolCallChunk)]
|
||||
token_events = [c for c in tool_chunks if isinstance(c.chunk, TokenChunk)]
|
||||
|
||||
assert len(tool_call_events) == 1, (
|
||||
f"Expected 1 ToolCallChunk, got {len(tool_call_events)}"
|
||||
)
|
||||
tc_chunk = tool_call_events[0].chunk
|
||||
assert isinstance(tc_chunk, ToolCallChunk)
|
||||
assert tc_chunk.tool_calls[0].name == "get_weather"
|
||||
assert json.loads(tc_chunk.tool_calls[0].arguments) == {"city": "SF"}
|
||||
|
||||
assert len(token_events) == 1, "Expected 1 trailing TokenChunk after tool call"
|
||||
assert token_events[0].chunk.finish_reason == "stop"
|
||||
|
||||
# Both tasks should complete
|
||||
done = completed_task_ids(events)
|
||||
assert TaskId("t_normal") in done
|
||||
assert TaskId("t_tool") in done
|
||||
|
||||
|
||||
def test_tool_call_interrupted_by_finish_reason():
|
||||
"""Tool call in progress when finish_reason fires — partial text emitted."""
|
||||
engine = ScriptedBatchEngine()
|
||||
engine.scripts["cmd1"] = [
|
||||
("<tool>", None),
|
||||
('{"name":"f"', "stop"), # finish while inside tool call
|
||||
]
|
||||
|
||||
chat = make_chat_task("t1", "cmd1", max_tokens=100)
|
||||
events = _run_with_tasks(
|
||||
[*SETUP_TASKS, chat],
|
||||
tokenizer_cls=MockToolTokenizer,
|
||||
engine_instance=engine,
|
||||
)
|
||||
|
||||
chunks = chunks_for(events, "cmd1")
|
||||
assert len(chunks) == 1
|
||||
chunk = chunks[0].chunk
|
||||
assert isinstance(chunk, TokenChunk)
|
||||
# The interrupted tool call should be emitted as raw text
|
||||
assert "<tool>" in chunk.text
|
||||
assert '{"name":"f"' in chunk.text
|
||||
assert chunk.finish_reason == "stop"
|
||||
|
||||
assert TaskId("t1") in completed_task_ids(events)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test 2: Request finishing with 'length' reason (timeout mid-generation)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_request_finishes_with_length_reason():
|
||||
"""Request that hits max_tokens limit and finishes with 'length'."""
|
||||
engine = ScriptedBatchEngine()
|
||||
engine.scripts["cmd1"] = [
|
||||
("tok1", None),
|
||||
("tok2", None),
|
||||
("tok3", "length"), # hit the token limit
|
||||
]
|
||||
|
||||
chat = make_chat_task("t1", "cmd1", max_tokens=100)
|
||||
events = _run_with_tasks(
|
||||
[*SETUP_TASKS, chat],
|
||||
engine_instance=engine,
|
||||
)
|
||||
|
||||
chunks = chunks_for(events, "cmd1")
|
||||
assert len(chunks) == 3
|
||||
|
||||
# Last chunk should have finish_reason="length"
|
||||
assert isinstance(chunks[-1].chunk, TokenChunk)
|
||||
assert chunks[-1].chunk.finish_reason == "length"
|
||||
|
||||
# Earlier chunks should have no finish_reason
|
||||
for c in chunks[:-1]:
|
||||
assert isinstance(c.chunk, TokenChunk)
|
||||
assert c.chunk.finish_reason is None
|
||||
|
||||
assert TaskId("t1") in completed_task_ids(events)
|
||||
|
||||
|
||||
def test_mixed_finish_reasons_across_requests():
|
||||
"""Two requests finishing with different reasons: 'stop' and 'length'."""
|
||||
engine = ScriptedBatchEngine()
|
||||
engine.scripts["cmd_stop"] = [("a", None), ("b", "stop")]
|
||||
engine.scripts["cmd_len"] = [("x", None), ("y", "length")]
|
||||
|
||||
chat1 = make_chat_task("t_stop", "cmd_stop", max_tokens=100)
|
||||
chat2 = make_chat_task("t_len", "cmd_len", max_tokens=100)
|
||||
|
||||
events = _run_with_tasks(
|
||||
[*SETUP_TASKS, chat1, chat2],
|
||||
engine_instance=engine,
|
||||
)
|
||||
|
||||
stop_chunks = chunks_for(events, "cmd_stop")
|
||||
len_chunks = chunks_for(events, "cmd_len")
|
||||
|
||||
assert stop_chunks[-1].chunk.finish_reason == "stop"
|
||||
assert len_chunks[-1].chunk.finish_reason == "length"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
assert TaskId("t_stop") in done
|
||||
assert TaskId("t_len") in done
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test 3: Multiple finish reasons in rapid succession (same step)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_all_requests_finish_on_same_step():
|
||||
"""Three requests that all finish on the same step() call.
|
||||
|
||||
This tests that the runner and _process_generation_results correctly
|
||||
handle multiple completions in a single step.
|
||||
"""
|
||||
engine = ScriptedBatchEngine()
|
||||
# All three produce exactly 1 token and finish
|
||||
engine.scripts["cmd_a"] = [("alpha", "stop")]
|
||||
engine.scripts["cmd_b"] = [("beta", "stop")]
|
||||
engine.scripts["cmd_c"] = [("gamma", "stop")]
|
||||
|
||||
tasks = [
|
||||
*SETUP_TASKS,
|
||||
make_chat_task("ta", "cmd_a", max_tokens=100),
|
||||
make_chat_task("tb", "cmd_b", max_tokens=100),
|
||||
make_chat_task("tc", "cmd_c", max_tokens=100),
|
||||
]
|
||||
events = _run_with_tasks([*tasks], engine_instance=engine)
|
||||
|
||||
for cmd_id, expected_text in [
|
||||
("cmd_a", "alpha"),
|
||||
("cmd_b", "beta"),
|
||||
("cmd_c", "gamma"),
|
||||
]:
|
||||
c = chunks_for(events, cmd_id)
|
||||
assert len(c) == 1, f"Expected 1 chunk for {cmd_id}, got {len(c)}"
|
||||
assert isinstance(c[0].chunk, TokenChunk)
|
||||
assert c[0].chunk.text == expected_text
|
||||
assert c[0].chunk.finish_reason == "stop"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
assert TaskId("ta") in done
|
||||
assert TaskId("tb") in done
|
||||
assert TaskId("tc") in done
|
||||
|
||||
# Runner should reach RunnerReady at least after warmup.
|
||||
# With inline task processing, later requests may be inserted into the
|
||||
# batch before the generation loop exits, so the runner can stay
|
||||
# RunnerRunning until Shutdown without an intermediate RunnerReady.
|
||||
ready_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerReady)
|
||||
]
|
||||
assert len(ready_events) >= 1, "Expected RunnerReady at least after warmup"
|
||||
|
||||
|
||||
def test_staggered_completions_in_batch():
|
||||
"""Four requests with different token counts — they complete at different steps.
|
||||
|
||||
Verifies each request gets the right number of chunks and the runner
|
||||
tracks active_requests correctly as requests drain.
|
||||
"""
|
||||
engine = ScriptedBatchEngine()
|
||||
engine.scripts["c1"] = [("a", "stop")] # finishes step 1
|
||||
engine.scripts["c2"] = [("a", None), ("b", "stop")] # finishes step 2
|
||||
engine.scripts["c3"] = [("a", None), ("b", None), ("c", "stop")] # finishes step 3
|
||||
engine.scripts["c4"] = [
|
||||
("a", None),
|
||||
("b", None),
|
||||
("c", None),
|
||||
("d", "stop"),
|
||||
] # finishes step 4
|
||||
|
||||
tasks = [
|
||||
*SETUP_TASKS,
|
||||
make_chat_task("t1", "c1", max_tokens=100),
|
||||
make_chat_task("t2", "c2", max_tokens=100),
|
||||
make_chat_task("t3", "c3", max_tokens=100),
|
||||
make_chat_task("t4", "c4", max_tokens=100),
|
||||
]
|
||||
events = _run_with_tasks([*tasks], engine_instance=engine)
|
||||
|
||||
assert len(chunks_for(events, "c1")) == 1
|
||||
assert len(chunks_for(events, "c2")) == 2
|
||||
assert len(chunks_for(events, "c3")) == 3
|
||||
assert len(chunks_for(events, "c4")) == 4
|
||||
|
||||
done = completed_task_ids(events)
|
||||
for tid in ["t1", "t2", "t3", "t4"]:
|
||||
assert TaskId(tid) in done, f"Task {tid} should be complete"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test 4: Batch of 5+ simultaneous completions
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MockTokenizer))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
def test_five_simultaneous_completions(patch_batch_engine: None):
|
||||
"""Five requests submitted together, all generating tokens and completing."""
|
||||
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=2) for i in range(5)]
|
||||
events = _run_with_tasks([*SETUP_TASKS, *chats])
|
||||
|
||||
for i in range(5):
|
||||
c = chunks_for(events, f"cmd{i}")
|
||||
assert len(c) == 2, f"Expected 2 chunks for cmd{i}, got {len(c)}"
|
||||
assert c[-1].chunk.finish_reason == "stop"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
for i in range(5):
|
||||
assert TaskId(f"t{i}") in done
|
||||
|
||||
|
||||
def test_eight_requests_staggered(patch_batch_engine: None):
|
||||
"""Eight requests with varying token counts, verifying all complete correctly."""
|
||||
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=i + 1) for i in range(8)]
|
||||
events = _run_with_tasks([*SETUP_TASKS, *chats])
|
||||
|
||||
for i in range(8):
|
||||
c = chunks_for(events, f"cmd{i}")
|
||||
expected = i + 1
|
||||
assert len(c) == expected, (
|
||||
f"Expected {expected} chunks for cmd{i}, got {len(c)}"
|
||||
)
|
||||
assert c[-1].chunk.finish_reason == "stop"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
for i in range(8):
|
||||
assert TaskId(f"t{i}") in done
|
||||
|
||||
# Verify runner transitions back to ready after all requests complete
|
||||
# Find the last RunnerReady before shutdown
|
||||
ready_events = [
|
||||
(idx, e)
|
||||
for idx, e in enumerate(events)
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerReady)
|
||||
]
|
||||
shutdown_idx = next(
|
||||
idx
|
||||
for idx, e in enumerate(events)
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("shutdown")
|
||||
and e.task_status == TaskStatus.Running
|
||||
)
|
||||
# There should be a RunnerReady event between generation and shutdown
|
||||
ready_before_shutdown = [idx for idx, _ in ready_events if idx < shutdown_idx]
|
||||
assert len(ready_before_shutdown) >= 1, (
|
||||
"Expected RunnerReady between generation completion and shutdown"
|
||||
)
|
||||
|
||||
|
||||
def test_ten_simultaneous_single_token():
|
||||
"""Ten requests that each produce exactly one token — all finish on step 1."""
|
||||
engine = ScriptedBatchEngine()
|
||||
for i in range(10):
|
||||
engine.scripts[f"cmd{i}"] = [(f"word{i}", "stop")]
|
||||
|
||||
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=100) for i in range(10)]
|
||||
events = _run_with_tasks([*SETUP_TASKS, *chats], engine_instance=engine)
|
||||
|
||||
for i in range(10):
|
||||
c = chunks_for(events, f"cmd{i}")
|
||||
assert len(c) == 1
|
||||
assert isinstance(c[0].chunk, TokenChunk)
|
||||
assert c[0].chunk.text == f"word{i}"
|
||||
assert c[0].chunk.finish_reason == "stop"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
assert len(done & {TaskId(f"t{i}") for i in range(10)}) == 10
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -19,6 +20,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -37,6 +39,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import BatchedGenerationResponse
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
@@ -113,15 +116,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
|
||||
|
||||
|
||||
# Use a fake event_sender to remove test flakiness.
|
||||
@@ -144,6 +139,7 @@ class MockTokenizer:
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
@@ -154,6 +150,70 @@ class MockGroup:
|
||||
return 1
|
||||
|
||||
|
||||
class FakeBatchEngine:
|
||||
"""Fake batch engine that generates a single 'hi' token per request."""
|
||||
|
||||
def __init__(self, *_args: object, **_kwargs: object):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
|
||||
self._pending_inserts: list[tuple[CommandId, TaskId, object]] = []
|
||||
self._uid_counter = 0
|
||||
self.rank = 0
|
||||
|
||||
def queue_request(
|
||||
self, command_id: CommandId, task_id: TaskId, task_params: object
|
||||
) -> str:
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
return ""
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
uids: list[int] = []
|
||||
for cmd_id, task_id, _params in self._pending_inserts:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
self._active_requests[uid] = (cmd_id, task_id)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
for _uid, (cmd_id, task_id) in list(self._active_requests.items()):
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=0, text="hi", finish_reason="stop", usage=None
|
||||
),
|
||||
)
|
||||
)
|
||||
self._active_requests.clear()
|
||||
return results
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self._active_requests)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -219,17 +279,22 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
# CHAT TASK: queued, tokens generated, then completed
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID,
|
||||
runner_status=RunnerRunning(active_requests=1),
|
||||
),
|
||||
# Generation loop produces token and completes the task
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
# SHUTDOWN
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
@@ -238,7 +303,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user