mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-15 00:23:07 -05:00
Compare commits
6 Commits
alexcheema
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd43588a04 | ||
|
|
7b879593bb | ||
|
|
e4e895d7a8 | ||
|
|
db400dbb75 | ||
|
|
15fad9c632 | ||
|
|
842beefac0 |
@@ -276,24 +276,23 @@ class BatchGenerator:
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
unprocessed_prompts: List[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: nn.Module,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
stop_tokens: Optional[set[int]] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
|
||||
) -> List[int]: ...
|
||||
def stats(self) -> BatchStats: ...
|
||||
def next(self) -> List[Response]: ...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
|
||||
@@ -39,11 +39,11 @@ class StreamingDetokenizer:
|
||||
"""
|
||||
|
||||
__slots__ = ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
@property
|
||||
def last_segment(self):
|
||||
def last_segment(self) -> str:
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
||||
76
AGENTS.md
76
AGENTS.md
@@ -116,10 +116,49 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Model Storage
|
||||
|
||||
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
|
||||
|
||||
## Creating Model Instances via API
|
||||
|
||||
When testing with the API, you must first create a model instance before sending chat completions:
|
||||
|
||||
```bash
|
||||
# 1. Get instance previews for a model
|
||||
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
|
||||
|
||||
# 2. Create an instance from the first valid preview
|
||||
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
|
||||
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
|
||||
|
||||
# 3. Wait for the runner to become ready (check logs for "runner ready")
|
||||
|
||||
# 4. Send chat completions using the full model ID
|
||||
curl -X POST http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
|
||||
```
|
||||
|
||||
## Logs
|
||||
|
||||
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
### Distributed Testing
|
||||
|
||||
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
|
||||
|
||||
```bash
|
||||
# On each machine in the test cluster, use the same unique namespace
|
||||
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
|
||||
```
|
||||
|
||||
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.
|
||||
|
||||
## Dashboard UI Testing & Screenshots
|
||||
|
||||
### Building and Running the Dashboard
|
||||
@@ -194,40 +233,3 @@ GitHub's API doesn't support direct image upload for PR comments. Workaround:
|
||||
git push origin <branch>
|
||||
```
|
||||
The images still render in the PR comment because they reference the permanent commit SHA.
|
||||
|
||||
## Running exo Remotely via SSH (macOS mDNS)
|
||||
|
||||
**CRITICAL: On macOS, mDNS multicast (used for peer discovery) only works when the process runs in a proper macOS user session.** Background processes started via `nohup ... &`, `screen`, or plain SSH commands will NOT send mDNS packets and nodes will never discover each other.
|
||||
|
||||
### The Problem
|
||||
When you SSH into a Mac and run `nohup uv run exo &`, the process runs in a detached session without access to macOS multicast networking. The exo node will start but will never discover peers, even if they're on the same network.
|
||||
|
||||
### The Solution: Use `open` with a `.command` wrapper
|
||||
|
||||
Create a `.command` script that `open` will execute in the proper macOS GUI session context:
|
||||
|
||||
```bash
|
||||
# 1. Create wrapper script on the remote machine
|
||||
ssh user@remote-mac "cat > /tmp/run_exo.command << 'SCRIPT'
|
||||
#!/bin/bash
|
||||
export PATH=/opt/homebrew/bin:\$HOME/.local/bin:\$PATH
|
||||
export EXO_LIBP2P_NAMESPACE=your-namespace # must match across all nodes
|
||||
cd ~/path/to/exo
|
||||
exec uv run exo -vv 2>&1 | tee /tmp/exo.log
|
||||
SCRIPT
|
||||
chmod +x /tmp/run_exo.command"
|
||||
|
||||
# 2. Launch it via `open` (runs in macOS GUI session with proper mDNS)
|
||||
ssh user@remote-mac "open /tmp/run_exo.command"
|
||||
|
||||
# 3. Check logs
|
||||
ssh user@remote-mac "tail -f /tmp/exo.log"
|
||||
```
|
||||
|
||||
### Key Details
|
||||
- **`EXO_LIBP2P_NAMESPACE`**: All nodes in a cluster MUST use the same namespace value. The EXO.app uses a build-specific namespace (check with `ps eww <pid> | grep NAMESPACE`). If mixing dev builds with EXO.app, set the dev build's namespace to match.
|
||||
- **`open *.command`**: This is the macOS equivalent of double-clicking the script in Finder. It runs in the user's GUI session with full network access.
|
||||
- **Do NOT use**: `nohup ... &`, `screen -dm`, `tmux new-session -d`, or `sshpass`. These all create detached sessions where mDNS won't work.
|
||||
- **Killing**: `ssh user@remote-mac "pkill -f 'python.*exo'"` works fine for stopping.
|
||||
- **Dashboard**: Must be built before running: `cd dashboard && npm install && npm run build && cd ..`. Node.js is at `/opt/homebrew/bin/node` on Apple Silicon Macs.
|
||||
- **Verifying cluster**: `curl -s http://localhost:52415/state | python3 -c "import json,sys; s=json.load(sys.stdin); print(len(s['topology']['nodes']), 'nodes')"` — should show 2+ nodes.
|
||||
|
||||
1
conftest.py
Normal file
1
conftest.py
Normal file
@@ -0,0 +1 @@
|
||||
collect_ignore = ["tests/start_distributed_test.py"]
|
||||
@@ -132,7 +132,7 @@ markers = [
|
||||
env = [
|
||||
"EXO_TESTS=1"
|
||||
]
|
||||
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
|
||||
addopts = "-m 'not slow'"
|
||||
filterwarnings = [
|
||||
"ignore:builtin type Swig:DeprecationWarning",
|
||||
]
|
||||
|
||||
@@ -73,8 +73,6 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DistributeModelParams,
|
||||
DistributeModelResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -119,7 +117,6 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DistributeModel,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
@@ -145,7 +142,6 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
@@ -302,7 +298,6 @@ class API:
|
||||
self.app.get("/events")(self.stream_events)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.post("/v1/models/{model_id:path}/distribute")(self.distribute_model)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
@@ -1482,57 +1477,6 @@ class API:
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
async def distribute_model(
|
||||
self, model_id: ModelId, payload: DistributeModelParams
|
||||
) -> DistributeModelResponse:
|
||||
"""Distribute model files from one node to others via MLX distributed."""
|
||||
# Find a source node that has the model downloaded
|
||||
source_node_id: NodeId | None = None
|
||||
for nid, downloads in self.state.downloads.items():
|
||||
for dp in downloads:
|
||||
if (
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == model_id
|
||||
):
|
||||
source_node_id = nid
|
||||
break
|
||||
if source_node_id is not None:
|
||||
break
|
||||
|
||||
if source_node_id is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No node has model {model_id} downloaded",
|
||||
)
|
||||
|
||||
# Determine target nodes
|
||||
if payload.target_node_ids is not None:
|
||||
target_node_ids = [
|
||||
nid for nid in payload.target_node_ids if nid != source_node_id
|
||||
]
|
||||
else:
|
||||
target_node_ids = [
|
||||
nid for nid in self.state.topology.list_nodes() if nid != source_node_id
|
||||
]
|
||||
|
||||
if not target_node_ids:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No target nodes to distribute to",
|
||||
)
|
||||
|
||||
command = DistributeModel(
|
||||
model_id=model_id,
|
||||
source_node_id=source_node_id,
|
||||
target_node_ids=target_node_ids,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return DistributeModelResponse(
|
||||
command_id=command.command_id,
|
||||
message=f"Distributing {model_id} from {source_node_id} to {len(target_node_ids)} node(s)",
|
||||
)
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DistributeModel,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
@@ -313,37 +312,6 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case DistributeModel():
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
model_card = await ModelCard.load(command.model_id)
|
||||
all_node_ids = set(
|
||||
[command.source_node_id] + list(command.target_node_ids)
|
||||
)
|
||||
place_command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=len(all_node_ids),
|
||||
)
|
||||
placement = place_instance(
|
||||
place_command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
required_nodes=all_node_ids,
|
||||
)
|
||||
# Mark new instances as transfer-only
|
||||
for instance_id, instance in placement.items():
|
||||
if instance_id not in self.state.instances:
|
||||
instance.shard_assignments.transfer_only = True
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
|
||||
@@ -374,15 +374,6 @@ class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class DistributeModelParams(CamelCaseModel):
|
||||
target_node_ids: list[NodeId] | None = None # None = all connected nodes
|
||||
|
||||
|
||||
class DistributeModelResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
message: str
|
||||
|
||||
|
||||
class TraceEventResponse(CamelCaseModel):
|
||||
name: str
|
||||
start_us: int
|
||||
|
||||
@@ -77,14 +77,6 @@ class CancelDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class DistributeModel(BaseCommand):
|
||||
"""Distribute model files from one node to others via MLX distributed."""
|
||||
|
||||
model_id: ModelId
|
||||
source_node_id: NodeId
|
||||
target_node_ids: list[NodeId]
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
@@ -99,7 +91,6 @@ Command = (
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
| DistributeModel
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
has_local_model: bool = Field(default=True)
|
||||
pass
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
@@ -76,13 +76,6 @@ class ImageEdits(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TransferModelToDisk(BaseTask): # emitted by Worker
|
||||
"""Transfer all model files from source to receivers' disk via MLX distributed."""
|
||||
|
||||
shard_metadata: ShardMetadata
|
||||
has_local_model: bool = Field(default=True)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -92,7 +85,6 @@ Task = (
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| TransferModelToDisk
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| ImageGeneration
|
||||
|
||||
@@ -50,7 +50,9 @@ class RunnerReady(BaseRunnerStatus):
|
||||
|
||||
|
||||
class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
"""Runner is processing requests and can accept more (continuous batching)."""
|
||||
|
||||
active_requests: int = 0
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
@@ -84,7 +86,6 @@ class ShardAssignments(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
runner_to_shard: Mapping[RunnerId, ShardMetadata]
|
||||
node_to_runner: Mapping[NodeId, RunnerId]
|
||||
transfer_only: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_runners_exist(self) -> "ShardAssignments":
|
||||
|
||||
@@ -47,7 +47,6 @@ if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
WeightLoader = Callable[[nn.Module, int], None] | None
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
@@ -347,7 +346,6 @@ def tensor_auto_parallel(
|
||||
group: mx.distributed.Group,
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
all_to_sharded_linear = partial(
|
||||
shard_linear,
|
||||
@@ -457,7 +455,7 @@ def tensor_auto_parallel(
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
model = tensor_parallel_sharding_strategy.shard_model(
|
||||
model, timeout_seconds, on_timeout, weight_loader
|
||||
model, timeout_seconds, on_timeout
|
||||
)
|
||||
return patch_tensor_model(model)
|
||||
|
||||
@@ -484,7 +482,6 @@ class TensorParallelShardingStrategy(ABC):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module: ...
|
||||
|
||||
|
||||
@@ -494,12 +491,9 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(LlamaModel, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
# Force load weights before sharding to avoid FAST_SYNCH deadlock
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -551,12 +545,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -629,12 +620,9 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
for i, layer in enumerate(model.layers): # type: ignore
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers: # type: ignore
|
||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
@@ -774,12 +762,9 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -817,12 +802,9 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -944,12 +926,9 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Glm4MoeModel, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -993,13 +972,10 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -1037,7 +1013,6 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Step35Model, model)
|
||||
|
||||
|
||||
307
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
307
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import TaskId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""Tracks an active request in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
uid: int # BatchGenerator's internal ID
|
||||
detokenizer: StreamingDetokenizer
|
||||
tokens_generated: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedGenerationResponse:
|
||||
"""Response from batch engine, tagged with command_id and task_id."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
response: GenerationResponse
|
||||
|
||||
|
||||
class BatchGenerationEngine:
|
||||
"""Manages continuous batching using mlx_lm's BatchGenerator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
completion_batch_size: int = 32,
|
||||
prefill_batch_size: int = 8,
|
||||
prefill_step_size: int = 2048,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.active_requests: dict[int, ActiveRequest] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, TextGenerationTaskParams]
|
||||
] = []
|
||||
self._pending_completions: list[
|
||||
int
|
||||
] = [] # UIDs completed but not yet synced/removed
|
||||
|
||||
self.group = group
|
||||
self.rank = group.rank() if group else 0
|
||||
self.is_distributed = group is not None and group.size() > 1
|
||||
|
||||
sampler = make_sampler(temp=0.7, top_p=1.0)
|
||||
|
||||
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
|
||||
|
||||
self.batch_gen: BatchGenerator = BatchGenerator(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
stop_tokens=eos_tokens,
|
||||
sampler=sampler,
|
||||
completion_batch_size=completion_batch_size,
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
prefill_step_size=prefill_step_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
|
||||
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
|
||||
)
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion. Only rank 0 should call this.
|
||||
|
||||
In distributed mode, rank 0 receives tasks from the control plane and
|
||||
queues them here. The actual insertion happens in sync_and_insert_pending()
|
||||
which ensures all ranks insert the same requests together.
|
||||
"""
|
||||
assert self.rank == 0, "Only rank 0 should queue requests"
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
logger.info(
|
||||
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})"
|
||||
)
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Sync pending inserts across ranks and insert them. Returns UIDs.
|
||||
|
||||
This method ensures all ranks insert the same requests in the same order.
|
||||
In non-distributed mode, it simply inserts all pending requests.
|
||||
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
|
||||
|
||||
Batches all pending inserts into a single batch_gen.insert() call for
|
||||
efficient prefill batching.
|
||||
"""
|
||||
inserts_to_process: list[tuple[CommandId, TaskId, TextGenerationTaskParams]]
|
||||
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: just insert directly from pending
|
||||
inserts_to_process = list(self._pending_inserts)
|
||||
else:
|
||||
# Distributed: broadcast pending inserts from rank 0 to all ranks
|
||||
assert self.group is not None
|
||||
pending_data = self._pending_inserts if self.rank == 0 else None
|
||||
synced_data = share_object(pending_data, self.rank, self.group)
|
||||
|
||||
if synced_data is None:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
inserts_to_process = synced_data
|
||||
|
||||
if not inserts_to_process:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
# Prepare all requests for batched insertion
|
||||
all_tokens: list[list[int]] = []
|
||||
all_max_tokens: list[int] = []
|
||||
all_prompt_tokens: list[int] = []
|
||||
request_info: list[tuple[CommandId, TaskId]] = []
|
||||
|
||||
for cmd_id, task_id, params in inserts_to_process:
|
||||
prompt_str = apply_chat_template(self.tokenizer, params)
|
||||
tokens: list[int] = self.tokenizer.encode(
|
||||
prompt_str, add_special_tokens=False
|
||||
)
|
||||
max_tokens = params.max_output_tokens or self.max_tokens
|
||||
|
||||
all_tokens.append(tokens)
|
||||
all_max_tokens.append(max_tokens)
|
||||
all_prompt_tokens.append(len(tokens))
|
||||
request_info.append((cmd_id, task_id))
|
||||
|
||||
# Single batched insert for efficient prefill
|
||||
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
|
||||
|
||||
# Track all inserted requests
|
||||
for i, uid in enumerate(uids):
|
||||
cmd_id, task_id = request_info[i]
|
||||
self.active_requests[uid] = ActiveRequest(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
uid=uid,
|
||||
detokenizer=self.tokenizer.detokenizer,
|
||||
prompt_tokens=all_prompt_tokens[i],
|
||||
)
|
||||
logger.info(
|
||||
f"Inserted request {cmd_id} with uid={uid}, prompt_tokens={all_prompt_tokens[i]}, max_tokens={all_max_tokens[i]}"
|
||||
)
|
||||
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
|
||||
responses = self.batch_gen.next()
|
||||
if not responses:
|
||||
return []
|
||||
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
|
||||
for r in responses:
|
||||
uid: int = r.uid
|
||||
req = self.active_requests.get(uid)
|
||||
if req is None:
|
||||
logger.warning(f"Received response for unknown uid={uid}")
|
||||
continue
|
||||
|
||||
req.tokens_generated += 1
|
||||
|
||||
# Decode the token
|
||||
token: int = r.token
|
||||
req.detokenizer.add_token(token)
|
||||
text: str = req.detokenizer.last_segment
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
raw_finish_reason: str | None = r.finish_reason
|
||||
if raw_finish_reason is not None:
|
||||
# Finalize to get remaining text
|
||||
req.detokenizer.finalize()
|
||||
text = req.detokenizer.last_segment
|
||||
|
||||
elapsed = time.perf_counter() - req.start_time
|
||||
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=0.0, # Not tracked per-request in batch mode
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=req.prompt_tokens,
|
||||
generation_tokens=req.tokens_generated,
|
||||
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
|
||||
)
|
||||
|
||||
if raw_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif raw_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
|
||||
finish_reason = "stop"
|
||||
|
||||
# Track completion but don't remove yet - wait for sync_completions()
|
||||
self._pending_completions.append(uid)
|
||||
logger.info(
|
||||
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=req.command_id,
|
||||
task_id=req.task_id,
|
||||
response=GenerationResponse(
|
||||
text=text,
|
||||
token=token,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# In non-distributed mode, clean up completions immediately
|
||||
if not self.is_distributed:
|
||||
self._remove_completed()
|
||||
|
||||
return results
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: early return if nothing to do
|
||||
if not self._pending_completions:
|
||||
return
|
||||
self._remove_completed()
|
||||
return
|
||||
|
||||
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
|
||||
# This prevents deadlock if one rank has completions and another doesn't
|
||||
assert self.group is not None
|
||||
synced_uids = share_object(
|
||||
self._pending_completions if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
if synced_uids:
|
||||
self._pending_completions = synced_uids
|
||||
|
||||
self._remove_completed()
|
||||
|
||||
def _remove_completed(self) -> None:
|
||||
"""Remove completed requests from tracking."""
|
||||
for uid in self._pending_completions:
|
||||
if uid in self.active_requests:
|
||||
del self.active_requests[uid]
|
||||
self._pending_completions.clear()
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self.active_requests)
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return len(self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def has_pending_completions(self) -> bool:
|
||||
return bool(self._pending_completions)
|
||||
30
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
30
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
import pickle
|
||||
from typing import TypeVar, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
|
||||
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
|
||||
if rank == 0:
|
||||
if obj is None:
|
||||
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
|
||||
return None
|
||||
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
|
||||
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
|
||||
mx.eval(mx.distributed.all_sum(data, group=group))
|
||||
return obj
|
||||
else:
|
||||
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
|
||||
if size == 0:
|
||||
return None
|
||||
data = mx.zeros(size, dtype=mx.uint8)
|
||||
data = mx.distributed.all_sum(data, group=group)
|
||||
mx.eval(data)
|
||||
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))
|
||||
104
src/exo/worker/engines/mlx/generator/time_budget.py
Normal file
104
src/exo/worker/engines/mlx/generator/time_budget.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Time budget iterator for controlling generation loop timing in distributed mode.
|
||||
|
||||
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
|
||||
rather than syncing every token. This reduces distributed sync overhead.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class TimeBudget(Iterator[None]):
|
||||
"""Controls generation loop timing, syncing across ranks periodically.
|
||||
|
||||
In distributed mode, periodically syncs timing across all ranks to
|
||||
dynamically adjust iteration count based on actual performance.
|
||||
|
||||
In non-distributed mode, simply runs for the time budget.
|
||||
|
||||
Usage:
|
||||
for _ in TimeBudget(budget=0.5):
|
||||
batch_engine.step()
|
||||
# ... process responses ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
budget: float = 0.5,
|
||||
iterations: int = 25,
|
||||
sync_frequency: int = 10,
|
||||
group: mx.distributed.Group | None = None,
|
||||
):
|
||||
"""Initialize TimeBudget.
|
||||
|
||||
Args:
|
||||
budget: Time budget in seconds before yielding control
|
||||
iterations: Initial number of iterations per budget period (distributed only)
|
||||
sync_frequency: How often to sync timing across ranks (distributed only)
|
||||
group: Distributed group, or None for non-distributed mode
|
||||
"""
|
||||
self._budget = budget
|
||||
self._iterations = iterations
|
||||
self._sync_frequency = sync_frequency
|
||||
self._group = group
|
||||
self._is_distributed = group is not None and group.size() > 1
|
||||
|
||||
# Runtime state
|
||||
self._start: float = 0.0
|
||||
self._current_iterations: int = 0
|
||||
self._loops: int = 0
|
||||
self._time_spent: float = 0.0
|
||||
|
||||
def __iter__(self) -> "TimeBudget":
|
||||
self._start = time.perf_counter()
|
||||
self._current_iterations = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> None:
|
||||
if not self._is_distributed:
|
||||
# Non-distributed: just check time budget
|
||||
if time.perf_counter() - self._start > self._budget:
|
||||
raise StopIteration()
|
||||
return None
|
||||
|
||||
# Distributed mode: iteration-based with periodic timing sync
|
||||
self._current_iterations += 1
|
||||
if self._current_iterations > self._iterations:
|
||||
self._loops += 1
|
||||
self._time_spent += time.perf_counter() - self._start
|
||||
|
||||
if self._loops % self._sync_frequency == 0:
|
||||
# Sync timing across all ranks
|
||||
assert self._group is not None
|
||||
with mx.stream(generation_stream):
|
||||
time_array = mx.array([self._time_spent], dtype=mx.float32)
|
||||
total_time = mx.distributed.all_sum(time_array, group=self._group)
|
||||
mx.eval(total_time)
|
||||
loop_time = float(total_time.item())
|
||||
|
||||
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
|
||||
|
||||
if avg_loop_time > 0:
|
||||
factor = self._budget / avg_loop_time
|
||||
self._iterations = max(round(self._iterations * factor), 1)
|
||||
logger.debug(
|
||||
f"TimeBudget adjusted iterations to {self._iterations}"
|
||||
)
|
||||
|
||||
self._loops = 0
|
||||
self._time_spent = 0.0
|
||||
|
||||
raise StopIteration()
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Current iterations per budget period."""
|
||||
return self._iterations
|
||||
@@ -1,507 +0,0 @@
|
||||
"""
|
||||
Model transfer via MLX distributed all_sum.
|
||||
|
||||
Three transfer modes:
|
||||
1. Metadata file transfer: broadcast small files (config.json, tokenizer, etc.) to disk
|
||||
2. Weight tensor broadcast: stream weight tensors directly into memory via all_sum
|
||||
3. Full file transfer: broadcast all files (including safetensors) to disk
|
||||
|
||||
All functions are collective operations — every rank in the group must call them.
|
||||
|
||||
Protocol relies on all_sum: source has real data, receivers have zeros.
|
||||
all_sum(source + zeros) = source data on all ranks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Final, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
|
||||
CHUNK_SIZE: Final[int] = 100 * 1024 * 1024 # 100 MB
|
||||
_LAYER_RE: Final[re.Pattern[str]] = re.compile(r"(?:^|\.)(layers|h)\.(\d+)\.")
|
||||
|
||||
|
||||
def _all_sum_cpu(x: mx.array, group: Group) -> mx.array:
|
||||
"""all_sum on CPU stream to avoid GPU memory pressure."""
|
||||
return mx.distributed.all_sum(
|
||||
x, stream=mx.default_stream(mx.Device(mx.cpu)), group=group
|
||||
)
|
||||
|
||||
|
||||
def _is_metadata_file(filename: str) -> bool:
|
||||
"""A metadata file is anything that isn't a weight file or weight index.
|
||||
|
||||
Weight indices (.safetensors.index.json) reference safetensors shard paths.
|
||||
Transferring them to a receiver that has no safetensors files is harmless
|
||||
today (load_model's glob doesn't match them), but excluding them avoids
|
||||
stale references and keeps the transfer minimal.
|
||||
"""
|
||||
if filename.endswith(".safetensors"):
|
||||
return False
|
||||
return not filename.endswith(".safetensors.index.json")
|
||||
|
||||
|
||||
def model_path_for_id(model_id: ModelId) -> Path:
|
||||
"""Get model path without requiring directory to exist (unlike build_model_path)."""
|
||||
return EXO_MODELS_DIR / model_id.normalize()
|
||||
|
||||
|
||||
def coordinate_transfer(group: Group, has_local_model: bool) -> tuple[bool, int]:
|
||||
"""
|
||||
Determine if a transfer is needed and which rank is the source.
|
||||
|
||||
All ranks must call this function (uses collective all_sum).
|
||||
|
||||
Returns:
|
||||
(needs_transfer, source_rank) — source_rank is the lowest rank
|
||||
that has the model. needs_transfer is True if any rank is missing it.
|
||||
"""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
world_size = group.size()
|
||||
|
||||
# Each rank broadcasts a one-hot vector at its position if it has the model
|
||||
bitmask = mx.zeros(world_size, dtype=mx.int32)
|
||||
if has_local_model:
|
||||
bitmask = bitmask.at[group.rank()].add(1)
|
||||
summed = all_sum(bitmask)
|
||||
mx.eval(summed)
|
||||
|
||||
has_model_flags: list[int] = summed.tolist() # type: ignore[assignment]
|
||||
total_have = sum(has_model_flags)
|
||||
|
||||
if total_have == 0:
|
||||
raise RuntimeError(
|
||||
"No rank has the model files — cannot transfer. "
|
||||
"At least one node must have downloaded the model."
|
||||
)
|
||||
|
||||
if total_have == world_size:
|
||||
logger.info("All ranks have model files, no transfer needed")
|
||||
return False, 0
|
||||
|
||||
source_rank = next(i for i, flag in enumerate(has_model_flags) if flag > 0)
|
||||
logger.info(
|
||||
f"Transfer needed: source_rank={source_rank}, "
|
||||
f"{total_have}/{world_size} ranks have model"
|
||||
)
|
||||
return True, source_rank
|
||||
|
||||
|
||||
def _broadcast_json(obj: object, group: Group, is_source: bool) -> object:
|
||||
"""Broadcast a JSON-serializable object from source to all ranks."""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
data = json.dumps(obj, separators=(",", ":")).encode("utf-8") if is_source else b""
|
||||
|
||||
# Broadcast length
|
||||
len_arr = mx.array([len(data) if is_source else 0], dtype=mx.int64)
|
||||
len_result = all_sum(len_arr)
|
||||
mx.eval(len_result)
|
||||
length = int(len_result.item())
|
||||
if length == 0:
|
||||
return None
|
||||
|
||||
# Broadcast payload
|
||||
if is_source:
|
||||
arr = mx.array(list(data), dtype=mx.uint8)
|
||||
else:
|
||||
arr = mx.zeros(length, dtype=mx.uint8)
|
||||
result = all_sum(arr)
|
||||
mx.eval(result)
|
||||
return json.loads(bytes(cast(list[int], result.tolist()))) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def _build_manifest(
|
||||
model_path: Path, metadata_only: bool = False
|
||||
) -> list[dict[str, str | int]]:
|
||||
"""Build a list of files in the model directory with their relative paths and sizes."""
|
||||
manifest: list[dict[str, str | int]] = []
|
||||
for root, _dirs, files in os.walk(model_path):
|
||||
for fname in sorted(files):
|
||||
if metadata_only and not _is_metadata_file(fname):
|
||||
continue
|
||||
full_path = Path(root) / fname
|
||||
rel_path = str(full_path.relative_to(model_path))
|
||||
manifest.append(
|
||||
{
|
||||
"path": rel_path,
|
||||
"size": full_path.stat().st_size,
|
||||
}
|
||||
)
|
||||
return manifest
|
||||
|
||||
|
||||
def _transfer_file_to_disk(
|
||||
source_path: Path,
|
||||
rel_path: str,
|
||||
file_size: int,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
dest_path: Path,
|
||||
) -> None:
|
||||
"""Transfer a single file chunk-by-chunk via all_sum. Source reads from disk, receivers write to dest_path."""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
if is_source:
|
||||
src_file = source_path / rel_path
|
||||
with open(src_file, "rb") as f:
|
||||
offset = 0
|
||||
while offset < file_size:
|
||||
chunk_bytes = min(CHUNK_SIZE, file_size - offset)
|
||||
data = f.read(chunk_bytes)
|
||||
if not data:
|
||||
break
|
||||
size_arr = mx.array([len(data)], dtype=mx.int64)
|
||||
mx.eval(all_sum(size_arr))
|
||||
chunk_arr = mx.array(list(data), dtype=mx.uint8)
|
||||
result = all_sum(chunk_arr)
|
||||
mx.eval(result)
|
||||
offset += len(data)
|
||||
# Signal end of file
|
||||
mx.eval(all_sum(mx.array([0], dtype=mx.int64)))
|
||||
else:
|
||||
dst_file = dest_path / rel_path
|
||||
os.makedirs(dst_file.parent, exist_ok=True)
|
||||
with open(dst_file, "wb") as f:
|
||||
while True:
|
||||
size_arr = all_sum(mx.zeros(1, dtype=mx.int64))
|
||||
mx.eval(size_arr)
|
||||
chunk_size = int(size_arr.item())
|
||||
if chunk_size == 0:
|
||||
break
|
||||
chunk_data = all_sum(mx.zeros(chunk_size, dtype=mx.uint8))
|
||||
mx.eval(chunk_data)
|
||||
f.write(bytes(cast(list[int], chunk_data.tolist())))
|
||||
|
||||
|
||||
def _transfer_files_to_disk(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
metadata_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer files from source to all receivers' disk.
|
||||
|
||||
Source broadcasts a manifest then each file. Receivers write to a temp dir
|
||||
then atomically move files to model_path.
|
||||
"""
|
||||
if is_source:
|
||||
source_manifest = _build_manifest(model_path, metadata_only=metadata_only)
|
||||
else:
|
||||
source_manifest = []
|
||||
manifest = cast(
|
||||
list[dict[str, str | int]],
|
||||
_broadcast_json(source_manifest if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
if not manifest:
|
||||
logger.info("No files to transfer")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Transferring {len(manifest)} files ({'metadata only' if metadata_only else 'all'})"
|
||||
)
|
||||
|
||||
temp_dir: Path | None = None
|
||||
if not is_source:
|
||||
os.makedirs(model_path.parent, exist_ok=True)
|
||||
temp_dir = Path(
|
||||
tempfile.mkdtemp(
|
||||
dir=model_path.parent,
|
||||
prefix=f".transfer_{model_path.name}_",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
for entry in manifest:
|
||||
rel_path = str(entry["path"])
|
||||
file_size = int(entry["size"])
|
||||
logger.info(f" {rel_path} ({file_size} bytes)")
|
||||
_transfer_file_to_disk(
|
||||
source_path=model_path,
|
||||
rel_path=rel_path,
|
||||
file_size=file_size,
|
||||
group=group,
|
||||
is_source=is_source,
|
||||
dest_path=temp_dir if temp_dir is not None else model_path,
|
||||
)
|
||||
|
||||
if temp_dir is not None:
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
for entry in manifest:
|
||||
rel_path = str(entry["path"])
|
||||
src = temp_dir / rel_path
|
||||
dst = model_path / rel_path
|
||||
os.makedirs(dst.parent, exist_ok=True)
|
||||
os.replace(src, dst)
|
||||
logger.info(
|
||||
f"Transfer complete: {len(manifest)} files moved to {model_path}"
|
||||
)
|
||||
finally:
|
||||
if temp_dir is not None and temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def transfer_metadata_files(model_path: Path, group: Group, is_source: bool) -> None:
|
||||
"""
|
||||
Transfer metadata files (config.json, tokenizer files, etc.) to receivers' disk.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
"""
|
||||
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=True)
|
||||
|
||||
|
||||
def transfer_all_files(model_path: Path, group: Group, is_source: bool) -> None:
|
||||
"""
|
||||
Transfer ALL model files (including safetensors) to receivers' disk.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
"""
|
||||
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=False)
|
||||
|
||||
|
||||
def _parse_mx_dtype(dtype_str: str) -> mx.Dtype:
|
||||
"""Convert a dtype string like 'float16' or 'mlx.core.float16' to mx.Dtype."""
|
||||
name = dtype_str.split(".")[-1]
|
||||
dtype = getattr(mx, name, None)
|
||||
if dtype is None:
|
||||
raise ValueError(f"Unknown MLX dtype: {dtype_str}")
|
||||
return dtype # type: ignore[return-value]
|
||||
|
||||
|
||||
def _extract_layer_index(name: str) -> int | None:
|
||||
"""Extract layer index from a weight name, or None for non-layer weights.
|
||||
|
||||
Matches patterns like ``model.layers.5.self_attn.q_proj.weight``
|
||||
or ``transformer.h.12.mlp.gate_proj.scales``.
|
||||
"""
|
||||
m = _LAYER_RE.search(name)
|
||||
return int(m.group(2)) if m else None
|
||||
|
||||
|
||||
class WeightBroadcastState:
|
||||
"""Holds state for layer-by-layer weight broadcasting.
|
||||
|
||||
Created by :func:`prepare_weight_broadcast`. Callers stream weights
|
||||
incrementally via :meth:`broadcast_non_layer_weights` and
|
||||
:meth:`broadcast_layer` so that at most one layer's worth of un-sharded
|
||||
weight data is resident at a time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta: dict[str, dict[str, Any]],
|
||||
source_weights: dict[str, mx.array] | None,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> None:
|
||||
self.meta = meta
|
||||
self.source_weights = source_weights
|
||||
self.group = group
|
||||
self.is_source = is_source
|
||||
|
||||
# Partition weight names into layer vs. non-layer
|
||||
self.layer_names: dict[int, list[str]] = {}
|
||||
self.non_layer_names: list[str] = []
|
||||
for name in sorted(meta.keys()):
|
||||
layer_idx = _extract_layer_index(name)
|
||||
if layer_idx is not None:
|
||||
self.layer_names.setdefault(layer_idx, []).append(name)
|
||||
else:
|
||||
self.non_layer_names.append(name)
|
||||
|
||||
logger.info(
|
||||
f"WeightBroadcastState: {len(self.non_layer_names)} non-layer weights, "
|
||||
f"{len(self.layer_names)} layers"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _broadcast_names(self, names: list[str]) -> dict[str, mx.array]:
|
||||
"""Broadcast a specific set of weight tensors by name."""
|
||||
all_sum = partial(_all_sum_cpu, group=self.group)
|
||||
result: dict[str, mx.array] = {}
|
||||
for name in names:
|
||||
info = self.meta[name]
|
||||
shape = cast(list[int], info["s"])
|
||||
dtype = _parse_mx_dtype(cast(str, info["d"]))
|
||||
|
||||
if self.is_source:
|
||||
assert self.source_weights is not None
|
||||
tensor = self.source_weights.pop(name)
|
||||
mx.eval(tensor) # loads from disk (lazy)
|
||||
else:
|
||||
tensor = mx.zeros(shape, dtype=dtype)
|
||||
|
||||
broadcasted = all_sum(tensor)
|
||||
mx.eval(broadcasted)
|
||||
result[name] = broadcasted
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def broadcast_non_layer_weights(self) -> dict[str, mx.array]:
|
||||
"""Broadcast non-layer weights (embeddings, norms, lm_head)."""
|
||||
if not self.non_layer_names:
|
||||
return {}
|
||||
logger.info(
|
||||
f"Broadcasting {len(self.non_layer_names)} non-layer weight tensors"
|
||||
)
|
||||
return self._broadcast_names(self.non_layer_names)
|
||||
|
||||
def broadcast_layer(self, layer_idx: int) -> dict[str, mx.array]:
|
||||
"""Broadcast weights for a single transformer layer."""
|
||||
names = self.layer_names.get(layer_idx, [])
|
||||
if not names:
|
||||
return {}
|
||||
return self._broadcast_names(names)
|
||||
|
||||
|
||||
def prepare_weight_broadcast(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> WeightBroadcastState:
|
||||
"""Prepare for layer-by-layer weight broadcasting.
|
||||
|
||||
Source loads safetensors lazily and broadcasts weight metadata (names,
|
||||
shapes, dtypes) as JSON. Returns a :class:`WeightBroadcastState` that
|
||||
can then stream weights incrementally via ``broadcast_layer()``.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
"""
|
||||
source_weights: dict[str, mx.array] | None = None
|
||||
if is_source:
|
||||
source_weights = {}
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
weight_files = sorted(model_path.glob("**/*.safetensors"))
|
||||
for wf in weight_files:
|
||||
try:
|
||||
loaded = cast(
|
||||
dict[str, mx.array],
|
||||
mx.load(str(wf), lazy=True), # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
except TypeError:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
|
||||
source_weights.update(loaded)
|
||||
logger.info(
|
||||
f"Source loaded {len(source_weights)} weight tensors (lazy) "
|
||||
f"from {len(weight_files)} files"
|
||||
)
|
||||
|
||||
# Broadcast metadata
|
||||
if is_source and source_weights is not None:
|
||||
source_meta: dict[str, dict[str, Any]] = {
|
||||
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
|
||||
for name, tensor in source_weights.items()
|
||||
}
|
||||
else:
|
||||
source_meta = {}
|
||||
|
||||
meta = cast(
|
||||
dict[str, dict[str, Any]],
|
||||
_broadcast_json(source_meta if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
logger.info(f"Weight broadcast prepared: {len(meta)} tensors")
|
||||
return WeightBroadcastState(meta, source_weights, group, is_source)
|
||||
|
||||
|
||||
def broadcast_model_weights(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> dict[str, mx.array]:
|
||||
"""
|
||||
Broadcast model weight tensors from source rank to all receivers' memory.
|
||||
|
||||
Source loads weights from .safetensors files on disk and broadcasts each
|
||||
tensor via all_sum. Receivers receive tensors directly as mx.arrays in
|
||||
memory — no disk write for weight data.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
|
||||
Returns:
|
||||
dict mapping weight names to mx.arrays (on all ranks).
|
||||
"""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
# Source loads weights (lazy if supported, so only one tensor in memory at a time)
|
||||
weights: dict[str, mx.array] = {}
|
||||
if is_source:
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
weight_files = sorted(model_path.glob("**/*.safetensors"))
|
||||
for wf in weight_files:
|
||||
try:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf), lazy=True)) # pyright: ignore[reportCallIssue]
|
||||
except TypeError:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
|
||||
weights.update(loaded)
|
||||
logger.info(
|
||||
f"Source loaded {len(weights)} weight tensors from {len(weight_files)} files"
|
||||
)
|
||||
|
||||
# Broadcast weight metadata: {name: {shape, dtype}}
|
||||
if is_source:
|
||||
source_meta: dict[str, dict[str, Any]] = {
|
||||
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
|
||||
for name, tensor in weights.items()
|
||||
}
|
||||
else:
|
||||
source_meta = {}
|
||||
meta = cast(
|
||||
dict[str, dict[str, Any]],
|
||||
_broadcast_json(source_meta if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
logger.info(f"Broadcasting {len(meta)} weight tensors")
|
||||
|
||||
# Broadcast each tensor in sorted order (deterministic across ranks).
|
||||
# Source loads one tensor at a time from disk (lazy), broadcasts it,
|
||||
# then drops the reference so only one tensor is in flight at a time.
|
||||
result: dict[str, mx.array] = {}
|
||||
for i, name in enumerate(sorted(meta.keys())):
|
||||
info = meta[name]
|
||||
shape = cast(list[int], info["s"])
|
||||
dtype_str = cast(str, info["d"])
|
||||
dtype = _parse_mx_dtype(dtype_str)
|
||||
|
||||
if is_source:
|
||||
tensor = weights.pop(name) # pop to free lazy ref after broadcast
|
||||
mx.eval(tensor) # loads from disk
|
||||
else:
|
||||
tensor = mx.zeros(shape, dtype=dtype)
|
||||
|
||||
broadcasted = all_sum(tensor)
|
||||
mx.eval(broadcasted)
|
||||
result[name] = broadcasted
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
logger.info(f" Broadcast {i + 1}/{len(meta)} tensors")
|
||||
|
||||
logger.info(f"Weight broadcast complete: {len(result)} tensors")
|
||||
return result
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -60,13 +59,6 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
pipeline_auto_parallel,
|
||||
tensor_auto_parallel,
|
||||
)
|
||||
from exo.worker.engines.mlx.model_transfer import (
|
||||
WeightBroadcastState,
|
||||
coordinate_transfer,
|
||||
model_path_for_id,
|
||||
prepare_weight_broadcast,
|
||||
transfer_metadata_files,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
@@ -205,7 +197,6 @@ def load_mlx_items(
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
has_local_model: bool = True,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
@@ -220,10 +211,7 @@ def load_mlx_items(
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard,
|
||||
group=group,
|
||||
on_timeout=on_timeout,
|
||||
has_local_model=has_local_model,
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
@@ -239,89 +227,30 @@ def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
has_local_model: bool = True,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_id = shard_metadata.model_card.model_id
|
||||
model_path = model_path_for_id(model_id)
|
||||
model_path = build_model_path(shard_metadata.model_card.model_id)
|
||||
|
||||
# Coordinate: does any rank need a transfer?
|
||||
needs_transfer, source_rank = coordinate_transfer(group, has_local_model)
|
||||
is_source = group.rank() == source_rank
|
||||
|
||||
# Step 1: Always ensure all nodes have metadata files (config, tokenizer, etc.).
|
||||
# This is cheap (~20MB, ~1s) and guarantees config.json is present for load_model().
|
||||
transfer_metadata_files(model_path, group, is_source)
|
||||
|
||||
# Step 2: Only broadcast weights if some rank is missing the model
|
||||
broadcast_state: WeightBroadcastState | None = None
|
||||
if needs_transfer:
|
||||
logger.info(
|
||||
f"Model transfer needed (source_rank={source_rank}, "
|
||||
f"is_source={is_source}, local_weights={has_local_model})"
|
||||
)
|
||||
broadcast_state = prepare_weight_broadcast(model_path, group, is_source)
|
||||
|
||||
# Create model architecture (all ranks have config.json on disk now).
|
||||
# Always use lazy=True when we have broadcast state: load_model's internal
|
||||
# nn.quantize skips quantization when weights dict is empty (no safetensors),
|
||||
# leaving the model un-quantized. lazy=False would then mx.eval() the full
|
||||
# fp16 model (~72GB for a 36B-param model), causing OOM on the receiver.
|
||||
# We handle quantization ourselves below before loading broadcast weights.
|
||||
use_lazy = has_local_model or broadcast_state is not None
|
||||
model, _ = load_model(model_path, lazy=use_lazy, strict=False)
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
logger.debug(model)
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# TODO: See if we should quantize the model.
|
||||
# def is_attention_layer(path: str) -> bool:
|
||||
# path = path.lower()
|
||||
|
||||
# return "self_attn" in path and "layernorm" not in path
|
||||
|
||||
# def quant_predicate(path: str, module: nn.Module):
|
||||
# if not isinstance(module, nn.Linear):
|
||||
# return False
|
||||
|
||||
# return is_attention_layer(path)
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
if broadcast_state is not None:
|
||||
# When receiver has no weight files, load_model skips quantization
|
||||
# (its class_predicate checks `f"{p}.scales" in weights`, which is
|
||||
# always False when weights is empty). Apply quantization explicitly
|
||||
# using the broadcast metadata to determine which layers are quantized,
|
||||
# matching load_model's selective quantization logic exactly.
|
||||
if not has_local_model:
|
||||
config_path = model_path / "config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f) # pyright: ignore[reportAny]
|
||||
quant_config: dict[str, Any] | None = config.get( # pyright: ignore[reportAny]
|
||||
"quantization", None
|
||||
)
|
||||
if quant_config is not None:
|
||||
logger.info(f"Applying quantization to receiver model: {quant_config}")
|
||||
broadcast_weight_names = set(broadcast_state.meta.keys())
|
||||
|
||||
def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
|
||||
# Per-layer overrides from config (e.g. "lm_head": false)
|
||||
assert quant_config is not None
|
||||
if p in quant_config:
|
||||
return quant_config[p] # pyright: ignore[reportAny]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
# Only quantize layers whose .scales exist in broadcast weights
|
||||
return f"{p}.scales" in broadcast_weight_names
|
||||
|
||||
group_size = int(quant_config.get("group_size", 64)) # pyright: ignore[reportAny]
|
||||
bits = int(quant_config.get("bits", 4)) # pyright: ignore[reportAny]
|
||||
mode: str = quant_config.get("mode", "affine") # pyright: ignore[reportAny]
|
||||
nn.quantize( # pyright: ignore[reportUnknownMemberType]
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mode,
|
||||
class_predicate=_class_predicate,
|
||||
)
|
||||
|
||||
# Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.
|
||||
# These are small (~600MB) and needed before the sharding loop.
|
||||
non_layer_weights = broadcast_state.broadcast_non_layer_weights()
|
||||
if non_layer_weights:
|
||||
model.load_weights(list(non_layer_weights.items()), strict=False)
|
||||
logger.info(f"Loaded {len(non_layer_weights)} non-layer weight tensors")
|
||||
del non_layer_weights
|
||||
|
||||
tokenizer = get_tokenizer(model_path, shard_metadata)
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
@@ -335,43 +264,12 @@ def shard_and_load(
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
# Build per-layer weight loader for streaming broadcast during sharding.
|
||||
# Each layer's weights are broadcast via all_sum just before that layer is
|
||||
# sharded, so at most one un-sharded layer is in memory at a time.
|
||||
weight_loader_fn: Callable[[nn.Module, int], None] | None = None
|
||||
if broadcast_state is not None:
|
||||
_state = broadcast_state # capture for closure
|
||||
|
||||
def _load_layer_weights(mdl: nn.Module, layer_idx: int) -> None:
|
||||
layer_weights = _state.broadcast_layer(layer_idx)
|
||||
if layer_weights:
|
||||
mdl.load_weights(list(layer_weights.items()), strict=False)
|
||||
|
||||
weight_loader_fn = _load_layer_weights
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(
|
||||
model, group, timeout_seconds, on_timeout, weight_loader_fn
|
||||
)
|
||||
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
# Broadcast all layers (all_sum is collective — all ranks must
|
||||
# participate) but only load weights for layers this node will
|
||||
# keep after pipeline slicing. Out-of-range results are discarded,
|
||||
# keeping peak memory proportional to this node's layer count.
|
||||
if broadcast_state is not None:
|
||||
for layer_idx in sorted(broadcast_state.layer_names.keys()):
|
||||
layer_weights = broadcast_state.broadcast_layer(layer_idx)
|
||||
if (
|
||||
shard_metadata.start_layer
|
||||
<= layer_idx
|
||||
< shard_metadata.end_layer
|
||||
and layer_weights
|
||||
):
|
||||
model.load_weights(list(layer_weights.items()), strict=False)
|
||||
del layer_weights
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
case CfgShardMetadata():
|
||||
@@ -380,8 +278,6 @@ def shard_and_load(
|
||||
"this metadata type is only for image generation models"
|
||||
)
|
||||
|
||||
del broadcast_state
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
@@ -17,7 +16,6 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
TransferModelToDisk,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
@@ -36,11 +34,8 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -62,7 +57,6 @@ def plan(
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _transfer_model_to_disk(runners, all_runners, global_download_status)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
@@ -127,10 +121,6 @@ def _model_needs_download(
|
||||
}
|
||||
|
||||
for runner in runners.values():
|
||||
# Transfer-only instances don't need downloads
|
||||
if runner.bound_instance.instance.shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
@@ -139,15 +129,6 @@ def _model_needs_download(
|
||||
(DownloadOngoing, DownloadCompleted, DownloadFailed),
|
||||
)
|
||||
):
|
||||
# For multi-node instances, skip download if a peer already has the model.
|
||||
# The model will be transferred via MLX distributed during LoadModel.
|
||||
instance = runner.bound_instance.instance
|
||||
is_multi_node = len(instance.shard_assignments.node_to_runner) > 1
|
||||
if is_multi_node and _any_peer_has_model(
|
||||
node_id, model_id, instance, global_download_status
|
||||
):
|
||||
continue
|
||||
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
instance_id=runner.bound_instance.instance.instance_id,
|
||||
@@ -205,43 +186,6 @@ def _init_distributed_backend(
|
||||
return None
|
||||
|
||||
|
||||
def _transfer_model_to_disk(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> TransferModelToDisk | None:
|
||||
"""For transfer-only instances: after all ranks are connected, emit TransferModelToDisk."""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
if not shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
is_runner_connected = isinstance(runner.status, RunnerConnected)
|
||||
all_connected_or_further = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerShuttingDown, RunnerShutdown),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_connected and all_connected_or_further:
|
||||
has_local = _node_has_download(
|
||||
runner.bound_instance.bound_node_id,
|
||||
shard_assignments.model_id,
|
||||
global_download_status,
|
||||
)
|
||||
return TransferModelToDisk(
|
||||
instance_id=instance.instance_id,
|
||||
shard_metadata=runner.bound_instance.bound_shard,
|
||||
has_local_model=has_local,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -251,97 +195,38 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
# Transfer-only instances don't load models for inference
|
||||
if shard_assignments.transfer_only:
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
if is_single_node_instance:
|
||||
# Single-node: require local download complete
|
||||
if not _all_downloads_complete(shard_assignments, global_download_status):
|
||||
continue
|
||||
if isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id, has_local_model=True)
|
||||
else:
|
||||
# Multi-node: require at least one node to have the model downloaded.
|
||||
# Nodes without the model will receive it via MLX distributed transfer
|
||||
# during model loading.
|
||||
if not _any_download_complete(shard_assignments, global_download_status):
|
||||
continue
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
has_local = _node_has_download(
|
||||
runner.bound_instance.bound_node_id,
|
||||
shard_assignments.model_id,
|
||||
global_download_status,
|
||||
)
|
||||
return LoadModel(
|
||||
instance_id=instance.instance_id,
|
||||
has_local_model=has_local,
|
||||
)
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _node_has_download(
|
||||
nid: NodeId,
|
||||
model_id: ModelId,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if a specific node has completed downloading the given model."""
|
||||
return any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == model_id
|
||||
for dp in global_download_status.get(nid, [])
|
||||
)
|
||||
|
||||
|
||||
def _any_peer_has_model(
|
||||
node_id: NodeId,
|
||||
model_id: ModelId,
|
||||
instance: Instance,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if any other node in the instance already has the model downloaded."""
|
||||
return any(
|
||||
_node_has_download(nid, model_id, global_download_status)
|
||||
for nid in instance.shard_assignments.node_to_runner
|
||||
if nid != node_id
|
||||
)
|
||||
|
||||
|
||||
def _all_downloads_complete(
|
||||
shard_assignments: ShardAssignments,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if ALL nodes in the instance have completed downloading the model."""
|
||||
return all(
|
||||
_node_has_download(nid, shard_assignments.model_id, global_download_status)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
|
||||
|
||||
def _any_download_complete(
|
||||
shard_assignments: ShardAssignments,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if at least one node in the instance has completed downloading the model."""
|
||||
return any(
|
||||
_node_has_download(nid, shard_assignments.model_id, global_download_status)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
|
||||
|
||||
def _ready_to_warmup(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -349,11 +234,6 @@ def _ready_to_warmup(
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
# Transfer-only instances don't go through warmup
|
||||
if shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
@@ -415,12 +295,14 @@ def _pending_tasks(
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
@@ -43,7 +43,6 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
TransferModelToDisk,
|
||||
)
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -83,11 +82,6 @@ from exo.worker.engines.image import (
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.model_transfer import (
|
||||
coordinate_transfer,
|
||||
model_path_for_id,
|
||||
transfer_all_files,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
@@ -198,10 +192,7 @@ def main(
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance,
|
||||
group,
|
||||
on_timeout=on_model_load_timeout,
|
||||
has_local_model=task.has_local_model,
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
@@ -517,27 +508,6 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case TransferModelToDisk() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
):
|
||||
logger.info("starting disk-to-disk model transfer")
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
model_path = model_path_for_id(
|
||||
task.shard_metadata.model_card.model_id
|
||||
)
|
||||
_, source_rank = coordinate_transfer(group, task.has_local_model)
|
||||
is_source = group.rank() == source_rank
|
||||
transfer_all_files(model_path, group, is_source)
|
||||
|
||||
logger.info("disk-to-disk model transfer complete")
|
||||
current_status = RunnerShuttingDown()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
|
||||
@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -112,7 +112,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is True
|
||||
|
||||
|
||||
def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
@@ -158,11 +157,10 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
assert not isinstance(result, plan_mod.DownloadModel)
|
||||
|
||||
|
||||
def test_plan_loads_model_when_any_node_has_download_for_multi_node():
|
||||
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
"""
|
||||
For multi-node instances, LoadModel should be emitted when at least one
|
||||
node has the model downloaded. Nodes without the model will receive it
|
||||
via MLX distributed transfer during model loading.
|
||||
LoadModel should not be emitted while some shards are still missing from
|
||||
the global_download_status.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -187,7 +185,6 @@ def test_plan_loads_model_when_any_node_has_download_for_multi_node():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A has the model — LoadModel should still fire
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -206,42 +203,19 @@ def test_plan_loads_model_when_any_node_has_download_for_multi_node():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is True
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_load_model_when_no_node_has_download():
|
||||
"""
|
||||
LoadModel should not be emitted when no node has the model downloaded.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# No node has the model
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [],
|
||||
NODE_B: [],
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -253,57 +227,4 @@ def test_plan_does_not_load_model_when_no_node_has_download():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_load_model_has_local_model_false_when_node_missing_download():
|
||||
"""
|
||||
For multi-node instances, when the local node does NOT have the model
|
||||
but a peer does, LoadModel should be emitted with has_local_model=False.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
# NODE_B is the local node (bound_node_id=NODE_B), it does NOT have the model
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A has the model, NODE_B does not
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is False
|
||||
assert result is not None
|
||||
|
||||
@@ -0,0 +1,341 @@
|
||||
"""
|
||||
Tests for continuous batching behavior in the runner.
|
||||
|
||||
These tests verify that:
|
||||
1. Single requests work through the batch path
|
||||
2. Multiple concurrent requests batch together
|
||||
3. Tokens are routed to the correct requests
|
||||
4. Requests complete at different times appropriately
|
||||
|
||||
NOTE: These tests require the continuous-batching runner architecture
|
||||
(BatchGenerationEngine) which is not yet integrated with main.
|
||||
"""
|
||||
|
||||
# ruff: noqa: E402
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.skip(
|
||||
"continuous batching runner not yet integrated with main branch runner",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerRunning
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""
|
||||
Fake batch engine that generates a specified number of tokens per request.
|
||||
|
||||
This simulates realistic batch generation behavior where:
|
||||
- Requests are queued on insert
|
||||
- Each step() call generates one token for all active requests
|
||||
- Requests complete when they've generated all their tokens
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, TextGenerationTaskParams]
|
||||
] = []
|
||||
self._uid_counter = 0
|
||||
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._do_insert(command_id, task_id, task_params)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def _do_insert(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams | None,
|
||||
) -> int:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
# Track: (command_id, task_id, tokens_generated, max_tokens)
|
||||
max_tokens = (
|
||||
task_params.max_output_tokens if task_params else self._tokens_per_request
|
||||
)
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish_reason = "stop" if tokens_gen >= max_tokens else None
|
||||
text = f"token{tokens_gen}"
|
||||
|
||||
if finish_reason:
|
||||
uids_to_remove.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (
|
||||
command_id,
|
||||
task_id,
|
||||
tokens_gen,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for uid in uids_to_remove:
|
||||
del self._active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False # Non-distributed mode for testing
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
"""Fake MLX distributed group for testing."""
|
||||
|
||||
def size(self) -> int:
|
||||
return 1 # Single node (non-distributed)
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
|
||||
"""
|
||||
Run tasks through the runner, adding shutdown at the end.
|
||||
|
||||
Tasks are sent in order, with shutdown sent last.
|
||||
The batch engine processes between task handling.
|
||||
"""
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
shutdown_task = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
with task_sender, event_receiver:
|
||||
# Send all tasks including shutdown
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown_task)
|
||||
|
||||
# Disable cleanup methods to prevent issues
|
||||
event_sender.close = lambda: None
|
||||
event_sender.join = lambda: None
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID,
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_single_request_generates_tokens(patch_batch_engine: None):
|
||||
"""
|
||||
Verify a single request generates the expected tokens through the batch path.
|
||||
|
||||
Note: With the current non-blocking design, shutdown is processed before
|
||||
batch steps run when all tasks are queued together. This test verifies
|
||||
the runner status reflects active requests.
|
||||
"""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events - this shows the request was inserted
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
|
||||
"""Verify RunnerRunning status includes active_requests count."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_chat_task_acknowledged(patch_batch_engine: None):
|
||||
"""Verify chat completion task is acknowledged with proper status updates."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find the chat task status events
|
||||
chat_running = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Running
|
||||
]
|
||||
|
||||
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
|
||||
|
||||
|
||||
def test_multiple_requests_tracked(patch_batch_engine: None):
|
||||
"""Verify multiple concurrent requests are tracked in active_requests."""
|
||||
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
# Should have at least 2 RunnerRunning events (one per request inserted)
|
||||
assert len(running_events) >= 2, (
|
||||
f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
|
||||
)
|
||||
|
||||
# First should have 1 active request, second should have 2
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
assert running_events[1].runner_status.active_requests == 2
|
||||
Reference in New Issue
Block a user