mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-14 00:48:58 -05:00
Compare commits
1 Commits
main
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16578d35a4 |
@@ -276,24 +276,23 @@ class BatchGenerator:
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
unprocessed_prompts: List[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: nn.Module,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
stop_tokens: Optional[set[int]] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
|
||||
) -> List[int]: ...
|
||||
def stats(self) -> BatchStats: ...
|
||||
def next(self) -> List[Response]: ...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
|
||||
@@ -39,12 +39,18 @@ class StreamingDetokenizer:
|
||||
"""
|
||||
|
||||
__slots__ = ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
tokens: list[int]
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
@property
|
||||
def last_segment(self):
|
||||
def text(self) -> str:
|
||||
"""The full text decoded so far."""
|
||||
...
|
||||
@property
|
||||
def last_segment(self) -> str:
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
...
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
@@ -108,6 +114,7 @@ class TokenizerWrapper:
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
eos_token_ids: list[int] | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
|
||||
37
flake.lock
generated
37
flake.lock
generated
@@ -8,11 +8,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -42,22 +42,6 @@
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
@@ -68,8 +52,8 @@
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
@@ -78,18 +62,17 @@
|
||||
"fenix": "fenix",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -106,11 +89,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
12
flake.nix
12
flake.nix
@@ -18,9 +18,6 @@
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
nixConfig = {
|
||||
@@ -42,11 +39,9 @@
|
||||
];
|
||||
|
||||
perSystem =
|
||||
{ config, inputs', pkgs, lib, system, ... }:
|
||||
{ config, inputs', pkgs, lib, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
@@ -65,10 +60,7 @@
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
swift-format.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -50,7 +50,9 @@ class RunnerReady(BaseRunnerStatus):
|
||||
|
||||
|
||||
class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
"""Runner is processing requests and can accept more (continuous batching)."""
|
||||
|
||||
active_requests: int = 0
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
|
||||
208
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
208
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""Tracks an active request in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
uid: int # BatchGenerator's internal ID
|
||||
detokenizer: StreamingDetokenizer
|
||||
tokens_generated: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedGenerationResponse:
|
||||
"""Response from batch engine, tagged with command_id and task_id."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
response: GenerationResponse
|
||||
|
||||
|
||||
class BatchGenerationEngine:
|
||||
"""Manages continuous batching using mlx_lm's BatchGenerator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
completion_batch_size: int = 32,
|
||||
prefill_batch_size: int = 8,
|
||||
prefill_step_size: int = 2048,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.active_requests: dict[int, ActiveRequest] = {}
|
||||
|
||||
self.group = group
|
||||
self.rank = group.rank() if group else 0
|
||||
self.is_distributed = group is not None and group.size() > 1
|
||||
|
||||
sampler = make_sampler(temp=0.7, top_p=1.0)
|
||||
|
||||
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
|
||||
|
||||
self.batch_gen: BatchGenerator = BatchGenerator(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
stop_tokens=eos_tokens,
|
||||
sampler=sampler,
|
||||
completion_batch_size=completion_batch_size,
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
prefill_step_size=prefill_step_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
|
||||
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
|
||||
)
|
||||
|
||||
def insert_request(
|
||||
self,
|
||||
command_id: CommandId | None,
|
||||
task_id: TaskId | None,
|
||||
task_params: ChatCompletionTaskParams | None,
|
||||
) -> int:
|
||||
"""Insert a request, syncing across ranks if distributed. Returns uid."""
|
||||
if self.is_distributed:
|
||||
assert self.group is not None
|
||||
request_data = share_object(
|
||||
(command_id, task_id, task_params) if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
if self.rank != 0 and request_data is not None:
|
||||
command_id, task_id, task_params = request_data
|
||||
|
||||
assert command_id is not None
|
||||
assert task_id is not None
|
||||
assert task_params is not None
|
||||
|
||||
prompt_str = apply_chat_template(self.tokenizer, task_params)
|
||||
tokens: list[int] = self.tokenizer.encode(prompt_str, add_special_tokens=False)
|
||||
prompt_tokens = len(tokens)
|
||||
max_tokens = task_params.max_tokens or self.max_tokens
|
||||
|
||||
uids = self.batch_gen.insert([tokens], max_tokens=[max_tokens])
|
||||
uid = uids[0]
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
|
||||
self.active_requests[uid] = ActiveRequest(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
uid=uid,
|
||||
detokenizer=detokenizer,
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Inserted request {command_id} with uid={uid}, prompt_tokens={prompt_tokens}, max_tokens={max_tokens}")
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
"""Run one decode step. Syncs completed UIDs across ranks if distributed."""
|
||||
responses = self.batch_gen.next()
|
||||
if not responses:
|
||||
return []
|
||||
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for r in responses:
|
||||
uid: int = r.uid
|
||||
req = self.active_requests.get(uid)
|
||||
if req is None:
|
||||
logger.warning(f"Received response for unknown uid={uid}")
|
||||
continue
|
||||
|
||||
req.tokens_generated += 1
|
||||
|
||||
# Decode the token
|
||||
token: int = r.token
|
||||
req.detokenizer.add_token(token)
|
||||
text: str = req.detokenizer.last_segment
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
raw_finish_reason: str | None = r.finish_reason
|
||||
if raw_finish_reason is not None:
|
||||
# Finalize to get remaining text
|
||||
req.detokenizer.finalize()
|
||||
text = req.detokenizer.last_segment
|
||||
|
||||
elapsed = time.perf_counter() - req.start_time
|
||||
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=0.0, # Not tracked per-request in batch mode
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=req.prompt_tokens,
|
||||
generation_tokens=req.tokens_generated,
|
||||
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
|
||||
)
|
||||
|
||||
if raw_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif raw_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
|
||||
finish_reason = "stop"
|
||||
|
||||
uids_to_remove.append(uid) # Sync before removal
|
||||
logger.info(f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}")
|
||||
|
||||
results.append(BatchedGenerationResponse(
|
||||
command_id=req.command_id,
|
||||
task_id=req.task_id,
|
||||
response=GenerationResponse(text=text, token=token, finish_reason=finish_reason, stats=stats),
|
||||
))
|
||||
|
||||
# Sync completed UIDs across ranks before removing
|
||||
if self.is_distributed and uids_to_remove:
|
||||
assert self.group is not None
|
||||
uids_to_remove = share_object(uids_to_remove if self.rank == 0 else None, self.rank, self.group) or []
|
||||
|
||||
for uid in uids_to_remove:
|
||||
if uid in self.active_requests:
|
||||
del self.active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self.active_requests)
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return len(self.batch_gen.unprocessed_prompts)
|
||||
30
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
30
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
import pickle
|
||||
from typing import TypeVar, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
|
||||
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
|
||||
if rank == 0:
|
||||
if obj is None:
|
||||
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
|
||||
return None
|
||||
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
|
||||
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
|
||||
mx.eval(mx.distributed.all_sum(data, group=group))
|
||||
return obj
|
||||
else:
|
||||
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
|
||||
if size == 0:
|
||||
return None
|
||||
data = mx.zeros(size, dtype=mx.uint8)
|
||||
data = mx.distributed.all_sum(data, group=group)
|
||||
mx.eval(data)
|
||||
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))
|
||||
@@ -277,12 +277,14 @@ def _pending_tasks(
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import gc
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from anyio import WouldBlock
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
@@ -21,9 +23,6 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
@@ -39,7 +38,8 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.generator.batch_engine import BatchGenerationEngine
|
||||
from exo.worker.engines.mlx.generator.generate import warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
@@ -69,143 +69,156 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
batch_engine: BatchGenerationEngine | None = None
|
||||
pending_shutdown: Shutdown | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
|
||||
def send_status(status: RunnerStatus) -> None:
|
||||
event_sender.send(RunnerStatusUpdated(runner_id=runner_id, runner_status=status))
|
||||
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
send_status(current_status)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
def handle_task(task: Task, is_deferred: bool = False) -> bool:
|
||||
nonlocal current_status, model, tokenizer, group, batch_engine, pending_shutdown
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
# For Shutdown, check if we need to defer BEFORE sending Running/Acknowledged
|
||||
if isinstance(task, Shutdown) and not is_deferred:
|
||||
if batch_engine is not None and batch_engine.has_active_requests:
|
||||
logger.info("deferring shutdown until active requests complete")
|
||||
pending_shutdown = task
|
||||
return True
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running))
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
send_status(current_status)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
send_status(current_status)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model is not None
|
||||
assert tokenizer is not None
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
send_status(current_status)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(model=model, tokenizer=tokenizer)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(f"runner initialized in {time.time() - setup_start_time} seconds")
|
||||
|
||||
batch_engine = BatchGenerationEngine(model=model, tokenizer=tokenizer, group=group)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, (RunnerReady, RunnerRunning)):
|
||||
assert batch_engine is not None
|
||||
|
||||
if task_params.messages and task_params.messages[0].content is not None:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
batch_engine.insert_request(command_id=command_id, task_id=task.task_id, task_params=task_params)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count)
|
||||
send_status(current_status)
|
||||
|
||||
gc.collect()
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
send_status(current_status)
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
current_status = RunnerShutdown()
|
||||
send_status(current_status)
|
||||
return False
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
with task_receiver as tasks:
|
||||
running = True
|
||||
while running:
|
||||
while True:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
running = handle_task(task)
|
||||
if not running:
|
||||
break
|
||||
except WouldBlock:
|
||||
break
|
||||
|
||||
if not running:
|
||||
break
|
||||
|
||||
if batch_engine is not None and batch_engine.has_active_requests:
|
||||
for resp in batch_engine.step():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(ChunkGenerated(
|
||||
command_id=resp.command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=resp.response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=resp.response.text,
|
||||
token_id=resp.response.token,
|
||||
finish_reason=resp.response.finish_reason,
|
||||
stats=resp.response.stats,
|
||||
),
|
||||
))
|
||||
if resp.response.finish_reason is not None:
|
||||
event_sender.send(TaskStatusUpdated(task_id=resp.task_id, task_status=TaskStatus.Complete))
|
||||
|
||||
if batch_engine.has_active_requests:
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count)
|
||||
else:
|
||||
current_status = RunnerReady()
|
||||
send_status(current_status)
|
||||
|
||||
# Process deferred shutdown after all requests complete
|
||||
if pending_shutdown is not None and not batch_engine.has_active_requests:
|
||||
running = handle_task(pending_shutdown, is_deferred=True)
|
||||
else:
|
||||
task = tasks.receive()
|
||||
running = handle_task(task)
|
||||
|
||||
# Cleanup
|
||||
del model, tokenizer, group, batch_engine
|
||||
mx.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
|
||||
@@ -105,7 +105,7 @@ class RunnerSupervisor:
|
||||
return
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
logger.warning("Runner process didn't shutdown successfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
@@ -128,9 +128,11 @@ class RunnerSupervisor:
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
logger.info(f"Skipping task {task.task_id} - already completed")
|
||||
return
|
||||
if task.task_id in self.pending:
|
||||
logger.info(f"Skipping task {task.task_id} - already pending")
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -149,13 +151,17 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
# Just set the event to unblock start_task, but keep in pending
|
||||
# to prevent duplicate forwarding until completion
|
||||
if event.task_id in self.pending:
|
||||
self.pending[event.task_id].set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
and event.task_status == TaskStatus.Complete
|
||||
if isinstance(event, TaskStatusUpdated) and event.task_status in (
|
||||
TaskStatus.Complete,
|
||||
TaskStatus.TimedOut,
|
||||
TaskStatus.Failed,
|
||||
):
|
||||
# If a task has just been completed, we should be working on it.
|
||||
# If a task has just finished, we should be working on it.
|
||||
assert isinstance(
|
||||
self.status,
|
||||
(
|
||||
@@ -166,6 +172,8 @@ class RunnerSupervisor:
|
||||
RunnerShuttingDown,
|
||||
),
|
||||
)
|
||||
# Now safe to remove from pending and add to completed
|
||||
self.pending.pop(event.task_id, None)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
|
||||
@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Tests for continuous batching behavior in the runner.
|
||||
|
||||
These tests verify that:
|
||||
1. Single requests work through the batch path
|
||||
2. Multiple concurrent requests batch together
|
||||
3. Tokens are routed to the correct requests
|
||||
4. Requests complete at different times appropriately
|
||||
"""
|
||||
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerRunning
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""
|
||||
Fake batch engine that generates a specified number of tokens per request.
|
||||
|
||||
This simulates realistic batch generation behavior where:
|
||||
- Requests are queued on insert
|
||||
- Each step() call generates one token for all active requests
|
||||
- Requests complete when they've generated all their tokens
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._uid_counter = 0
|
||||
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
|
||||
|
||||
def insert_request(
|
||||
self,
|
||||
command_id: CommandId | None,
|
||||
task_id: TaskId | None,
|
||||
task_params: ChatCompletionTaskParams | None,
|
||||
) -> int:
|
||||
assert command_id is not None
|
||||
assert task_id is not None
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
# Track: (command_id, task_id, tokens_generated, max_tokens)
|
||||
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish_reason = "stop" if tokens_gen >= max_tokens else None
|
||||
text = f"token{tokens_gen}"
|
||||
|
||||
if finish_reason:
|
||||
uids_to_remove.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (
|
||||
command_id,
|
||||
task_id,
|
||||
tokens_gen,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for uid in uids_to_remove:
|
||||
del self._active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
|
||||
"""
|
||||
Run tasks through the runner, adding shutdown at the end.
|
||||
|
||||
Tasks are sent in order, with shutdown sent last.
|
||||
The batch engine processes between task handling.
|
||||
"""
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
shutdown_task = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
with task_sender, event_receiver:
|
||||
# Send all tasks including shutdown
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown_task)
|
||||
|
||||
# Disable cleanup methods to prevent issues
|
||||
event_sender.close = lambda: None
|
||||
event_sender.join = lambda: None
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> ChatCompletion:
|
||||
return ChatCompletion(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_single_request_generates_tokens(patch_batch_engine: None):
|
||||
"""
|
||||
Verify a single request generates the expected tokens through the batch path.
|
||||
|
||||
Note: With the current non-blocking design, shutdown is processed before
|
||||
batch steps run when all tasks are queued together. This test verifies
|
||||
the runner status reflects active requests.
|
||||
"""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events - this shows the request was inserted
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
|
||||
"""Verify RunnerRunning status includes active_requests count."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_chat_task_acknowledged(patch_batch_engine: None):
|
||||
"""Verify chat completion task is acknowledged with proper status updates."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find the chat task status events
|
||||
chat_running = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Running
|
||||
]
|
||||
|
||||
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
|
||||
|
||||
|
||||
def test_multiple_requests_tracked(patch_batch_engine: None):
|
||||
"""Verify multiple concurrent requests are tracked in active_requests."""
|
||||
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
# Should have at least 2 RunnerRunning events (one per request inserted)
|
||||
assert len(running_events) >= 2, f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
|
||||
|
||||
# First should have 1 active request, second should have 2
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
assert running_events[1].runner_status.active_requests == 2
|
||||
@@ -1,11 +1,16 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
@@ -22,6 +27,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
@@ -38,6 +44,9 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
@@ -107,18 +116,68 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
class FakeBatchEngine:
|
||||
"""
|
||||
Fake batch engine for testing.
|
||||
|
||||
Queues requests on insert, returns one token per step.
|
||||
The runner's non-blocking loop drains all tasks before running batch steps,
|
||||
so this engine queues requests and has_active_requests returns True only
|
||||
after at least one request has been inserted.
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
|
||||
self._uid_counter = 0
|
||||
|
||||
def insert_request(
|
||||
self,
|
||||
command_id: CommandId | None,
|
||||
task_id: TaskId | None,
|
||||
task_params: ChatCompletionTaskParams | None,
|
||||
) -> int:
|
||||
assert command_id is not None
|
||||
assert task_id is not None
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
self._active_requests[uid] = (command_id, task_id)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
# Process all active requests - return one token and complete
|
||||
for uid, (command_id, task_id) in list(self._active_requests.items()):
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=0,
|
||||
text="hi",
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
)
|
||||
del self._active_requests[uid]
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
|
||||
# initialize_mlx returns a fake "group" (non-None for state machine)
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock())))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
@@ -148,7 +207,8 @@ def _run(tasks: Iterable[Task]):
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPatch):
|
||||
"""Verify chat completion generates tokens, completes, and runner returns to Ready."""
|
||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||
|
||||
expected_chunk = ChunkGenerated(
|
||||
@@ -191,7 +251,9 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerRunning(active_requests=1)
|
||||
),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
@@ -206,7 +268,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user