mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-12 23:21:44 -05:00
Compare commits
17 Commits
rust-explo
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3425c0ef51 | ||
|
|
5734408157 | ||
|
|
a34970bb5c | ||
|
|
02a78afb87 | ||
|
|
e1975558c1 | ||
|
|
9d6d60c411 | ||
|
|
e13e7af6e8 | ||
|
|
fb2b0148ee | ||
|
|
7c0147f544 | ||
|
|
f0d7560ec0 | ||
|
|
a1e4d5aba1 | ||
|
|
2c500ab8cf | ||
|
|
f9d7af4dbf | ||
|
|
f70be12b04 | ||
|
|
8c49adc97a | ||
|
|
e86e6a9d1e | ||
|
|
cc33213842 |
37
AGENTS.md
37
AGENTS.md
@@ -194,3 +194,40 @@ GitHub's API doesn't support direct image upload for PR comments. Workaround:
|
||||
git push origin <branch>
|
||||
```
|
||||
The images still render in the PR comment because they reference the permanent commit SHA.
|
||||
|
||||
## Running exo Remotely via SSH (macOS mDNS)
|
||||
|
||||
**CRITICAL: On macOS, mDNS multicast (used for peer discovery) only works when the process runs in a proper macOS user session.** Background processes started via `nohup ... &`, `screen`, or plain SSH commands will NOT send mDNS packets and nodes will never discover each other.
|
||||
|
||||
### The Problem
|
||||
When you SSH into a Mac and run `nohup uv run exo &`, the process runs in a detached session without access to macOS multicast networking. The exo node will start but will never discover peers, even if they're on the same network.
|
||||
|
||||
### The Solution: Use `open` with a `.command` wrapper
|
||||
|
||||
Create a `.command` script that `open` will execute in the proper macOS GUI session context:
|
||||
|
||||
```bash
|
||||
# 1. Create wrapper script on the remote machine
|
||||
ssh user@remote-mac "cat > /tmp/run_exo.command << 'SCRIPT'
|
||||
#!/bin/bash
|
||||
export PATH=/opt/homebrew/bin:\$HOME/.local/bin:\$PATH
|
||||
export EXO_LIBP2P_NAMESPACE=your-namespace # must match across all nodes
|
||||
cd ~/path/to/exo
|
||||
exec uv run exo -vv 2>&1 | tee /tmp/exo.log
|
||||
SCRIPT
|
||||
chmod +x /tmp/run_exo.command"
|
||||
|
||||
# 2. Launch it via `open` (runs in macOS GUI session with proper mDNS)
|
||||
ssh user@remote-mac "open /tmp/run_exo.command"
|
||||
|
||||
# 3. Check logs
|
||||
ssh user@remote-mac "tail -f /tmp/exo.log"
|
||||
```
|
||||
|
||||
### Key Details
|
||||
- **`EXO_LIBP2P_NAMESPACE`**: All nodes in a cluster MUST use the same namespace value. The EXO.app uses a build-specific namespace (check with `ps eww <pid> | grep NAMESPACE`). If mixing dev builds with EXO.app, set the dev build's namespace to match.
|
||||
- **`open *.command`**: This is the macOS equivalent of double-clicking the script in Finder. It runs in the user's GUI session with full network access.
|
||||
- **Do NOT use**: `nohup ... &`, `screen -dm`, `tmux new-session -d`, or `sshpass`. These all create detached sessions where mDNS won't work.
|
||||
- **Killing**: `ssh user@remote-mac "pkill -f 'python.*exo'"` works fine for stopping.
|
||||
- **Dashboard**: Must be built before running: `cd dashboard && npm install && npm run build && cd ..`. Node.js is at `/opt/homebrew/bin/node` on Apple Silicon Macs.
|
||||
- **Verifying cluster**: `curl -s http://localhost:52415/state | python3 -c "import json,sys; s=json.load(sys.stdin); print(len(s['topology']['nodes']), 'nodes')"` — should show 2+ nodes.
|
||||
|
||||
@@ -19,6 +19,11 @@ from urllib.parse import urlencode
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Backoff constants for cluster settling retry
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
@@ -388,6 +393,66 @@ class PromptSizer:
|
||||
return content, tok
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
@@ -464,6 +529,12 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -487,11 +558,6 @@ def main() -> int:
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = load_tokenizer_for_bench(full_model_id)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
@@ -503,54 +569,20 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
if not selected and args.settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + args.settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
|
||||
@@ -73,6 +73,8 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DistributeModelParams,
|
||||
DistributeModelResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -117,6 +119,7 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DistributeModel,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
@@ -142,6 +145,7 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
@@ -298,6 +302,7 @@ class API:
|
||||
self.app.get("/events")(self.stream_events)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.post("/v1/models/{model_id:path}/distribute")(self.distribute_model)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
@@ -1477,6 +1482,57 @@ class API:
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
async def distribute_model(
|
||||
self, model_id: ModelId, payload: DistributeModelParams
|
||||
) -> DistributeModelResponse:
|
||||
"""Distribute model files from one node to others via MLX distributed."""
|
||||
# Find a source node that has the model downloaded
|
||||
source_node_id: NodeId | None = None
|
||||
for nid, downloads in self.state.downloads.items():
|
||||
for dp in downloads:
|
||||
if (
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == model_id
|
||||
):
|
||||
source_node_id = nid
|
||||
break
|
||||
if source_node_id is not None:
|
||||
break
|
||||
|
||||
if source_node_id is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No node has model {model_id} downloaded",
|
||||
)
|
||||
|
||||
# Determine target nodes
|
||||
if payload.target_node_ids is not None:
|
||||
target_node_ids = [
|
||||
nid for nid in payload.target_node_ids if nid != source_node_id
|
||||
]
|
||||
else:
|
||||
target_node_ids = [
|
||||
nid for nid in self.state.topology.list_nodes() if nid != source_node_id
|
||||
]
|
||||
|
||||
if not target_node_ids:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No target nodes to distribute to",
|
||||
)
|
||||
|
||||
command = DistributeModel(
|
||||
model_id=model_id,
|
||||
source_node_id=source_node_id,
|
||||
target_node_ids=target_node_ids,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return DistributeModelResponse(
|
||||
command_id=command.command_id,
|
||||
message=f"Distributing {model_id} from {source_node_id} to {len(target_node_ids)} node(s)",
|
||||
)
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DistributeModel,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
@@ -312,6 +313,37 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case DistributeModel():
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
model_card = await ModelCard.load(command.model_id)
|
||||
all_node_ids = set(
|
||||
[command.source_node_id] + list(command.target_node_ids)
|
||||
)
|
||||
place_command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=len(all_node_ids),
|
||||
)
|
||||
placement = place_instance(
|
||||
place_command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
required_nodes=all_node_ids,
|
||||
)
|
||||
# Mark new instances as transfer-only
|
||||
for instance_id, instance in placement.items():
|
||||
if instance_id not in self.state.instances:
|
||||
instance.shard_assignments.transfer_only = True
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
|
||||
@@ -373,6 +373,15 @@ class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class DistributeModelParams(CamelCaseModel):
|
||||
target_node_ids: list[NodeId] | None = None # None = all connected nodes
|
||||
|
||||
|
||||
class DistributeModelResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
message: str
|
||||
|
||||
|
||||
class TraceEventResponse(CamelCaseModel):
|
||||
name: str
|
||||
start_us: int
|
||||
|
||||
@@ -77,6 +77,14 @@ class CancelDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class DistributeModel(BaseCommand):
|
||||
"""Distribute model files from one node to others via MLX distributed."""
|
||||
|
||||
model_id: ModelId
|
||||
source_node_id: NodeId
|
||||
target_node_ids: list[NodeId]
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
@@ -91,6 +99,7 @@ Command = (
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
| DistributeModel
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
has_local_model: bool = Field(default=True)
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
@@ -76,6 +76,13 @@ class ImageEdits(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TransferModelToDisk(BaseTask): # emitted by Worker
|
||||
"""Transfer all model files from source to receivers' disk via MLX distributed."""
|
||||
|
||||
shard_metadata: ShardMetadata
|
||||
has_local_model: bool = Field(default=True)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -85,6 +92,7 @@ Task = (
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| TransferModelToDisk
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| ImageGeneration
|
||||
|
||||
@@ -84,6 +84,7 @@ class ShardAssignments(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
runner_to_shard: Mapping[RunnerId, ShardMetadata]
|
||||
node_to_runner: Mapping[NodeId, RunnerId]
|
||||
transfer_only: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_runners_exist(self) -> "ShardAssignments":
|
||||
|
||||
@@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
WeightLoader = Callable[[nn.Module, int], None] | None
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
@@ -330,6 +331,7 @@ def tensor_auto_parallel(
|
||||
group: mx.distributed.Group,
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
all_to_sharded_linear = partial(
|
||||
shard_linear,
|
||||
@@ -431,7 +433,7 @@ def tensor_auto_parallel(
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
model = tensor_parallel_sharding_strategy.shard_model(
|
||||
model, timeout_seconds, on_timeout
|
||||
model, timeout_seconds, on_timeout, weight_loader
|
||||
)
|
||||
return patch_tensor_model(model)
|
||||
|
||||
@@ -458,6 +460,7 @@ class TensorParallelShardingStrategy(ABC):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module: ...
|
||||
|
||||
|
||||
@@ -467,9 +470,12 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(LlamaModel, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
# Force load weights before sharding to avoid FAST_SYNCH deadlock
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -521,9 +527,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -596,9 +605,12 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
for layer in model.layers: # type: ignore
|
||||
for i, layer in enumerate(model.layers): # type: ignore
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
@@ -738,9 +750,12 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -778,9 +793,12 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -902,9 +920,12 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Glm4MoeModel, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -948,10 +969,13 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
|
||||
499
src/exo/worker/engines/mlx/model_transfer.py
Normal file
499
src/exo/worker/engines/mlx/model_transfer.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
Model transfer via MLX distributed all_sum.
|
||||
|
||||
Three transfer modes:
|
||||
1. Metadata file transfer: broadcast small files (config.json, tokenizer, etc.) to disk
|
||||
2. Weight tensor broadcast: stream weight tensors directly into memory via all_sum
|
||||
3. Full file transfer: broadcast all files (including safetensors) to disk
|
||||
|
||||
All functions are collective operations — every rank in the group must call them.
|
||||
|
||||
Protocol relies on all_sum: source has real data, receivers have zeros.
|
||||
all_sum(source + zeros) = source data on all ranks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Final, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
|
||||
CHUNK_SIZE: Final[int] = 100 * 1024 * 1024 # 100 MB
|
||||
_LAYER_RE: Final[re.Pattern[str]] = re.compile(r"(?:^|\.)(layers|h)\.(\d+)\.")
|
||||
|
||||
|
||||
def _all_sum_cpu(x: mx.array, group: Group) -> mx.array:
|
||||
"""all_sum on CPU stream to avoid GPU memory pressure."""
|
||||
return mx.distributed.all_sum(
|
||||
x, stream=mx.default_stream(mx.Device(mx.cpu)), group=group
|
||||
)
|
||||
|
||||
|
||||
def _is_metadata_file(filename: str) -> bool:
|
||||
"""A metadata file is anything that isn't a weight file (.safetensors)."""
|
||||
return not filename.endswith(".safetensors")
|
||||
|
||||
|
||||
def model_path_for_id(model_id: ModelId) -> Path:
|
||||
"""Get model path without requiring directory to exist (unlike build_model_path)."""
|
||||
return EXO_MODELS_DIR / model_id.normalize()
|
||||
|
||||
|
||||
def coordinate_transfer(group: Group, has_local_model: bool) -> tuple[bool, int]:
|
||||
"""
|
||||
Determine if a transfer is needed and which rank is the source.
|
||||
|
||||
All ranks must call this function (uses collective all_sum).
|
||||
|
||||
Returns:
|
||||
(needs_transfer, source_rank) — source_rank is the lowest rank
|
||||
that has the model. needs_transfer is True if any rank is missing it.
|
||||
"""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
world_size = group.size()
|
||||
|
||||
# Each rank broadcasts a one-hot vector at its position if it has the model
|
||||
bitmask = mx.zeros(world_size, dtype=mx.int32)
|
||||
if has_local_model:
|
||||
bitmask = bitmask.at[group.rank()].add(1)
|
||||
summed = all_sum(bitmask)
|
||||
mx.eval(summed)
|
||||
|
||||
has_model_flags: list[int] = summed.tolist() # type: ignore[assignment]
|
||||
total_have = sum(has_model_flags)
|
||||
|
||||
if total_have == 0:
|
||||
raise RuntimeError(
|
||||
"No rank has the model files — cannot transfer. "
|
||||
"At least one node must have downloaded the model."
|
||||
)
|
||||
|
||||
if total_have == world_size:
|
||||
logger.info("All ranks have model files, no transfer needed")
|
||||
return False, 0
|
||||
|
||||
source_rank = next(i for i, flag in enumerate(has_model_flags) if flag > 0)
|
||||
logger.info(
|
||||
f"Transfer needed: source_rank={source_rank}, "
|
||||
f"{total_have}/{world_size} ranks have model"
|
||||
)
|
||||
return True, source_rank
|
||||
|
||||
|
||||
def _broadcast_json(obj: object, group: Group, is_source: bool) -> object:
|
||||
"""Broadcast a JSON-serializable object from source to all ranks."""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
data = json.dumps(obj, separators=(",", ":")).encode("utf-8") if is_source else b""
|
||||
|
||||
# Broadcast length
|
||||
len_arr = mx.array([len(data) if is_source else 0], dtype=mx.int64)
|
||||
len_result = all_sum(len_arr)
|
||||
mx.eval(len_result)
|
||||
length = int(len_result.item())
|
||||
if length == 0:
|
||||
return None
|
||||
|
||||
# Broadcast payload
|
||||
if is_source:
|
||||
arr = mx.array(list(data), dtype=mx.uint8)
|
||||
else:
|
||||
arr = mx.zeros(length, dtype=mx.uint8)
|
||||
result = all_sum(arr)
|
||||
mx.eval(result)
|
||||
return json.loads(bytes(cast(list[int], result.tolist()))) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def _build_manifest(
|
||||
model_path: Path, metadata_only: bool = False
|
||||
) -> list[dict[str, str | int]]:
|
||||
"""Build a list of files in the model directory with their relative paths and sizes."""
|
||||
manifest: list[dict[str, str | int]] = []
|
||||
for root, _dirs, files in os.walk(model_path):
|
||||
for fname in sorted(files):
|
||||
if metadata_only and not _is_metadata_file(fname):
|
||||
continue
|
||||
full_path = Path(root) / fname
|
||||
rel_path = str(full_path.relative_to(model_path))
|
||||
manifest.append(
|
||||
{
|
||||
"path": rel_path,
|
||||
"size": full_path.stat().st_size,
|
||||
}
|
||||
)
|
||||
return manifest
|
||||
|
||||
|
||||
def _transfer_file_to_disk(
|
||||
source_path: Path,
|
||||
rel_path: str,
|
||||
file_size: int,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
dest_path: Path,
|
||||
) -> None:
|
||||
"""Transfer a single file chunk-by-chunk via all_sum. Source reads from disk, receivers write to dest_path."""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
if is_source:
|
||||
src_file = source_path / rel_path
|
||||
with open(src_file, "rb") as f:
|
||||
offset = 0
|
||||
while offset < file_size:
|
||||
chunk_bytes = min(CHUNK_SIZE, file_size - offset)
|
||||
data = f.read(chunk_bytes)
|
||||
if not data:
|
||||
break
|
||||
size_arr = mx.array([len(data)], dtype=mx.int64)
|
||||
mx.eval(all_sum(size_arr))
|
||||
chunk_arr = mx.array(list(data), dtype=mx.uint8)
|
||||
result = all_sum(chunk_arr)
|
||||
mx.eval(result)
|
||||
offset += len(data)
|
||||
# Signal end of file
|
||||
mx.eval(all_sum(mx.array([0], dtype=mx.int64)))
|
||||
else:
|
||||
dst_file = dest_path / rel_path
|
||||
os.makedirs(dst_file.parent, exist_ok=True)
|
||||
with open(dst_file, "wb") as f:
|
||||
while True:
|
||||
size_arr = all_sum(mx.zeros(1, dtype=mx.int64))
|
||||
mx.eval(size_arr)
|
||||
chunk_size = int(size_arr.item())
|
||||
if chunk_size == 0:
|
||||
break
|
||||
chunk_data = all_sum(mx.zeros(chunk_size, dtype=mx.uint8))
|
||||
mx.eval(chunk_data)
|
||||
f.write(bytes(cast(list[int], chunk_data.tolist())))
|
||||
|
||||
|
||||
def _transfer_files_to_disk(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
metadata_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer files from source to all receivers' disk.
|
||||
|
||||
Source broadcasts a manifest then each file. Receivers write to a temp dir
|
||||
then atomically move files to model_path.
|
||||
"""
|
||||
if is_source:
|
||||
source_manifest = _build_manifest(model_path, metadata_only=metadata_only)
|
||||
else:
|
||||
source_manifest = []
|
||||
manifest = cast(
|
||||
list[dict[str, str | int]],
|
||||
_broadcast_json(source_manifest if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
if not manifest:
|
||||
logger.info("No files to transfer")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Transferring {len(manifest)} files ({'metadata only' if metadata_only else 'all'})"
|
||||
)
|
||||
|
||||
temp_dir: Path | None = None
|
||||
if not is_source:
|
||||
os.makedirs(model_path.parent, exist_ok=True)
|
||||
temp_dir = Path(
|
||||
tempfile.mkdtemp(
|
||||
dir=model_path.parent,
|
||||
prefix=f".transfer_{model_path.name}_",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
for entry in manifest:
|
||||
rel_path = str(entry["path"])
|
||||
file_size = int(entry["size"])
|
||||
logger.info(f" {rel_path} ({file_size} bytes)")
|
||||
_transfer_file_to_disk(
|
||||
source_path=model_path,
|
||||
rel_path=rel_path,
|
||||
file_size=file_size,
|
||||
group=group,
|
||||
is_source=is_source,
|
||||
dest_path=temp_dir if temp_dir is not None else model_path,
|
||||
)
|
||||
|
||||
if temp_dir is not None:
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
for entry in manifest:
|
||||
rel_path = str(entry["path"])
|
||||
src = temp_dir / rel_path
|
||||
dst = model_path / rel_path
|
||||
os.makedirs(dst.parent, exist_ok=True)
|
||||
os.replace(src, dst)
|
||||
logger.info(
|
||||
f"Transfer complete: {len(manifest)} files moved to {model_path}"
|
||||
)
|
||||
finally:
|
||||
if temp_dir is not None and temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def transfer_metadata_files(model_path: Path, group: Group, is_source: bool) -> None:
|
||||
"""
|
||||
Transfer metadata files (config.json, tokenizer files, etc.) to receivers' disk.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
"""
|
||||
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=True)
|
||||
|
||||
|
||||
def transfer_all_files(model_path: Path, group: Group, is_source: bool) -> None:
|
||||
"""
|
||||
Transfer ALL model files (including safetensors) to receivers' disk.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
"""
|
||||
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=False)
|
||||
|
||||
|
||||
def _parse_mx_dtype(dtype_str: str) -> mx.Dtype:
|
||||
"""Convert a dtype string like 'float16' or 'mlx.core.float16' to mx.Dtype."""
|
||||
name = dtype_str.split(".")[-1]
|
||||
dtype = getattr(mx, name, None)
|
||||
if dtype is None:
|
||||
raise ValueError(f"Unknown MLX dtype: {dtype_str}")
|
||||
return dtype # type: ignore[return-value]
|
||||
|
||||
|
||||
def _extract_layer_index(name: str) -> int | None:
|
||||
"""Extract layer index from a weight name, or None for non-layer weights.
|
||||
|
||||
Matches patterns like ``model.layers.5.self_attn.q_proj.weight``
|
||||
or ``transformer.h.12.mlp.gate_proj.scales``.
|
||||
"""
|
||||
m = _LAYER_RE.search(name)
|
||||
return int(m.group(2)) if m else None
|
||||
|
||||
|
||||
class WeightBroadcastState:
|
||||
"""Holds state for layer-by-layer weight broadcasting.
|
||||
|
||||
Created by :func:`prepare_weight_broadcast`. Callers stream weights
|
||||
incrementally via :meth:`broadcast_non_layer_weights` and
|
||||
:meth:`broadcast_layer` so that at most one layer's worth of un-sharded
|
||||
weight data is resident at a time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta: dict[str, dict[str, Any]],
|
||||
source_weights: dict[str, mx.array] | None,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> None:
|
||||
self.meta = meta
|
||||
self.source_weights = source_weights
|
||||
self.group = group
|
||||
self.is_source = is_source
|
||||
|
||||
# Partition weight names into layer vs. non-layer
|
||||
self.layer_names: dict[int, list[str]] = {}
|
||||
self.non_layer_names: list[str] = []
|
||||
for name in sorted(meta.keys()):
|
||||
layer_idx = _extract_layer_index(name)
|
||||
if layer_idx is not None:
|
||||
self.layer_names.setdefault(layer_idx, []).append(name)
|
||||
else:
|
||||
self.non_layer_names.append(name)
|
||||
|
||||
logger.info(
|
||||
f"WeightBroadcastState: {len(self.non_layer_names)} non-layer weights, "
|
||||
f"{len(self.layer_names)} layers"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _broadcast_names(self, names: list[str]) -> dict[str, mx.array]:
|
||||
"""Broadcast a specific set of weight tensors by name."""
|
||||
all_sum = partial(_all_sum_cpu, group=self.group)
|
||||
result: dict[str, mx.array] = {}
|
||||
for name in names:
|
||||
info = self.meta[name]
|
||||
shape = cast(list[int], info["s"])
|
||||
dtype = _parse_mx_dtype(cast(str, info["d"]))
|
||||
|
||||
if self.is_source:
|
||||
assert self.source_weights is not None
|
||||
tensor = self.source_weights.pop(name)
|
||||
mx.eval(tensor) # loads from disk (lazy)
|
||||
else:
|
||||
tensor = mx.zeros(shape, dtype=dtype)
|
||||
|
||||
broadcasted = all_sum(tensor)
|
||||
mx.eval(broadcasted)
|
||||
result[name] = broadcasted
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def broadcast_non_layer_weights(self) -> dict[str, mx.array]:
|
||||
"""Broadcast non-layer weights (embeddings, norms, lm_head)."""
|
||||
if not self.non_layer_names:
|
||||
return {}
|
||||
logger.info(
|
||||
f"Broadcasting {len(self.non_layer_names)} non-layer weight tensors"
|
||||
)
|
||||
return self._broadcast_names(self.non_layer_names)
|
||||
|
||||
def broadcast_layer(self, layer_idx: int) -> dict[str, mx.array]:
|
||||
"""Broadcast weights for a single transformer layer."""
|
||||
names = self.layer_names.get(layer_idx, [])
|
||||
if not names:
|
||||
return {}
|
||||
return self._broadcast_names(names)
|
||||
|
||||
|
||||
def prepare_weight_broadcast(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> WeightBroadcastState:
|
||||
"""Prepare for layer-by-layer weight broadcasting.
|
||||
|
||||
Source loads safetensors lazily and broadcasts weight metadata (names,
|
||||
shapes, dtypes) as JSON. Returns a :class:`WeightBroadcastState` that
|
||||
can then stream weights incrementally via ``broadcast_layer()``.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
"""
|
||||
source_weights: dict[str, mx.array] | None = None
|
||||
if is_source:
|
||||
source_weights = {}
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
weight_files = sorted(model_path.glob("**/*.safetensors"))
|
||||
for wf in weight_files:
|
||||
try:
|
||||
loaded = cast(
|
||||
dict[str, mx.array],
|
||||
mx.load(str(wf), lazy=True), # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
except TypeError:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
|
||||
source_weights.update(loaded)
|
||||
logger.info(
|
||||
f"Source loaded {len(source_weights)} weight tensors (lazy) "
|
||||
f"from {len(weight_files)} files"
|
||||
)
|
||||
|
||||
# Broadcast metadata
|
||||
if is_source and source_weights is not None:
|
||||
source_meta: dict[str, dict[str, Any]] = {
|
||||
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
|
||||
for name, tensor in source_weights.items()
|
||||
}
|
||||
else:
|
||||
source_meta = {}
|
||||
|
||||
meta = cast(
|
||||
dict[str, dict[str, Any]],
|
||||
_broadcast_json(source_meta if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
logger.info(f"Weight broadcast prepared: {len(meta)} tensors")
|
||||
return WeightBroadcastState(meta, source_weights, group, is_source)
|
||||
|
||||
|
||||
def broadcast_model_weights(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> dict[str, mx.array]:
|
||||
"""
|
||||
Broadcast model weight tensors from source rank to all receivers' memory.
|
||||
|
||||
Source loads weights from .safetensors files on disk and broadcasts each
|
||||
tensor via all_sum. Receivers receive tensors directly as mx.arrays in
|
||||
memory — no disk write for weight data.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
|
||||
Returns:
|
||||
dict mapping weight names to mx.arrays (on all ranks).
|
||||
"""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
# Source loads weights (lazy if supported, so only one tensor in memory at a time)
|
||||
weights: dict[str, mx.array] = {}
|
||||
if is_source:
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
weight_files = sorted(model_path.glob("**/*.safetensors"))
|
||||
for wf in weight_files:
|
||||
try:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf), lazy=True)) # pyright: ignore[reportCallIssue]
|
||||
except TypeError:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
|
||||
weights.update(loaded)
|
||||
logger.info(
|
||||
f"Source loaded {len(weights)} weight tensors from {len(weight_files)} files"
|
||||
)
|
||||
|
||||
# Broadcast weight metadata: {name: {shape, dtype}}
|
||||
if is_source:
|
||||
source_meta: dict[str, dict[str, Any]] = {
|
||||
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
|
||||
for name, tensor in weights.items()
|
||||
}
|
||||
else:
|
||||
source_meta = {}
|
||||
meta = cast(
|
||||
dict[str, dict[str, Any]],
|
||||
_broadcast_json(source_meta if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
logger.info(f"Broadcasting {len(meta)} weight tensors")
|
||||
|
||||
# Broadcast each tensor in sorted order (deterministic across ranks).
|
||||
# Source loads one tensor at a time from disk (lazy), broadcasts it,
|
||||
# then drops the reference so only one tensor is in flight at a time.
|
||||
result: dict[str, mx.array] = {}
|
||||
for i, name in enumerate(sorted(meta.keys())):
|
||||
info = meta[name]
|
||||
shape = cast(list[int], info["s"])
|
||||
dtype_str = cast(str, info["d"])
|
||||
dtype = _parse_mx_dtype(dtype_str)
|
||||
|
||||
if is_source:
|
||||
tensor = weights.pop(name) # pop to free lazy ref after broadcast
|
||||
mx.eval(tensor) # loads from disk
|
||||
else:
|
||||
tensor = mx.zeros(shape, dtype=dtype)
|
||||
|
||||
broadcasted = all_sum(tensor)
|
||||
mx.eval(broadcasted)
|
||||
result[name] = broadcasted
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
logger.info(f" Broadcast {i + 1}/{len(meta)} tensors")
|
||||
|
||||
logger.info(f"Weight broadcast complete: {len(result)} tensors")
|
||||
return result
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -59,6 +60,13 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
pipeline_auto_parallel,
|
||||
tensor_auto_parallel,
|
||||
)
|
||||
from exo.worker.engines.mlx.model_transfer import (
|
||||
WeightBroadcastState,
|
||||
coordinate_transfer,
|
||||
model_path_for_id,
|
||||
prepare_weight_broadcast,
|
||||
transfer_metadata_files,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
@@ -197,6 +205,7 @@ def load_mlx_items(
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
has_local_model: bool = True,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
@@ -211,7 +220,10 @@ def load_mlx_items(
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
bound_instance.bound_shard,
|
||||
group=group,
|
||||
on_timeout=on_timeout,
|
||||
has_local_model=has_local_model,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
@@ -227,30 +239,69 @@ def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
has_local_model: bool = True,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_card.model_id)
|
||||
model_id = shard_metadata.model_card.model_id
|
||||
model_path = model_path_for_id(model_id)
|
||||
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
# Coordinate: does any rank need a transfer?
|
||||
needs_transfer, source_rank = coordinate_transfer(group, has_local_model)
|
||||
is_source = group.rank() == source_rank
|
||||
|
||||
# Step 1: Always ensure all nodes have metadata files (config, tokenizer, etc.).
|
||||
# This is cheap (~20MB, ~1s) and guarantees config.json is present for load_model().
|
||||
transfer_metadata_files(model_path, group, is_source)
|
||||
|
||||
# Step 2: Only broadcast weights if some rank is missing the model
|
||||
broadcast_state: WeightBroadcastState | None = None
|
||||
if needs_transfer:
|
||||
logger.info(
|
||||
f"Model transfer needed (source_rank={source_rank}, "
|
||||
f"is_source={is_source}, local_weights={has_local_model})"
|
||||
)
|
||||
broadcast_state = prepare_weight_broadcast(model_path, group, is_source)
|
||||
|
||||
# Create model architecture (all ranks have config.json on disk now).
|
||||
# Always use lazy=True when we have broadcast state: load_model's internal
|
||||
# nn.quantize skips quantization when weights dict is empty (no safetensors),
|
||||
# leaving the model un-quantized. lazy=False would then mx.eval() the full
|
||||
# fp16 model (~72GB for a 36B-param model), causing OOM on the receiver.
|
||||
# We handle quantization ourselves below before loading broadcast weights.
|
||||
use_lazy = has_local_model or broadcast_state is not None
|
||||
model, _ = load_model(model_path, lazy=use_lazy, strict=False)
|
||||
logger.debug(model)
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# TODO: See if we should quantize the model.
|
||||
# def is_attention_layer(path: str) -> bool:
|
||||
# path = path.lower()
|
||||
|
||||
# return "self_attn" in path and "layernorm" not in path
|
||||
|
||||
# def quant_predicate(path: str, module: nn.Module):
|
||||
# if not isinstance(module, nn.Linear):
|
||||
# return False
|
||||
|
||||
# return is_attention_layer(path)
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
if broadcast_state is not None:
|
||||
# When receiver has no weight files, load_model skips quantization.
|
||||
# Apply it explicitly so QuantizedLinear layers match broadcast weight shapes.
|
||||
if not has_local_model:
|
||||
config_path = model_path / "config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f) # pyright: ignore[reportAny]
|
||||
quant_config: dict[str, int] | None = config.get( # pyright: ignore[reportAny]
|
||||
"quantization", None
|
||||
)
|
||||
if quant_config is not None:
|
||||
logger.info(f"Applying quantization to receiver model: {quant_config}")
|
||||
nn.quantize( # pyright: ignore[reportUnknownMemberType]
|
||||
model,
|
||||
group_size=quant_config.get("group_size", 64),
|
||||
bits=quant_config.get("bits", 4),
|
||||
)
|
||||
|
||||
# Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.
|
||||
# These are small (~600MB) and needed before the sharding loop.
|
||||
non_layer_weights = broadcast_state.broadcast_non_layer_weights()
|
||||
if non_layer_weights:
|
||||
model.load_weights(list(non_layer_weights.items()), strict=False)
|
||||
logger.info(f"Loaded {len(non_layer_weights)} non-layer weight tensors")
|
||||
del non_layer_weights
|
||||
|
||||
tokenizer = get_tokenizer(model_path, shard_metadata)
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
@@ -264,12 +315,43 @@ def shard_and_load(
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
# Build per-layer weight loader for streaming broadcast during sharding.
|
||||
# Each layer's weights are broadcast via all_sum just before that layer is
|
||||
# sharded, so at most one un-sharded layer is in memory at a time.
|
||||
weight_loader_fn: Callable[[nn.Module, int], None] | None = None
|
||||
if broadcast_state is not None:
|
||||
_state = broadcast_state # capture for closure
|
||||
|
||||
def _load_layer_weights(mdl: nn.Module, layer_idx: int) -> None:
|
||||
layer_weights = _state.broadcast_layer(layer_idx)
|
||||
if layer_weights:
|
||||
mdl.load_weights(list(layer_weights.items()), strict=False)
|
||||
|
||||
weight_loader_fn = _load_layer_weights
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
|
||||
model = tensor_auto_parallel(
|
||||
model, group, timeout_seconds, on_timeout, weight_loader_fn
|
||||
)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
# Broadcast all layers (all_sum is collective — all ranks must
|
||||
# participate) but only load weights for layers this node will
|
||||
# keep after pipeline slicing. Out-of-range results are discarded,
|
||||
# keeping peak memory proportional to this node's layer count.
|
||||
if broadcast_state is not None:
|
||||
for layer_idx in sorted(broadcast_state.layer_names.keys()):
|
||||
layer_weights = broadcast_state.broadcast_layer(layer_idx)
|
||||
if (
|
||||
shard_metadata.start_layer
|
||||
<= layer_idx
|
||||
< shard_metadata.end_layer
|
||||
and layer_weights
|
||||
):
|
||||
model.load_weights(list(layer_weights.items()), strict=False)
|
||||
del layer_weights
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
case CfgShardMetadata():
|
||||
@@ -278,6 +360,8 @@ def shard_and_load(
|
||||
"this metadata type is only for image generation models"
|
||||
)
|
||||
|
||||
del broadcast_state
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
@@ -16,6 +17,7 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
TransferModelToDisk,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
@@ -34,8 +36,11 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -57,6 +62,7 @@ def plan(
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _transfer_model_to_disk(runners, all_runners, global_download_status)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
@@ -121,6 +127,10 @@ def _model_needs_download(
|
||||
}
|
||||
|
||||
for runner in runners.values():
|
||||
# Transfer-only instances don't need downloads
|
||||
if runner.bound_instance.instance.shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
@@ -129,6 +139,15 @@ def _model_needs_download(
|
||||
(DownloadOngoing, DownloadCompleted, DownloadFailed),
|
||||
)
|
||||
):
|
||||
# For multi-node instances, skip download if a peer already has the model.
|
||||
# The model will be transferred via MLX distributed during LoadModel.
|
||||
instance = runner.bound_instance.instance
|
||||
is_multi_node = len(instance.shard_assignments.node_to_runner) > 1
|
||||
if is_multi_node and _any_peer_has_model(
|
||||
node_id, model_id, instance, global_download_status
|
||||
):
|
||||
continue
|
||||
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
instance_id=runner.bound_instance.instance.instance_id,
|
||||
@@ -186,6 +205,43 @@ def _init_distributed_backend(
|
||||
return None
|
||||
|
||||
|
||||
def _transfer_model_to_disk(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> TransferModelToDisk | None:
|
||||
"""For transfer-only instances: after all ranks are connected, emit TransferModelToDisk."""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
if not shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
is_runner_connected = isinstance(runner.status, RunnerConnected)
|
||||
all_connected_or_further = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerShuttingDown, RunnerShutdown),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_connected and all_connected_or_further:
|
||||
has_local = _node_has_download(
|
||||
runner.bound_instance.bound_node_id,
|
||||
shard_assignments.model_id,
|
||||
global_download_status,
|
||||
)
|
||||
return TransferModelToDisk(
|
||||
instance_id=instance.instance_id,
|
||||
shard_metadata=runner.bound_instance.bound_shard,
|
||||
has_local_model=has_local,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -195,38 +251,97 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
# Transfer-only instances don't load models for inference
|
||||
if shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
if is_single_node_instance:
|
||||
# Single-node: require local download complete
|
||||
if not _all_downloads_complete(shard_assignments, global_download_status):
|
||||
continue
|
||||
if isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id, has_local_model=True)
|
||||
else:
|
||||
# Multi-node: require at least one node to have the model downloaded.
|
||||
# Nodes without the model will receive it via MLX distributed transfer
|
||||
# during model loading.
|
||||
if not _any_download_complete(shard_assignments, global_download_status):
|
||||
continue
|
||||
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
has_local = _node_has_download(
|
||||
runner.bound_instance.bound_node_id,
|
||||
shard_assignments.model_id,
|
||||
global_download_status,
|
||||
)
|
||||
return LoadModel(
|
||||
instance_id=instance.instance_id,
|
||||
has_local_model=has_local,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _node_has_download(
|
||||
nid: NodeId,
|
||||
model_id: ModelId,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if a specific node has completed downloading the given model."""
|
||||
return any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == model_id
|
||||
for dp in global_download_status.get(nid, [])
|
||||
)
|
||||
|
||||
|
||||
def _any_peer_has_model(
|
||||
node_id: NodeId,
|
||||
model_id: ModelId,
|
||||
instance: Instance,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if any other node in the instance already has the model downloaded."""
|
||||
return any(
|
||||
_node_has_download(nid, model_id, global_download_status)
|
||||
for nid in instance.shard_assignments.node_to_runner
|
||||
if nid != node_id
|
||||
)
|
||||
|
||||
|
||||
def _all_downloads_complete(
|
||||
shard_assignments: ShardAssignments,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if ALL nodes in the instance have completed downloading the model."""
|
||||
return all(
|
||||
_node_has_download(nid, shard_assignments.model_id, global_download_status)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
|
||||
|
||||
def _any_download_complete(
|
||||
shard_assignments: ShardAssignments,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if at least one node in the instance has completed downloading the model."""
|
||||
return any(
|
||||
_node_has_download(nid, shard_assignments.model_id, global_download_status)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
|
||||
|
||||
def _ready_to_warmup(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -234,6 +349,11 @@ def _ready_to_warmup(
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
# Transfer-only instances don't go through warmup
|
||||
if shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
|
||||
@@ -43,6 +43,7 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
TransferModelToDisk,
|
||||
)
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -82,6 +83,11 @@ from exo.worker.engines.image import (
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.model_transfer import (
|
||||
coordinate_transfer,
|
||||
model_path_for_id,
|
||||
transfer_all_files,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
@@ -192,7 +198,10 @@ def main(
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
bound_instance,
|
||||
group,
|
||||
on_timeout=on_model_load_timeout,
|
||||
has_local_model=task.has_local_model,
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
@@ -508,6 +517,27 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case TransferModelToDisk() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
):
|
||||
logger.info("starting disk-to-disk model transfer")
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
model_path = model_path_for_id(
|
||||
task.shard_metadata.model_card.model_id
|
||||
)
|
||||
_, source_rank = coordinate_transfer(group, task.has_local_model)
|
||||
is_source = group.rank() == source_rank
|
||||
transfer_all_files(model_path, group, is_source)
|
||||
|
||||
logger.info("disk-to-disk model transfer complete")
|
||||
current_status = RunnerShuttingDown()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
|
||||
@@ -112,6 +112,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is True
|
||||
|
||||
|
||||
def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
@@ -157,10 +158,11 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
assert not isinstance(result, plan_mod.DownloadModel)
|
||||
|
||||
|
||||
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
def test_plan_loads_model_when_any_node_has_download_for_multi_node():
|
||||
"""
|
||||
LoadModel should not be emitted while some shards are still missing from
|
||||
the global_download_status.
|
||||
For multi-node instances, LoadModel should be emitted when at least one
|
||||
node has the model downloaded. Nodes without the model will receive it
|
||||
via MLX distributed transfer during model loading.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -185,6 +187,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A has the model — LoadModel should still fire
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -203,19 +206,42 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is True
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
|
||||
def test_plan_does_not_load_model_when_no_node_has_download():
|
||||
"""
|
||||
LoadModel should not be emitted when no node has the model downloaded.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# No node has the model
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -227,4 +253,57 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_load_model_has_local_model_false_when_node_missing_download():
|
||||
"""
|
||||
For multi-node instances, when the local node does NOT have the model
|
||||
but a peer does, LoadModel should be emitted with has_local_model=False.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
# NODE_B is the local node (bound_node_id=NODE_B), it does NOT have the model
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A has the model, NODE_B does not
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is False
|
||||
|
||||
Reference in New Issue
Block a user