mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 14:55:13 -05:00
Compare commits
20 Commits
architectu
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de5c1212ee | ||
|
|
000058aece | ||
|
|
b5c4df2700 | ||
|
|
89e3159871 | ||
|
|
2765f95da1 | ||
|
|
48cc931ba4 | ||
|
|
3409f19496 | ||
|
|
aab8aa38a9 | ||
|
|
b594c2f58b | ||
|
|
d0c0085f57 | ||
|
|
60b0ba2841 | ||
|
|
2e814f8a6f | ||
|
|
773bfb6dc4 | ||
|
|
c2e9092816 | ||
|
|
4bbdfcd872 | ||
|
|
7ac5891588 | ||
|
|
e709d409d1 | ||
|
|
e15d4ef50a | ||
|
|
bfbbea4e48 | ||
|
|
8634601959 |
37
AGENTS.md
37
AGENTS.md
@@ -194,3 +194,40 @@ GitHub's API doesn't support direct image upload for PR comments. Workaround:
|
||||
git push origin <branch>
|
||||
```
|
||||
The images still render in the PR comment because they reference the permanent commit SHA.
|
||||
|
||||
## Running exo Remotely via SSH (macOS mDNS)
|
||||
|
||||
**CRITICAL: On macOS, mDNS multicast (used for peer discovery) only works when the process runs in a proper macOS user session.** Background processes started via `nohup ... &`, `screen`, or plain SSH commands will NOT send mDNS packets and nodes will never discover each other.
|
||||
|
||||
### The Problem
|
||||
When you SSH into a Mac and run `nohup uv run exo &`, the process runs in a detached session without access to macOS multicast networking. The exo node will start but will never discover peers, even if they're on the same network.
|
||||
|
||||
### The Solution: Use `open` with a `.command` wrapper
|
||||
|
||||
Create a `.command` script that `open` will execute in the proper macOS GUI session context:
|
||||
|
||||
```bash
|
||||
# 1. Create wrapper script on the remote machine
|
||||
ssh user@remote-mac "cat > /tmp/run_exo.command << 'SCRIPT'
|
||||
#!/bin/bash
|
||||
export PATH=/opt/homebrew/bin:\$HOME/.local/bin:\$PATH
|
||||
export EXO_LIBP2P_NAMESPACE=your-namespace # must match across all nodes
|
||||
cd ~/path/to/exo
|
||||
exec uv run exo -vv 2>&1 | tee /tmp/exo.log
|
||||
SCRIPT
|
||||
chmod +x /tmp/run_exo.command"
|
||||
|
||||
# 2. Launch it via `open` (runs in macOS GUI session with proper mDNS)
|
||||
ssh user@remote-mac "open /tmp/run_exo.command"
|
||||
|
||||
# 3. Check logs
|
||||
ssh user@remote-mac "tail -f /tmp/exo.log"
|
||||
```
|
||||
|
||||
### Key Details
|
||||
- **`EXO_LIBP2P_NAMESPACE`**: All nodes in a cluster MUST use the same namespace value. The EXO.app uses a build-specific namespace (check with `ps eww <pid> | grep NAMESPACE`). If mixing dev builds with EXO.app, set the dev build's namespace to match.
|
||||
- **`open *.command`**: This is the macOS equivalent of double-clicking the script in Finder. It runs in the user's GUI session with full network access.
|
||||
- **Do NOT use**: `nohup ... &`, `screen -dm`, `tmux new-session -d`, or `sshpass`. These all create detached sessions where mDNS won't work.
|
||||
- **Killing**: `ssh user@remote-mac "pkill -f 'python.*exo'"` works fine for stopping.
|
||||
- **Dashboard**: Must be built before running: `cd dashboard && npm install && npm run build && cd ..`. Node.js is at `/opt/homebrew/bin/node` on Apple Silicon Macs.
|
||||
- **Verifying cluster**: `curl -s http://localhost:52415/state | python3 -c "import json,sys; s=json.load(sys.stdin); print(len(s['topology']['nodes']), 'nodes')"` — should show 2+ nodes.
|
||||
|
||||
11
README.md
11
README.md
@@ -72,16 +72,23 @@ There are two ways to run exo:
|
||||
|
||||
### Run from Source (macOS)
|
||||
|
||||
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
|
||||
|
||||
```bash
|
||||
nix run .#exo
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
```
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard)
|
||||
|
||||
|
||||
```bash
|
||||
brew install uv macmon node
|
||||
```
|
||||
|
||||
@@ -185,11 +185,7 @@
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
} | null;
|
||||
nodes?: Record<string, NodeInfo>;
|
||||
sharding?: "Pipeline" | "Tensor";
|
||||
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
runtime?: "MlxRing" | "MlxJaccl";
|
||||
onLaunch?: () => void;
|
||||
tags?: string[];
|
||||
apiPreview?: PlacementPreview | null;
|
||||
@@ -348,7 +348,7 @@
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
|
||||
const isRdma = $derived(runtime === "MlxJaccl");
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
@@ -575,7 +575,7 @@
|
||||
>
|
||||
{runtime === "MlxRing"
|
||||
? "MLX Ring"
|
||||
: runtime === "MlxIbv" || runtime === "MlxJaccl"
|
||||
: runtime === "MlxJaccl"
|
||||
? "MLX RDMA"
|
||||
: runtime}
|
||||
</span>
|
||||
|
||||
@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
|
||||
export interface PlacementPreview {
|
||||
model_id: string;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
instance_meta: "MlxRing" | "MlxJaccl";
|
||||
instance: unknown | null;
|
||||
memory_delta_by_node: Record<string, number> | null;
|
||||
error: string | null;
|
||||
@@ -219,7 +219,6 @@ interface RawStateResponse {
|
||||
string,
|
||||
{
|
||||
MlxRingInstance?: Instance;
|
||||
MlxIbvInstance?: Instance;
|
||||
MlxJacclInstance?: Instance;
|
||||
}
|
||||
>;
|
||||
@@ -250,6 +249,20 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// MetaInstances (declarative instance constraints)
|
||||
metaInstances?: Record<string, MetaInstanceData>;
|
||||
}
|
||||
|
||||
export interface MetaInstanceData {
|
||||
metaInstanceId: string;
|
||||
modelId: string;
|
||||
sharding: string;
|
||||
instanceMeta: string;
|
||||
minNodes: number;
|
||||
nodeIds: string[] | null;
|
||||
placementError: string | null;
|
||||
consecutiveFailures: number;
|
||||
lastFailureError: string | null;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -537,6 +550,7 @@ class AppStore {
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
|
||||
metaInstances = $state<Record<string, MetaInstanceData>>({});
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderbolt = $state<
|
||||
Record<
|
||||
@@ -895,11 +909,7 @@ class AppStore {
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
@@ -1273,6 +1283,8 @@ class AppStore {
|
||||
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
|
||||
// RDMA ctl status per node
|
||||
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
|
||||
// MetaInstances
|
||||
this.metaInstances = data.metaInstances ?? {};
|
||||
// Thunderbolt bridge cycles
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
|
||||
// Thunderbolt bridge status per node
|
||||
@@ -3044,6 +3056,7 @@ export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const metaInstances = () => appStore.metaInstances;
|
||||
export const runners = () => appStore.runners;
|
||||
export const downloads = () => appStore.downloads;
|
||||
export const nodeDisk = () => appStore.nodeDisk;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -115,7 +115,7 @@
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
|
||||
10
nix/mlx.nix
10
nix/mlx.nix
@@ -41,16 +41,16 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.6"; in
|
||||
version = let v = "0.30.7.dev20260217+50487b41"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
|
||||
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.6; sys_platform == 'darwin'",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
@@ -64,6 +64,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
|
||||
@@ -58,6 +58,21 @@
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
|
||||
buildInputs = (old.buildInputs or [ ]) ++ [
|
||||
final.nvidia-cublas
|
||||
final.nvidia-cuda-nvrtc
|
||||
final.nvidia-cudnn-cu13
|
||||
final.nvidia-nccl-cu13
|
||||
];
|
||||
preFixup = ''
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cublas}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
|
||||
'';
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
@@ -74,14 +89,25 @@
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
|
||||
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
|
||||
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
|
||||
"lib/python3.13/site-packages/mlx*"
|
||||
"lib/python3.13/site-packages/nvidia*"
|
||||
];
|
||||
|
||||
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
);
|
||||
)).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
|
||||
@@ -314,7 +314,17 @@ class DownloadCoordinator:
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
if (
|
||||
progress.downloaded_bytes.in_bytes
|
||||
>= progress.total_bytes.in_bytes
|
||||
> 0
|
||||
):
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
|
||||
@@ -254,7 +254,7 @@ def main():
|
||||
target = min(max(soft, 65535), hard)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
mp.set_start_method("spawn", force=True)
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
|
||||
@@ -71,8 +71,13 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
CreateMetaInstanceParams,
|
||||
CreateMetaInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DeleteMetaInstanceResponse,
|
||||
DistributeModelParams,
|
||||
DistributeModelResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -115,8 +120,11 @@ from exo.shared.types.claude_api import (
|
||||
from exo.shared.types.commands import (
|
||||
Command,
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
DistributeModel,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
@@ -129,7 +137,7 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -138,11 +146,13 @@ from exo.shared.types.events import (
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
@@ -276,6 +286,9 @@ class API:
|
||||
self.app.get("/instance/previews")(self.get_placement_previews)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/meta_instances")(self.list_meta_instances)
|
||||
self.app.post("/meta_instance")(self.create_meta_instance)
|
||||
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/models/add")(self.add_custom_model)
|
||||
@@ -299,18 +312,34 @@ class API:
|
||||
self.app.get("/events")(self.stream_events)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.post("/v1/models/{model_id:path}/distribute")(self.distribute_model)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
model_card = await ModelCard.load(payload.model_id)
|
||||
command = PlaceInstance(
|
||||
model_card=await ModelCard.load(payload.model_id),
|
||||
model_card=model_card,
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
)
|
||||
|
||||
# Validate placement before sending — fail fast with a clear error
|
||||
# instead of silently dropping the command in the master.
|
||||
try:
|
||||
get_instance_placements(
|
||||
command,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
@@ -522,6 +551,44 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]:
|
||||
return dict(self.state.meta_instances)
|
||||
|
||||
async def create_meta_instance(
|
||||
self, payload: CreateMetaInstanceParams
|
||||
) -> CreateMetaInstanceResponse:
|
||||
meta_instance = MetaInstance(
|
||||
model_id=payload.model_id,
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
node_ids=payload.node_ids,
|
||||
)
|
||||
command = CreateMetaInstance(meta_instance=meta_instance)
|
||||
await self._send(command)
|
||||
return CreateMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
)
|
||||
|
||||
async def delete_meta_instance(
|
||||
self, meta_instance_id: MetaInstanceId
|
||||
) -> DeleteMetaInstanceResponse:
|
||||
meta = self.state.meta_instances.get(meta_instance_id)
|
||||
if not meta:
|
||||
raise HTTPException(status_code=404, detail="MetaInstance not found")
|
||||
|
||||
# Command processor handles cascade-deleting backing instances
|
||||
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
|
||||
await self._send(command)
|
||||
|
||||
return DeleteMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
async def _token_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
@@ -541,10 +608,10 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -884,10 +951,10 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -970,10 +1037,10 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -1495,6 +1562,57 @@ class API:
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
async def distribute_model(
|
||||
self, model_id: ModelId, payload: DistributeModelParams
|
||||
) -> DistributeModelResponse:
|
||||
"""Distribute model files from one node to others via MLX distributed."""
|
||||
# Find a source node that has the model downloaded
|
||||
source_node_id: NodeId | None = None
|
||||
for nid, downloads in self.state.downloads.items():
|
||||
for dp in downloads:
|
||||
if (
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == model_id
|
||||
):
|
||||
source_node_id = nid
|
||||
break
|
||||
if source_node_id is not None:
|
||||
break
|
||||
|
||||
if source_node_id is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No node has model {model_id} downloaded",
|
||||
)
|
||||
|
||||
# Determine target nodes
|
||||
if payload.target_node_ids is not None:
|
||||
target_node_ids = [
|
||||
nid for nid in payload.target_node_ids if nid != source_node_id
|
||||
]
|
||||
else:
|
||||
target_node_ids = [
|
||||
nid for nid in self.state.topology.list_nodes() if nid != source_node_id
|
||||
]
|
||||
|
||||
if not target_node_ids:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No target nodes to distribute to",
|
||||
)
|
||||
|
||||
command = DistributeModel(
|
||||
model_id=model_id,
|
||||
source_node_id=source_node_id,
|
||||
target_node_ids=target_node_ids,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return DistributeModelResponse(
|
||||
command_id=command.command_id,
|
||||
message=f"Distributing {model_id} from {source_node_id} to {len(target_node_ids)} node(s)",
|
||||
)
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -12,11 +13,23 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.process_managers import ProcessManager
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
DistributeModel,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
@@ -36,8 +49,12 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
@@ -60,7 +77,8 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
|
||||
|
||||
@@ -84,16 +102,16 @@ class Master:
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
|
||||
self._process_managers: Sequence[ProcessManager] = [
|
||||
InstanceHealthReconciler(),
|
||||
NodeTimeoutReconciler(),
|
||||
MetaInstanceReconciler(),
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -102,15 +120,12 @@ class Master:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
tg.start_soon(self._reconcile)
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -292,6 +307,86 @@ class Master:
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateMetaInstance():
|
||||
logger.info(
|
||||
f"Creating MetaInstance for {command.meta_instance.model_id}"
|
||||
f" (min_nodes={command.meta_instance.min_nodes},"
|
||||
f" sharding={command.meta_instance.sharding})"
|
||||
)
|
||||
# Apply immediately so self.state is fresh across
|
||||
# the await below and the reconciler won't race.
|
||||
await self._apply_and_broadcast(
|
||||
MetaInstanceCreated(meta_instance=command.meta_instance)
|
||||
)
|
||||
# Immediate placement attempt for responsiveness
|
||||
model_card = await ModelCard.load(
|
||||
command.meta_instance.model_id
|
||||
)
|
||||
# Re-check: reconciler may have satisfied it during the await
|
||||
meta_id = command.meta_instance.meta_instance_id
|
||||
still_unsatisfied = any(
|
||||
m.meta_instance_id == meta_id
|
||||
for m in find_unsatisfied_meta_instances(
|
||||
self.state.meta_instances,
|
||||
self.state.instances,
|
||||
self.state.topology,
|
||||
)
|
||||
)
|
||||
if still_unsatisfied:
|
||||
result = try_place_for_meta_instance(
|
||||
command.meta_instance,
|
||||
model_card,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
self.state.tasks,
|
||||
)
|
||||
generated_events.extend(result.events)
|
||||
if result.error is not None:
|
||||
generated_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
case DeleteMetaInstance():
|
||||
backing_count = sum(
|
||||
1
|
||||
for inst in self.state.instances.values()
|
||||
if inst.meta_instance_id == command.meta_instance_id
|
||||
)
|
||||
logger.info(
|
||||
f"Deleting MetaInstance {command.meta_instance_id}"
|
||||
f" (cascade-deleting {backing_count} backing instance(s))"
|
||||
)
|
||||
generated_events.append(
|
||||
MetaInstanceDeleted(
|
||||
meta_instance_id=command.meta_instance_id
|
||||
)
|
||||
)
|
||||
# Cascade-delete backing instances atomically,
|
||||
# cancelling any active tasks first.
|
||||
for iid, inst in self.state.instances.items():
|
||||
if inst.meta_instance_id == command.meta_instance_id:
|
||||
for task in self.state.tasks.values():
|
||||
if (
|
||||
task.instance_id == iid
|
||||
and task.task_status
|
||||
in (
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
)
|
||||
):
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
generated_events.append(
|
||||
InstanceDeleted(instance_id=iid)
|
||||
)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
command,
|
||||
@@ -314,6 +409,36 @@ class Master:
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case DistributeModel():
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
model_card = await ModelCard.load(command.model_id)
|
||||
all_node_ids = set(
|
||||
[command.source_node_id] + list(command.target_node_ids)
|
||||
)
|
||||
place_command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=len(all_node_ids),
|
||||
)
|
||||
placement = place_instance(
|
||||
place_command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
required_nodes=all_node_ids,
|
||||
)
|
||||
# Mark new instances as transfer-only
|
||||
for instance_id, instance in placement.items():
|
||||
if instance_id not in self.state.instances:
|
||||
instance.shard_assignments.transfer_only = True
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
@@ -323,16 +448,19 @@ class Master:
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
command.cancelled_command_id
|
||||
in self.command_task_mapping
|
||||
):
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.cancelled_command_id
|
||||
]
|
||||
)
|
||||
)
|
||||
del self.command_task_mapping[
|
||||
command.cancelled_command_id
|
||||
]
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -341,9 +469,10 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
# rate limit to 1000 at a time
|
||||
@@ -354,31 +483,32 @@ class Master:
|
||||
):
|
||||
await self._send_event(IndexedEvent(idx=i, event=event))
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
await self._apply_and_broadcast(event)
|
||||
except ValueError as e:
|
||||
logger.opt(exception=e).warning("Error in command processor")
|
||||
|
||||
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
|
||||
async def _plan(self) -> None:
|
||||
async def _apply_and_broadcast(self, event: Event) -> None:
|
||||
"""Apply event to state, persist to disk, and broadcast to workers.
|
||||
|
||||
State is updated synchronously (before any await), so callers can
|
||||
rely on ``self.state`` reflecting this event immediately after the
|
||||
call. Python's cooperative scheduling guarantees no interleaving
|
||||
between the state read and write.
|
||||
"""
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _reconcile(self) -> None:
|
||||
while True:
|
||||
# kill broken instances
|
||||
connected_node_ids = set(self.state.topology.list_nodes())
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
await self.event_sender.send(
|
||||
InstanceDeleted(instance_id=instance_id)
|
||||
)
|
||||
break
|
||||
|
||||
# time out dead nodes
|
||||
for node_id, time in self.state.last_seen.items():
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if now - time > timedelta(seconds=30):
|
||||
logger.info(f"Manually removing node {node_id} due to inactivity")
|
||||
await self.event_sender.send(NodeTimedOut(node_id=node_id))
|
||||
|
||||
await anyio.sleep(10)
|
||||
for pm in self._process_managers:
|
||||
events = await pm.reconcile(self.state)
|
||||
for event in events:
|
||||
await self._apply_and_broadcast(event)
|
||||
await anyio.sleep(1)
|
||||
|
||||
async def _event_processor(self) -> None:
|
||||
with self.local_event_receiver as local_events:
|
||||
@@ -396,32 +526,15 @@ class Master:
|
||||
await self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
if isinstance(event, JacclSideChannelData):
|
||||
await self._apply_and_broadcast(event)
|
||||
await self._handle_jaccl_side_channel(event)
|
||||
continue
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
event.when = str(datetime.now(tz=timezone.utc))
|
||||
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
local_index = 0
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
await self._apply_and_broadcast(event)
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
@@ -453,10 +566,49 @@ class Master:
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
all_trace_data.extend(trace_data)
|
||||
|
||||
await self.event_sender.send(
|
||||
await self._apply_and_broadcast(
|
||||
TracesMerged(task_id=task_id, traces=all_trace_data)
|
||||
)
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
|
||||
"""Accumulate SideChannel contributions; when all runners for an instance
|
||||
have submitted for the same sequence, emit JacclSideChannelGathered."""
|
||||
iid = event.instance_id
|
||||
seq = event.sequence
|
||||
|
||||
if iid not in self._jaccl_pending:
|
||||
self._jaccl_pending[iid] = {}
|
||||
if seq not in self._jaccl_pending[iid]:
|
||||
self._jaccl_pending[iid][seq] = {}
|
||||
self._jaccl_pending[iid][seq][event.runner_id] = event.data
|
||||
|
||||
instance = self.state.instances.get(iid)
|
||||
if instance is None:
|
||||
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
|
||||
return
|
||||
|
||||
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
|
||||
submitted = set(self._jaccl_pending[iid][seq].keys())
|
||||
|
||||
logger.info(
|
||||
f"JACCL side channel: instance={iid} seq={seq} "
|
||||
f"submitted={len(submitted)}/{len(expected_runners)}"
|
||||
)
|
||||
|
||||
if submitted >= expected_runners:
|
||||
gathered = dict(self._jaccl_pending[iid][seq])
|
||||
del self._jaccl_pending[iid][seq]
|
||||
if not self._jaccl_pending[iid]:
|
||||
del self._jaccl_pending[iid]
|
||||
|
||||
await self._apply_and_broadcast(
|
||||
JacclSideChannelGathered(
|
||||
instance_id=iid,
|
||||
sequence=seq,
|
||||
gathered_data=gathered,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Sequence
|
||||
from exo.master.placement_utils import (
|
||||
Cycle,
|
||||
filter_cycles_by_memory,
|
||||
get_largest_cycles,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
@@ -106,23 +106,27 @@ def place_instance(
|
||||
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
|
||||
)
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
largest_cycles = get_largest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_rdma_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
largest_rdma_cycles = [
|
||||
cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl:
|
||||
if not largest_rdma_cycles:
|
||||
raise ValueError(
|
||||
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
|
||||
)
|
||||
largest_cycles = largest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
for cycle in largest_cycles
|
||||
if any(topology.node_is_leaf(node_id) for node_id in cycle)
|
||||
]
|
||||
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(node_memory[node_id].ram_available for node_id in cycle),
|
||||
start=Memory(),
|
||||
|
||||
@@ -37,11 +37,11 @@ def filter_cycles_by_memory(
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(
|
||||
def get_largest_cycles(
|
||||
cycles: list[Cycle],
|
||||
) -> list[Cycle]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
max_nodes = max(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == max_nodes]
|
||||
|
||||
|
||||
def allocate_layers_proportionally(
|
||||
|
||||
12
src/exo/master/process_managers/__init__.py
Normal file
12
src/exo/master/process_managers/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.state import State
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProcessManager(Protocol):
|
||||
"""A reconciliation step that examines state and returns corrective events."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]: ...
|
||||
62
src/exo/master/process_managers/instance_health.py
Normal file
62
src/exo/master/process_managers/instance_health.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
|
||||
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
|
||||
from exo.shared.types.state import State
|
||||
|
||||
MAX_INSTANCE_RETRIES = 3
|
||||
|
||||
|
||||
@final
|
||||
class InstanceHealthReconciler:
|
||||
"""Delete instances whose network connections are broken or whose runners have all failed."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
for instance_id, instance in state.instances.items():
|
||||
if not instance_connections_healthy(instance, state.topology):
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error="Network connection lost",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
is_failed, error_message = instance_runners_failed(
|
||||
instance, state.runners, state.node_identities
|
||||
)
|
||||
if is_failed:
|
||||
# Retry within the same instance if backed by a MetaInstance
|
||||
mid = instance.meta_instance_id
|
||||
mi = state.meta_instances.get(mid) if mid else None
|
||||
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
|
||||
logger.info(
|
||||
f"Instance {instance_id} failed (attempt"
|
||||
f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES}),"
|
||||
f" retrying: {error_message}"
|
||||
)
|
||||
events.append(
|
||||
InstanceRetrying(
|
||||
instance_id=instance_id,
|
||||
meta_instance_id=mid,
|
||||
failure_error=error_message or "Runner failed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
if mid and mi:
|
||||
logger.warning(
|
||||
f"Instance {instance_id} exceeded retry limit"
|
||||
f" ({MAX_INSTANCE_RETRIES}), deleting:"
|
||||
f" {error_message}"
|
||||
)
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error=error_message,
|
||||
)
|
||||
)
|
||||
return events
|
||||
92
src/exo/master/process_managers/meta_instance.py
Normal file
92
src/exo/master/process_managers/meta_instance.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
|
||||
MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstanceReconciler:
|
||||
"""Place instances for unsatisfied MetaInstances."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
all_events: list[Event] = []
|
||||
# Local copy for intermediate tracking — so placement of B
|
||||
# sees A's instance and doesn't double-place on same resources.
|
||||
current_instances: dict[InstanceId, Instance] = dict(state.instances)
|
||||
|
||||
unsatisfied = find_unsatisfied_meta_instances(
|
||||
state.meta_instances,
|
||||
current_instances,
|
||||
state.topology,
|
||||
)
|
||||
for meta_instance in unsatisfied:
|
||||
try:
|
||||
with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS):
|
||||
model_card = await ModelCard.load(meta_instance.model_id)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle"
|
||||
)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"ModelCard.load failed for {meta_instance.model_id}: {exc}"
|
||||
)
|
||||
error = f"Failed to load model card: {exc}"
|
||||
if meta_instance.placement_error != error:
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=error,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
result = try_place_for_meta_instance(
|
||||
meta_instance,
|
||||
model_card,
|
||||
state.topology,
|
||||
current_instances,
|
||||
state.node_memory,
|
||||
state.node_network,
|
||||
state.tasks,
|
||||
)
|
||||
# Update local instance map so next placement sees this one
|
||||
for event in result.events:
|
||||
if isinstance(event, InstanceCreated):
|
||||
logger.info(
|
||||
f"MetaInstance reconciler placed instance"
|
||||
f" {event.instance.instance_id} for"
|
||||
f" {meta_instance.model_id}"
|
||||
)
|
||||
current_instances[event.instance.instance_id] = event.instance
|
||||
all_events.extend(result.events)
|
||||
|
||||
# Emit placement failure if error differs from what's already in state
|
||||
if (
|
||||
result.error is not None
|
||||
and meta_instance.placement_error != result.error
|
||||
):
|
||||
logger.warning(
|
||||
f"MetaInstance placement failed for"
|
||||
f" {meta_instance.model_id}: {result.error}"
|
||||
)
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
return all_events
|
||||
27
src/exo/master/process_managers/node_timeout.py
Normal file
27
src/exo/master/process_managers/node_timeout.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import Event, NodeTimedOut
|
||||
from exo.shared.types.state import State
|
||||
|
||||
_DEFAULT_TIMEOUT = timedelta(seconds=30)
|
||||
|
||||
|
||||
@final
|
||||
class NodeTimeoutReconciler:
|
||||
"""Time out nodes that haven't been seen recently."""
|
||||
|
||||
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
|
||||
self.timeout = timeout
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
events: list[Event] = []
|
||||
for node_id, last_seen in state.last_seen.items():
|
||||
if now - last_seen > self.timeout:
|
||||
logger.info(f"Removing node {node_id} due to inactivity")
|
||||
events.append(NodeTimedOut(node_id=node_id))
|
||||
return events
|
||||
244
src/exo/master/reconcile.py
Normal file
244
src/exo/master/reconcile.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import NamedTuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import get_transition_events, place_instance
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
BaseInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
)
|
||||
|
||||
|
||||
class PlacementResult(NamedTuple):
|
||||
"""Result of a placement attempt: events to apply and optional error reason."""
|
||||
|
||||
events: Sequence[Event]
|
||||
error: str | None
|
||||
|
||||
|
||||
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
|
||||
"""Reconstruct ring order from shard device_rank."""
|
||||
node_ranks: list[tuple[NodeId, int]] = []
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
shard = instance.shard_assignments.runner_to_shard[runner_id]
|
||||
node_ranks.append((node_id, shard.device_rank))
|
||||
node_ranks.sort(key=lambda x: x[1])
|
||||
return [node_id for node_id, _ in node_ranks]
|
||||
|
||||
|
||||
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific IPs used by a ring instance still exist in the topology."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for node in ring:
|
||||
hosts = instance.hosts_by_node[node]
|
||||
for idx in range(n):
|
||||
host = hosts[idx]
|
||||
if host.ip in ("0.0.0.0", "198.51.100.1"):
|
||||
continue # self or placeholder
|
||||
# Real connection: node → ring[idx]. Check specific IP.
|
||||
connections = topology.get_all_connections_between(node, ring[idx])
|
||||
if not any(
|
||||
isinstance(c, SocketConnection)
|
||||
and c.sink_multiaddr.ip_address == host.ip
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
iface = instance.jaccl_devices[i][j]
|
||||
if iface is None:
|
||||
continue
|
||||
connections = topology.get_all_connections_between(ring[i], ring[j])
|
||||
if not any(
|
||||
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
|
||||
"""Check that an instance's nodes and specific connections are still in the topology."""
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
if not all(topology.contains_node(n) for n in instance_nodes):
|
||||
return False
|
||||
if len(instance_nodes) <= 1:
|
||||
return True
|
||||
match instance:
|
||||
case MlxRingInstance():
|
||||
return _ring_connections_healthy(instance, topology)
|
||||
case MlxJacclInstance():
|
||||
return _jaccl_connections_healthy(instance, topology)
|
||||
|
||||
|
||||
def instance_runners_failed(
|
||||
instance: Instance,
|
||||
runners: Mapping[RunnerId, RunnerStatus],
|
||||
node_identities: Mapping[NodeId, NodeIdentity],
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if an instance's runners have all reached terminal failure states.
|
||||
|
||||
Returns ``(True, error_message)`` when ALL runners are terminal
|
||||
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
|
||||
|
||||
Returns ``(False, None)`` when runners are still active, haven't reported
|
||||
yet, or all gracefully shut down (no ``RunnerFailed``).
|
||||
"""
|
||||
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
|
||||
|
||||
if not instance_runner_ids:
|
||||
return False, None
|
||||
|
||||
# Build reverse mapping: runner_id -> node_id
|
||||
runner_to_node: dict[RunnerId, NodeId] = {
|
||||
runner_id: node_id
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
}
|
||||
|
||||
has_any_failed = False
|
||||
error_messages: list[str] = []
|
||||
|
||||
for runner_id in instance_runner_ids:
|
||||
status = runners.get(runner_id)
|
||||
if status is None:
|
||||
# Runner hasn't reported yet — instance is still starting
|
||||
return False, None
|
||||
if isinstance(status, RunnerFailed):
|
||||
has_any_failed = True
|
||||
if status.error_message:
|
||||
node_id = runner_to_node.get(runner_id)
|
||||
name = (
|
||||
node_identities[node_id].friendly_name
|
||||
if node_id and node_id in node_identities
|
||||
else node_id or "unknown"
|
||||
)
|
||||
error_messages.append(f"{name}: {status.error_message}")
|
||||
elif isinstance(status, RunnerShutdown):
|
||||
pass # Terminal but not a failure indicator on its own
|
||||
else:
|
||||
# Runner is still active (connecting, loading, running, etc.)
|
||||
return False, None
|
||||
|
||||
if has_any_failed:
|
||||
return True, "; ".join(error_messages) if error_messages else "Runner failed"
|
||||
|
||||
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
|
||||
return False, None
|
||||
|
||||
|
||||
def instance_satisfies_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
instance: Instance,
|
||||
) -> bool:
|
||||
"""Check if a single instance satisfies a meta-instance's constraints.
|
||||
|
||||
This is a pure constraint check (model, min_nodes, node_ids).
|
||||
Use ``instance_connections_healthy`` separately for topology health.
|
||||
"""
|
||||
if instance.shard_assignments.model_id != meta_instance.model_id:
|
||||
return False
|
||||
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
|
||||
if len(instance_nodes) < meta_instance.min_nodes:
|
||||
return False
|
||||
|
||||
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
|
||||
instance_nodes
|
||||
)
|
||||
|
||||
|
||||
def find_unsatisfied_meta_instances(
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
topology: Topology,
|
||||
) -> Sequence[MetaInstance]:
|
||||
"""Return meta-instances that have no healthy backing instance."""
|
||||
unsatisfied: list[MetaInstance] = []
|
||||
for meta_id, meta_instance in meta_instances.items():
|
||||
has_healthy_backing = any(
|
||||
instance.meta_instance_id == meta_id
|
||||
and instance_connections_healthy(instance, topology)
|
||||
for instance in instances.values()
|
||||
)
|
||||
if not has_healthy_backing:
|
||||
unsatisfied.append(meta_instance)
|
||||
return unsatisfied
|
||||
|
||||
|
||||
def try_place_for_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
model_card: ModelCard,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> PlacementResult:
|
||||
"""Try to place an instance satisfying the meta-instance constraints.
|
||||
|
||||
Returns a :class:`PlacementResult` with events on success, or an error
|
||||
reason on failure.
|
||||
"""
|
||||
command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=meta_instance.sharding,
|
||||
instance_meta=meta_instance.instance_meta,
|
||||
min_nodes=meta_instance.min_nodes,
|
||||
)
|
||||
try:
|
||||
target_instances = place_instance(
|
||||
command,
|
||||
topology,
|
||||
current_instances,
|
||||
node_memory,
|
||||
node_network,
|
||||
required_nodes=(
|
||||
set(meta_instance.node_ids) if meta_instance.node_ids else None
|
||||
),
|
||||
)
|
||||
# Tag the new instance with meta_instance_id
|
||||
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
|
||||
if new_instance_ids:
|
||||
new_id = next(iter(new_instance_ids))
|
||||
target_instances[new_id] = target_instances[new_id].model_copy(
|
||||
update={"meta_instance_id": meta_instance.meta_instance_id}
|
||||
)
|
||||
return PlacementResult(
|
||||
events=list(
|
||||
get_transition_events(current_instances, target_instances, tasks)
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
|
||||
)
|
||||
return PlacementResult(events=[], error=str(e))
|
||||
778
src/exo/master/tests/test_meta_instance_edge_cases.py
Normal file
778
src/exo/master/tests/test_meta_instance_edge_cases.py
Normal file
@@ -0,0 +1,778 @@
|
||||
"""Edge-case and regression tests for MetaInstance lifecycle, concurrent operations, and error handling."""
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.process_managers.instance_health import (
|
||||
MAX_INSTANCE_RETRIES,
|
||||
InstanceHealthReconciler,
|
||||
)
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NodeIdentity
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import LoadModel, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerReady,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
# --- Helpers (copied from test_reconcile.py for independence) ---
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
consecutive_failures: int = 0,
|
||||
last_failure_error: str | None = None,
|
||||
placement_error: str | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
consecutive_failures=consecutive_failures,
|
||||
last_failure_error=last_failure_error,
|
||||
placement_error=placement_error,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 1. MetaInstance lifecycle edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_meta_instance_model_is_frozen():
|
||||
"""MetaInstance should be immutable (frozen model)."""
|
||||
meta = _meta_instance()
|
||||
try:
|
||||
meta.model_id = ModelId("something-else")
|
||||
raise AssertionError("Should have raised")
|
||||
except Exception:
|
||||
pass # Expected — frozen model
|
||||
|
||||
|
||||
def test_meta_instance_created_then_deleted_roundtrip():
|
||||
"""Create and delete a MetaInstance through apply — state should be clean."""
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta))
|
||||
)
|
||||
assert meta.meta_instance_id in state.meta_instances
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=1, event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
),
|
||||
)
|
||||
assert meta.meta_instance_id not in state.meta_instances
|
||||
assert len(state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_delete_nonexistent_meta_instance_is_safe():
|
||||
"""Deleting a MetaInstance that doesn't exist should not crash."""
|
||||
state = State()
|
||||
event = MetaInstanceDeleted(meta_instance_id=MetaInstanceId("nonexistent"))
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_placement_failed_for_nonexistent_meta_instance_is_safe():
|
||||
"""MetaInstancePlacementFailed for unknown ID should not crash."""
|
||||
state = State()
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=MetaInstanceId("nonexistent"),
|
||||
reason="test",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_multiple_meta_instances_for_same_model():
|
||||
"""Multiple MetaInstances for the same model are tracked independently."""
|
||||
state = State()
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta_a))
|
||||
)
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=1, event=MetaInstanceCreated(meta_instance=meta_b))
|
||||
)
|
||||
assert len(state.meta_instances) == 2
|
||||
assert meta_a.meta_instance_id in state.meta_instances
|
||||
assert meta_b.meta_instance_id in state.meta_instances
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 2. Retry logic edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_retry_counter_resets_on_successful_instance_creation():
|
||||
"""When a new instance is created for a meta-instance, failures should reset."""
|
||||
meta = _meta_instance(consecutive_failures=2, last_failure_error="old")
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
# last_failure_error is preserved (for UI display)
|
||||
assert mi.last_failure_error == "old"
|
||||
|
||||
|
||||
async def test_retry_count_increments_through_full_cycle():
|
||||
"""Walk through MAX_INSTANCE_RETRIES worth of retries, then verify delete."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=topology,
|
||||
)
|
||||
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
for idx, i in enumerate(range(MAX_INSTANCE_RETRIES)):
|
||||
# Simulate runners failing
|
||||
state_with_runners = state.model_copy(
|
||||
update={"runners": {runner_ids[0]: RunnerFailed(error_message=f"fail-{i}")}}
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state_with_runners)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying), f"iteration {i}"
|
||||
state = apply(state, IndexedEvent(idx=idx, event=events[0]))
|
||||
|
||||
# After MAX_INSTANCE_RETRIES retries, failure counter should be at max
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == MAX_INSTANCE_RETRIES
|
||||
|
||||
# Next failure should result in deletion
|
||||
state_with_runners = state.model_copy(
|
||||
update={"runners": {runner_ids[0]: RunnerFailed(error_message="final")}}
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state_with_runners)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_respects_exact_limit():
|
||||
"""At exactly MAX_INSTANCE_RETRIES, reconciler should delete, not retry."""
|
||||
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_at_limit_minus_one_retries():
|
||||
"""At MAX_INSTANCE_RETRIES - 1, reconciler should still retry."""
|
||||
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES - 1)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 3. Error handling edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_runners_failed_with_empty_error_message():
|
||||
"""RunnerFailed with empty error_message should still report as failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
# Empty error message means we get the fallback
|
||||
assert error == "Runner failed"
|
||||
|
||||
|
||||
def test_runners_failed_with_none_error_message():
|
||||
"""RunnerFailed with None error_message should still report as failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message=None)
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error == "Runner failed"
|
||||
|
||||
|
||||
def test_runners_failed_collects_all_error_messages():
|
||||
"""With multiple failed runners, all error messages should be collected."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM on GPU 0"),
|
||||
runner_ids[1]: RunnerFailed(error_message="OOM on GPU 1"),
|
||||
runner_ids[2]: RunnerFailed(error_message="OOM on GPU 2"),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM on GPU 0" in error
|
||||
assert "OOM on GPU 1" in error
|
||||
assert "OOM on GPU 2" in error
|
||||
|
||||
|
||||
def test_runners_failed_includes_friendly_name():
|
||||
"""Error messages should include node friendly names when available."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
node_id = NodeId("node-a")
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {runner_ids[0]: RunnerFailed(error_message="OOM")}
|
||||
identities = {node_id: NodeIdentity(friendly_name="My Mac Studio")}
|
||||
is_failed, error = instance_runners_failed(inst, runners, identities)
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "My Mac Studio" in error
|
||||
|
||||
|
||||
def test_instance_retrying_for_missing_instance_is_safe():
|
||||
"""InstanceRetrying for an instance not in state should not crash.
|
||||
|
||||
NOTE: When the instance is missing, the handler returns early WITHOUT
|
||||
incrementing the MetaInstance failure counter. This means stale retry
|
||||
events for already-deleted instances are silently dropped. This is
|
||||
acceptable since the InstanceDeleted handler already increments failures.
|
||||
"""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceRetrying(
|
||||
instance_id=InstanceId("nonexistent"),
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Does not crash, but failure count is NOT incremented (early return)
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 4. Backward compatibility
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_instance_without_meta_instance_id_works():
|
||||
"""Instances created without meta_instance_id should still function normally."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert inst.meta_instance_id is None
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_instance_deleted_without_meta_does_not_affect_meta_instances():
|
||||
"""Deleting an instance without meta_instance_id should not affect meta_instances."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"]) # no meta_instance_id
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0 # unchanged
|
||||
|
||||
|
||||
def test_satisfies_ignores_meta_instance_id_binding():
|
||||
"""instance_satisfies_meta_instance checks constraints only, not binding."""
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"]) # no meta_instance_id set
|
||||
# Should match on constraints (model, min_nodes) regardless of binding
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_find_unsatisfied_uses_binding_not_constraints():
|
||||
"""find_unsatisfied checks meta_instance_id binding, not just constraint matching."""
|
||||
meta = _meta_instance()
|
||||
# Instance matches constraints but is NOT bound to this meta_instance
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {iid: inst}, topology
|
||||
)
|
||||
# Should be unsatisfied because instance.meta_instance_id != meta.meta_instance_id
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 5. Concurrent / multi-instance scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def test_health_reconciler_handles_multiple_failing_instances():
|
||||
"""Multiple instances failing simultaneously should each get their own event."""
|
||||
meta_a = _meta_instance()
|
||||
meta_b = _meta_instance()
|
||||
iid_a, inst_a = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
iid_b, inst_b = _instance(
|
||||
node_ids=["node-b"], meta_instance_id=meta_b.meta_instance_id
|
||||
)
|
||||
runner_ids_a = list(inst_a.shard_assignments.node_to_runner.values())
|
||||
runner_ids_b = list(inst_b.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
instances={iid_a: inst_a, iid_b: inst_b},
|
||||
runners={
|
||||
runner_ids_a[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids_b[0]: RunnerFailed(error_message="OOM"),
|
||||
},
|
||||
topology=_topology("node-a", "node-b"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 2
|
||||
# Both should be InstanceRetrying since failures < MAX
|
||||
assert all(isinstance(e, InstanceRetrying) for e in events)
|
||||
instance_ids = {e.instance_id for e in events} # type: ignore[union-attr]
|
||||
assert instance_ids == {iid_a, iid_b}
|
||||
|
||||
|
||||
async def test_health_reconciler_mixed_healthy_and_failing():
|
||||
"""Only failing instances should produce events; healthy ones should not."""
|
||||
meta_healthy = _meta_instance()
|
||||
meta_failing = _meta_instance()
|
||||
iid_h, inst_h = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta_healthy.meta_instance_id
|
||||
)
|
||||
iid_f, inst_f = _instance(
|
||||
node_ids=["node-b"], meta_instance_id=meta_failing.meta_instance_id
|
||||
)
|
||||
runner_ids_h = list(inst_h.shard_assignments.node_to_runner.values())
|
||||
runner_ids_f = list(inst_f.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_healthy.meta_instance_id: meta_healthy,
|
||||
meta_failing.meta_instance_id: meta_failing,
|
||||
},
|
||||
instances={iid_h: inst_h, iid_f: inst_f},
|
||||
runners={
|
||||
runner_ids_h[0]: RunnerReady(),
|
||||
runner_ids_f[0]: RunnerFailed(error_message="crash"),
|
||||
},
|
||||
topology=_topology("node-a", "node-b"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid_f
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_empty_state():
|
||||
"""MetaInstanceReconciler with no meta_instances should produce no events."""
|
||||
state = State()
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 6. Placement error tracking
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_placement_failed_sets_error():
|
||||
"""MetaInstancePlacementFailed should set placement_error on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
reason="Not enough memory",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.placement_error == "Not enough memory"
|
||||
|
||||
|
||||
def test_instance_created_clears_placement_error():
|
||||
"""InstanceCreated should clear placement_error on the MetaInstance."""
|
||||
meta = _meta_instance(placement_error="Not enough memory")
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
def test_placement_error_does_not_increment_failures():
|
||||
"""Placement failures should only set placement_error, not increment consecutive_failures."""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
reason="No resources",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.placement_error == "No resources"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 7. State serialization roundtrip
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_state_with_meta_instances_serializes():
|
||||
"""State with meta_instances should serialize and deserialize correctly."""
|
||||
meta = _meta_instance(consecutive_failures=2, last_failure_error="test")
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
json_str = state.model_dump_json()
|
||||
restored = State.model_validate_json(json_str)
|
||||
assert meta.meta_instance_id in restored.meta_instances
|
||||
mi = restored.meta_instances[meta.meta_instance_id]
|
||||
assert mi.model_id == meta.model_id
|
||||
assert mi.consecutive_failures == 2
|
||||
assert mi.last_failure_error == "test"
|
||||
assert iid in restored.instances
|
||||
assert restored.instances[iid].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. MetaInstanceReconciler error handling
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_model_load_error_emits_placement_failed(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""When ModelCard.load raises, reconciler emits MetaInstancePlacementFailed."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta = _meta_instance()
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
async def _failing_load(_model_id: ModelId) -> ModelCard:
|
||||
raise RuntimeError("Network error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
|
||||
assert len(placement_failed) == 1
|
||||
assert "Failed to load model card" in placement_failed[0].reason
|
||||
assert meta.meta_instance_id == placement_failed[0].meta_instance_id
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_model_load_error_skips_dedup(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""When ModelCard.load error matches existing placement_error, no duplicate event."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta = _meta_instance(placement_error="Failed to load model card: Network error")
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
async def _failing_load(_model_id: ModelId) -> ModelCard:
|
||||
raise RuntimeError("Network error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
# Error matches existing placement_error, so no duplicate event emitted
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_continues_after_error(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""Reconciler should continue to next meta-instance after one fails to load."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta_a = _meta_instance(model_id="org/model-a")
|
||||
meta_b = _meta_instance(model_id="org/model-b")
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _load_second_fails(model_id: ModelId) -> ModelCard:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError(f"Cannot load {model_id}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_load_second_fails)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
# Both meta-instances should have been attempted (not short-circuited)
|
||||
assert call_count == 2
|
||||
# Both should have placement failed events
|
||||
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
|
||||
assert len(placement_failed) == 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. Cascade delete with task cancellation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_cascade_delete_cancels_active_tasks():
|
||||
"""Deleting a MetaInstance should cancel tasks on backing instances.
|
||||
|
||||
Regression test: previously, cascade-deleting backing instances via
|
||||
DeleteMetaInstance did not emit TaskStatusUpdated(Cancelled) for active
|
||||
tasks, leaving orphaned task references in state.
|
||||
"""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
task_id = TaskId()
|
||||
task = LoadModel(task_id=task_id, instance_id=iid, task_status=TaskStatus.Running)
|
||||
|
||||
# Build state with meta-instance, backing instance, and active task
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
tasks={task_id: task},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
|
||||
# Simulate the cascade-delete event sequence produced by main.py:
|
||||
# 1. MetaInstanceDeleted
|
||||
# 2. TaskStatusUpdated(Cancelled) for active tasks
|
||||
# 3. InstanceDeleted
|
||||
idx = 0
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=idx,
|
||||
event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id),
|
||||
),
|
||||
)
|
||||
idx += 1
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=idx,
|
||||
event=TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Cancelled),
|
||||
),
|
||||
)
|
||||
idx += 1
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(idx=idx, event=InstanceDeleted(instance_id=iid)),
|
||||
)
|
||||
|
||||
# Verify everything is cleaned up
|
||||
assert len(state.meta_instances) == 0
|
||||
assert len(state.instances) == 0
|
||||
assert state.tasks[task_id].task_status == TaskStatus.Cancelled
|
||||
|
||||
|
||||
def test_cascade_delete_skips_completed_tasks():
|
||||
"""Cascade delete should only cancel Pending/Running tasks, not completed ones."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
|
||||
running_task_id = TaskId()
|
||||
completed_task_id = TaskId()
|
||||
running_task = LoadModel(
|
||||
task_id=running_task_id, instance_id=iid, task_status=TaskStatus.Running
|
||||
)
|
||||
completed_task = LoadModel(
|
||||
task_id=completed_task_id, instance_id=iid, task_status=TaskStatus.Complete
|
||||
)
|
||||
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
tasks={running_task_id: running_task, completed_task_id: completed_task},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
|
||||
# Only the running task should be cancelled — we verify the logic pattern
|
||||
# by checking which tasks are Pending or Running
|
||||
active_tasks = [
|
||||
t
|
||||
for t in state.tasks.values()
|
||||
if t.instance_id == iid
|
||||
and t.task_status in (TaskStatus.Pending, TaskStatus.Running)
|
||||
]
|
||||
assert len(active_tasks) == 1
|
||||
assert active_tasks[0].task_id == running_task_id
|
||||
@@ -3,10 +3,10 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_largest_cycles,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_node_memory,
|
||||
@@ -143,7 +143,7 @@ def test_filter_multiple_cycles_by_memory():
|
||||
}
|
||||
|
||||
|
||||
def test_get_smallest_cycles():
|
||||
def test_get_largest_cycles():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
@@ -175,12 +175,12 @@ def test_get_smallest_cycles():
|
||||
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
largest_cycles = get_largest_cycles(cycles)
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
assert len(smallest_cycles[0]) == 2
|
||||
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
assert len(largest_cycles) == 1
|
||||
assert len(largest_cycles[0]) == 3
|
||||
assert set(n for n in largest_cycles[0]) == {node_a_id, node_b_id, node_c_id}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
742
src/exo/master/tests/test_reconcile.py
Normal file
742
src/exo/master/tests/test_reconcile.py
Normal file
@@ -0,0 +1,742 @@
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
|
||||
|
||||
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
|
||||
between consecutive nodes (including wrap-around).
|
||||
"""
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
# Build hosts_by_node with IPs matching _topology() convention:
|
||||
# node at index idx has IP 10.0.0.{idx+1}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# --- instance_satisfies_meta_instance (pure constraint matching) ---
|
||||
|
||||
|
||||
def test_satisfies_matching_model():
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_not_satisfies_wrong_model():
|
||||
meta = _meta_instance("test-org/model-a")
|
||||
_, inst = _instance("test-org/model-b")
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_missing_required_node():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-c")])
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_fewer_than_min_nodes():
|
||||
meta = _meta_instance(min_nodes=3)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_satisfies_with_node_ids_specified():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
# --- instance_connections_healthy ---
|
||||
|
||||
|
||||
def test_healthy_single_node_present():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_single_node_missing():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = Topology() # empty
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_two_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_two_node_edge_removed():
|
||||
"""Nodes present but edge removed — ring broken."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", connect=False)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_two_node_ip_changed():
|
||||
"""Edge exists but with a different IP than instance was configured with."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
# Build topology with different IPs than _instance() expects
|
||||
topology = Topology()
|
||||
topology.add_node(NodeId("node-a"))
|
||||
topology.add_node(NodeId("node-b"))
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-a"),
|
||||
sink=NodeId("node-b"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-b"),
|
||||
sink=NodeId("node-a"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_three_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_three_node_one_edge_removed():
|
||||
"""Remove one edge from a three-node ring — instance unhealthy."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
# Build topology with one direction of one edge missing
|
||||
topology = Topology()
|
||||
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
|
||||
for n in nodes:
|
||||
topology.add_node(n)
|
||||
# Add all edges except node-a → node-b
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[1],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[0],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
# Missing: node-a → node-b (ip 10.0.0.2)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_node_missing_from_topology():
|
||||
"""Instance has a node that's not in the topology at all."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a") # node-b not present
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_extra_nodes_in_topology():
|
||||
"""Extra nodes in topology don't affect instance health."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
# --- find_unsatisfied_meta_instances ---
|
||||
|
||||
|
||||
def test_unsatisfied_no_meta_instances():
|
||||
result = find_unsatisfied_meta_instances({}, {}, Topology())
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_satisfied():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_not_satisfied():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
id_a, inst_a = _instance("test-org/model-y")
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_mix():
|
||||
meta_satisfied = _meta_instance("test-org/model-a")
|
||||
meta_unsatisfied = _meta_instance("test-org/model-b")
|
||||
id_a, inst_a = _instance(
|
||||
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_satisfied.meta_instance_id: meta_satisfied,
|
||||
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
|
||||
},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_unsatisfied]
|
||||
|
||||
|
||||
def test_unsatisfied_node_disconnect():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a") # node-b disconnected
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_edge_break():
|
||||
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_idempotent():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
topology = _topology("node-a")
|
||||
meta_instances = {meta.meta_instance_id: meta}
|
||||
instances: dict[InstanceId, MlxRingInstance] = {}
|
||||
result_1 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
result_2 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
assert result_1 == result_2
|
||||
|
||||
|
||||
def test_unsatisfied_exclusive_binding():
|
||||
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
id_inst, inst = _instance(
|
||||
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
{id_inst: inst},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_b]
|
||||
|
||||
|
||||
# --- apply handlers ---
|
||||
|
||||
|
||||
def test_apply_meta_instance_created():
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
event = MetaInstanceCreated(meta_instance=meta)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id in new_state.meta_instances
|
||||
assert new_state.meta_instances[meta.meta_instance_id] == meta
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted():
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted_clears_failure_info():
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
|
||||
)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
# --- instance_runners_failed ---
|
||||
|
||||
|
||||
def test_runners_failed_all_failed():
|
||||
"""All runners in RunnerFailed -> instance is failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="OOM")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM" in error
|
||||
|
||||
|
||||
def test_runners_failed_mixed_failed_shutdown():
|
||||
"""One Failed + one Shutdown = failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="crash"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "crash" in error
|
||||
|
||||
|
||||
def test_runners_not_failed_all_shutdown():
|
||||
"""All Shutdown (graceful) = not a failure."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerShutdown() for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_still_active():
|
||||
"""Some runners still active = not failed yet."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerLoading(),
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_no_status():
|
||||
"""Runner not yet reported = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
is_failed, _ = instance_runners_failed(inst, {}, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_healthy():
|
||||
"""Runners in Ready state = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerReady() for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
# --- failure tracking in apply_instance_deleted ---
|
||||
|
||||
|
||||
def test_apply_instance_deleted_tracks_failure():
|
||||
"""InstanceDeleted with failure_error increments meta instance failure count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "Runner OOM"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_increments_failure():
|
||||
"""Subsequent failures increment the counter."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
|
||||
)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="new error")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 3
|
||||
assert mi.last_failure_error == "new error"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_no_failure_no_tracking():
|
||||
"""InstanceDeleted without failure_error does not track."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_apply_instance_deleted_orphan_no_tracking():
|
||||
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
state = State(instances={iid: inst})
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
# --- InstanceRetrying ---
|
||||
|
||||
|
||||
def test_apply_instance_retrying_removes_runners():
|
||||
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners=runners,
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="OOM",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Instance still exists
|
||||
assert iid in new_state.instances
|
||||
# Runners removed
|
||||
assert runner_ids[0] not in new_state.runners
|
||||
assert runner_ids[1] not in new_state.runners
|
||||
|
||||
|
||||
def test_apply_instance_retrying_increments_failure():
|
||||
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "crash"
|
||||
|
||||
|
||||
def test_apply_instance_retrying_skips_missing_runners():
|
||||
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
# No runners in state at all
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
# Should not raise
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert iid in new_state.instances
|
||||
|
||||
|
||||
def test_apply_instance_created_resets_failure_counter():
|
||||
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 3, "last_failure_error": "old error"}
|
||||
)
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceCreated(instance=inst)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.last_failure_error == "old error"
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
# --- InstanceHealthReconciler retry-vs-delete ---
|
||||
|
||||
|
||||
async def test_health_reconciler_retries_when_under_limit():
|
||||
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid
|
||||
assert events[0].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_when_limit_reached():
|
||||
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
|
||||
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_without_meta_instance():
|
||||
"""Instances without a MetaInstance are deleted immediately on runner failure."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_network_failure_always_deletes():
|
||||
"""Network failure always triggers InstanceDeleted regardless of retry count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=_topology("node-a"), # node-b missing
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
assert events[0].failure_error == "Network connection lost"
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -12,6 +12,12 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -28,6 +34,7 @@ from exo.shared.types.events import (
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
@@ -66,12 +73,22 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| InputChunkReceived()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
| JacclSideChannelData()
|
||||
| JacclSideChannelGathered()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceRetrying():
|
||||
return apply_instance_retrying(event, state)
|
||||
case MetaInstanceCreated():
|
||||
return apply_meta_instance_created(event, state)
|
||||
case MetaInstanceDeleted():
|
||||
return apply_meta_instance_deleted(event, state)
|
||||
case MetaInstancePlacementFailed():
|
||||
return apply_meta_instance_placement_failed(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDownloadProgress():
|
||||
@@ -174,20 +191,123 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
def _update_meta_instance(
|
||||
state: State, mid: MetaInstanceId, **fields: object
|
||||
) -> Mapping[MetaInstanceId, MetaInstance]:
|
||||
mi = state.meta_instances[mid]
|
||||
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
|
||||
|
||||
|
||||
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
instance = event.instance
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
instance.instance_id: instance,
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
# Reset failure tracking when a new instance is created for a meta-instance
|
||||
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
|
||||
mi = state.meta_instances[instance.meta_instance_id]
|
||||
if mi.placement_error is not None or mi.consecutive_failures > 0:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
instance.meta_instance_id,
|
||||
placement_error=None,
|
||||
consecutive_failures=0,
|
||||
)
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
deleted_instance = state.instances.get(event.instance_id)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
|
||||
# Track failure on the MetaInstance itself
|
||||
if (
|
||||
event.failure_error
|
||||
and deleted_instance
|
||||
and deleted_instance.meta_instance_id
|
||||
and deleted_instance.meta_instance_id in state.meta_instances
|
||||
):
|
||||
mid = deleted_instance.meta_instance_id
|
||||
mi = state.meta_instances[mid]
|
||||
update["meta_instances"] = {
|
||||
**state.meta_instances,
|
||||
mid: mi.model_copy(
|
||||
update={
|
||||
"consecutive_failures": mi.consecutive_failures + 1,
|
||||
"last_failure_error": event.failure_error,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
|
||||
"""Runners failed but retry limit not reached — remove runners, keep instance."""
|
||||
instance = state.instances.get(event.instance_id)
|
||||
if instance is None:
|
||||
# Instance was already deleted (e.g. cascade from DeleteMetaInstance).
|
||||
# The InstanceDeleted handler already incremented consecutive_failures
|
||||
# on the MetaInstance, so skipping here avoids double-counting.
|
||||
return state
|
||||
|
||||
# Remove all runners belonging to this instance from state
|
||||
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
|
||||
}
|
||||
|
||||
update: dict[str, object] = {"runners": new_runners}
|
||||
|
||||
# Increment failure count on the MetaInstance
|
||||
if event.meta_instance_id in state.meta_instances:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
event.meta_instance_id,
|
||||
consecutive_failures=state.meta_instances[
|
||||
event.meta_instance_id
|
||||
].consecutive_failures
|
||||
+ 1,
|
||||
last_failure_error=event.failure_error,
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
**state.meta_instances,
|
||||
event.meta_instance.meta_instance_id: event.meta_instance,
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
mid: mi
|
||||
for mid, mi in state.meta_instances.items()
|
||||
if mid != event.meta_instance_id
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_placement_failed(
|
||||
event: MetaInstancePlacementFailed, state: State
|
||||
) -> State:
|
||||
if event.meta_instance_id not in state.meta_instances:
|
||||
return state
|
||||
return state.model_copy(
|
||||
update={
|
||||
"meta_instances": _update_meta_instance(
|
||||
state, event.meta_instance_id, placement_error=event.reason
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
|
||||
@@ -165,7 +165,6 @@ def is_custom_card(model_id: ModelId) -> bool:
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
model_type: str | None = None
|
||||
architectures: list[str] | None = None
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
layer_count: int = Field(
|
||||
@@ -201,7 +200,6 @@ class ConfigData(BaseModel):
|
||||
return data
|
||||
|
||||
for field in [
|
||||
"model_type",
|
||||
"architectures",
|
||||
"hidden_size",
|
||||
"num_hidden_layers",
|
||||
|
||||
@@ -6,7 +6,7 @@ from uuid import uuid4
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -262,6 +262,26 @@ class DeleteInstanceResponse(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class CreateMetaInstanceParams(BaseModel):
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
|
||||
|
||||
class CreateMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class DeleteMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
@@ -366,6 +386,15 @@ class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class DistributeModelParams(CamelCaseModel):
|
||||
target_node_ids: list[NodeId] | None = None # None = all connected nodes
|
||||
|
||||
|
||||
class DistributeModelResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
message: str
|
||||
|
||||
|
||||
class TraceEventResponse(CamelCaseModel):
|
||||
name: str
|
||||
start_us: int
|
||||
|
||||
@@ -6,7 +6,8 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -52,6 +53,14 @@ class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class CreateMetaInstance(BaseCommand):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class DeleteMetaInstance(BaseCommand):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -81,6 +90,14 @@ class CancelDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class DistributeModel(BaseCommand):
|
||||
"""Distribute model files from one node to others via MLX distributed."""
|
||||
|
||||
model_id: ModelId
|
||||
source_node_id: NodeId
|
||||
target_node_ids: list[NodeId]
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
@@ -94,8 +111,11 @@ Command = (
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| CreateMetaInstance
|
||||
| DeleteMetaInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
| DistributeModel
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,10 @@ class CommandId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class MetaInstanceId(Id):
|
||||
"""Identifier for a MetaInstance."""
|
||||
|
||||
|
||||
class Host(CamelCaseModel):
|
||||
ip: str
|
||||
port: int
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from typing import Annotated, final
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BeforeValidator, Field, PlainSerializer
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -14,6 +17,28 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
|
||||
|
||||
def _decode_base64_bytes(v: bytes | str) -> bytes:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
return base64.b64decode(v)
|
||||
|
||||
|
||||
def _encode_base64_bytes(v: bytes) -> str:
|
||||
return base64.b64encode(v).decode("ascii")
|
||||
|
||||
|
||||
Base64Bytes = Annotated[
|
||||
bytes,
|
||||
BeforeValidator(_decode_base64_bytes),
|
||||
PlainSerializer(_encode_base64_bytes, return_type=str),
|
||||
]
|
||||
"""bytes that serialize to/from base64 strings in JSON.
|
||||
|
||||
Needed because TaggedModel's wrap validator converts JSON→Python validation
|
||||
context, which breaks strict-mode bytes deserialization from JSON strings.
|
||||
"""
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
"""
|
||||
Newtype around `ID`
|
||||
@@ -66,6 +91,30 @@ class InstanceCreated(BaseEvent):
|
||||
|
||||
class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
failure_error: str | None = None
|
||||
|
||||
|
||||
class MetaInstanceCreated(BaseEvent):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class MetaInstanceDeleted(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstancePlacementFailed(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
reason: str
|
||||
|
||||
|
||||
@final
|
||||
class InstanceRetrying(BaseEvent):
|
||||
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
meta_instance_id: MetaInstanceId
|
||||
failure_error: str
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
@@ -132,6 +181,25 @@ class TracesMerged(BaseEvent):
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelData(BaseEvent):
|
||||
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
runner_id: RunnerId
|
||||
sequence: int
|
||||
data: Base64Bytes
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelGathered(BaseEvent):
|
||||
"""Gathered result of a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
sequence: int
|
||||
gathered_data: Mapping[RunnerId, Base64Bytes]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -141,6 +209,10 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceRetrying
|
||||
| MetaInstanceCreated
|
||||
| MetaInstanceDeleted
|
||||
| MetaInstancePlacementFailed
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
@@ -152,6 +224,8 @@ Event = (
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
| TracesMerged
|
||||
| JacclSideChannelData
|
||||
| JacclSideChannelGathered
|
||||
)
|
||||
|
||||
|
||||
|
||||
25
src/exo/shared/types/meta_instance.py
Normal file
25
src/exo/shared/types/meta_instance.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstance(FrozenModel):
|
||||
"""Declarative constraint: ensure an instance matching these parameters always exists."""
|
||||
|
||||
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
# Failure tracking
|
||||
placement_error: str | None = None
|
||||
consecutive_failures: int = 0
|
||||
last_failure_error: str | None = None
|
||||
@@ -6,7 +6,8 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import (
|
||||
DiskUsage,
|
||||
MemoryUsage,
|
||||
@@ -41,6 +42,7 @@ class State(CamelCaseModel):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
instances: Mapping[InstanceId, Instance] = {}
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
|
||||
@@ -42,7 +42,7 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
has_local_model: bool = Field(default=True)
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
@@ -61,7 +61,7 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
class CancelTask(BaseTask): # emitted by Worker when master cancels a task
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -82,6 +82,13 @@ class ImageEdits(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TransferModelToDisk(BaseTask): # emitted by Worker
|
||||
"""Transfer all model files from source to receivers' disk via MLX distributed."""
|
||||
|
||||
shard_metadata: ShardMetadata
|
||||
has_local_model: bool = Field(default=True)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -91,6 +98,7 @@ Task = (
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| TransferModelToDisk
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| CancelTask
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,6 +19,7 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
meta_instance_id: MetaInstanceId | None = None
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -84,6 +84,7 @@ class ShardAssignments(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
runner_to_shard: Mapping[RunnerId, ShardMetadata]
|
||||
node_to_runner: Mapping[NodeId, RunnerId]
|
||||
transfer_only: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_runners_exist(self) -> "ShardAssignments":
|
||||
|
||||
@@ -125,9 +125,7 @@ class MpSender[T]:
|
||||
self._state.buffer.put(item, block=True)
|
||||
|
||||
async def send_async(self, item: T) -> None:
|
||||
await to_thread.run_sync(
|
||||
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
WeightLoader = Callable[[nn.Module, int], None] | None
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
@@ -346,6 +347,7 @@ def tensor_auto_parallel(
|
||||
group: mx.distributed.Group,
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
all_to_sharded_linear = partial(
|
||||
shard_linear,
|
||||
@@ -455,7 +457,7 @@ def tensor_auto_parallel(
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
model = tensor_parallel_sharding_strategy.shard_model(
|
||||
model, timeout_seconds, on_timeout
|
||||
model, timeout_seconds, on_timeout, weight_loader
|
||||
)
|
||||
return patch_tensor_model(model)
|
||||
|
||||
@@ -482,6 +484,7 @@ class TensorParallelShardingStrategy(ABC):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module: ...
|
||||
|
||||
|
||||
@@ -491,9 +494,12 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(LlamaModel, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
# Force load weights before sharding to avoid FAST_SYNCH deadlock
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -545,9 +551,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -620,9 +629,12 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
for layer in model.layers: # type: ignore
|
||||
for i, layer in enumerate(model.layers): # type: ignore
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
@@ -762,9 +774,12 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -802,9 +817,12 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -926,9 +944,12 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Glm4MoeModel, model)
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -972,10 +993,13 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -1013,10 +1037,13 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Step35Model, model)
|
||||
|
||||
for layer in model.layers:
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
|
||||
507
src/exo/worker/engines/mlx/model_transfer.py
Normal file
507
src/exo/worker/engines/mlx/model_transfer.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Model transfer via MLX distributed all_sum.
|
||||
|
||||
Three transfer modes:
|
||||
1. Metadata file transfer: broadcast small files (config.json, tokenizer, etc.) to disk
|
||||
2. Weight tensor broadcast: stream weight tensors directly into memory via all_sum
|
||||
3. Full file transfer: broadcast all files (including safetensors) to disk
|
||||
|
||||
All functions are collective operations — every rank in the group must call them.
|
||||
|
||||
Protocol relies on all_sum: source has real data, receivers have zeros.
|
||||
all_sum(source + zeros) = source data on all ranks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Final, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
|
||||
CHUNK_SIZE: Final[int] = 100 * 1024 * 1024 # 100 MB
|
||||
_LAYER_RE: Final[re.Pattern[str]] = re.compile(r"(?:^|\.)(layers|h)\.(\d+)\.")
|
||||
|
||||
|
||||
def _all_sum_cpu(x: mx.array, group: Group) -> mx.array:
|
||||
"""all_sum on CPU stream to avoid GPU memory pressure."""
|
||||
return mx.distributed.all_sum(
|
||||
x, stream=mx.default_stream(mx.Device(mx.cpu)), group=group
|
||||
)
|
||||
|
||||
|
||||
def _is_metadata_file(filename: str) -> bool:
|
||||
"""A metadata file is anything that isn't a weight file or weight index.
|
||||
|
||||
Weight indices (.safetensors.index.json) reference safetensors shard paths.
|
||||
Transferring them to a receiver that has no safetensors files is harmless
|
||||
today (load_model's glob doesn't match them), but excluding them avoids
|
||||
stale references and keeps the transfer minimal.
|
||||
"""
|
||||
if filename.endswith(".safetensors"):
|
||||
return False
|
||||
return not filename.endswith(".safetensors.index.json")
|
||||
|
||||
|
||||
def model_path_for_id(model_id: ModelId) -> Path:
|
||||
"""Get model path without requiring directory to exist (unlike build_model_path)."""
|
||||
return EXO_MODELS_DIR / model_id.normalize()
|
||||
|
||||
|
||||
def coordinate_transfer(group: Group, has_local_model: bool) -> tuple[bool, int]:
|
||||
"""
|
||||
Determine if a transfer is needed and which rank is the source.
|
||||
|
||||
All ranks must call this function (uses collective all_sum).
|
||||
|
||||
Returns:
|
||||
(needs_transfer, source_rank) — source_rank is the lowest rank
|
||||
that has the model. needs_transfer is True if any rank is missing it.
|
||||
"""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
world_size = group.size()
|
||||
|
||||
# Each rank broadcasts a one-hot vector at its position if it has the model
|
||||
bitmask = mx.zeros(world_size, dtype=mx.int32)
|
||||
if has_local_model:
|
||||
bitmask = bitmask.at[group.rank()].add(1)
|
||||
summed = all_sum(bitmask)
|
||||
mx.eval(summed)
|
||||
|
||||
has_model_flags: list[int] = summed.tolist() # type: ignore[assignment]
|
||||
total_have = sum(has_model_flags)
|
||||
|
||||
if total_have == 0:
|
||||
raise RuntimeError(
|
||||
"No rank has the model files — cannot transfer. "
|
||||
"At least one node must have downloaded the model."
|
||||
)
|
||||
|
||||
if total_have == world_size:
|
||||
logger.info("All ranks have model files, no transfer needed")
|
||||
return False, 0
|
||||
|
||||
source_rank = next(i for i, flag in enumerate(has_model_flags) if flag > 0)
|
||||
logger.info(
|
||||
f"Transfer needed: source_rank={source_rank}, "
|
||||
f"{total_have}/{world_size} ranks have model"
|
||||
)
|
||||
return True, source_rank
|
||||
|
||||
|
||||
def _broadcast_json(obj: object, group: Group, is_source: bool) -> object:
|
||||
"""Broadcast a JSON-serializable object from source to all ranks."""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
data = json.dumps(obj, separators=(",", ":")).encode("utf-8") if is_source else b""
|
||||
|
||||
# Broadcast length
|
||||
len_arr = mx.array([len(data) if is_source else 0], dtype=mx.int64)
|
||||
len_result = all_sum(len_arr)
|
||||
mx.eval(len_result)
|
||||
length = int(len_result.item())
|
||||
if length == 0:
|
||||
return None
|
||||
|
||||
# Broadcast payload
|
||||
if is_source:
|
||||
arr = mx.array(list(data), dtype=mx.uint8)
|
||||
else:
|
||||
arr = mx.zeros(length, dtype=mx.uint8)
|
||||
result = all_sum(arr)
|
||||
mx.eval(result)
|
||||
return json.loads(bytes(cast(list[int], result.tolist()))) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def _build_manifest(
|
||||
model_path: Path, metadata_only: bool = False
|
||||
) -> list[dict[str, str | int]]:
|
||||
"""Build a list of files in the model directory with their relative paths and sizes."""
|
||||
manifest: list[dict[str, str | int]] = []
|
||||
for root, _dirs, files in os.walk(model_path):
|
||||
for fname in sorted(files):
|
||||
if metadata_only and not _is_metadata_file(fname):
|
||||
continue
|
||||
full_path = Path(root) / fname
|
||||
rel_path = str(full_path.relative_to(model_path))
|
||||
manifest.append(
|
||||
{
|
||||
"path": rel_path,
|
||||
"size": full_path.stat().st_size,
|
||||
}
|
||||
)
|
||||
return manifest
|
||||
|
||||
|
||||
def _transfer_file_to_disk(
|
||||
source_path: Path,
|
||||
rel_path: str,
|
||||
file_size: int,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
dest_path: Path,
|
||||
) -> None:
|
||||
"""Transfer a single file chunk-by-chunk via all_sum. Source reads from disk, receivers write to dest_path."""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
if is_source:
|
||||
src_file = source_path / rel_path
|
||||
with open(src_file, "rb") as f:
|
||||
offset = 0
|
||||
while offset < file_size:
|
||||
chunk_bytes = min(CHUNK_SIZE, file_size - offset)
|
||||
data = f.read(chunk_bytes)
|
||||
if not data:
|
||||
break
|
||||
size_arr = mx.array([len(data)], dtype=mx.int64)
|
||||
mx.eval(all_sum(size_arr))
|
||||
chunk_arr = mx.array(list(data), dtype=mx.uint8)
|
||||
result = all_sum(chunk_arr)
|
||||
mx.eval(result)
|
||||
offset += len(data)
|
||||
# Signal end of file
|
||||
mx.eval(all_sum(mx.array([0], dtype=mx.int64)))
|
||||
else:
|
||||
dst_file = dest_path / rel_path
|
||||
os.makedirs(dst_file.parent, exist_ok=True)
|
||||
with open(dst_file, "wb") as f:
|
||||
while True:
|
||||
size_arr = all_sum(mx.zeros(1, dtype=mx.int64))
|
||||
mx.eval(size_arr)
|
||||
chunk_size = int(size_arr.item())
|
||||
if chunk_size == 0:
|
||||
break
|
||||
chunk_data = all_sum(mx.zeros(chunk_size, dtype=mx.uint8))
|
||||
mx.eval(chunk_data)
|
||||
f.write(bytes(cast(list[int], chunk_data.tolist())))
|
||||
|
||||
|
||||
def _transfer_files_to_disk(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
metadata_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer files from source to all receivers' disk.
|
||||
|
||||
Source broadcasts a manifest then each file. Receivers write to a temp dir
|
||||
then atomically move files to model_path.
|
||||
"""
|
||||
if is_source:
|
||||
source_manifest = _build_manifest(model_path, metadata_only=metadata_only)
|
||||
else:
|
||||
source_manifest = []
|
||||
manifest = cast(
|
||||
list[dict[str, str | int]],
|
||||
_broadcast_json(source_manifest if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
if not manifest:
|
||||
logger.info("No files to transfer")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Transferring {len(manifest)} files ({'metadata only' if metadata_only else 'all'})"
|
||||
)
|
||||
|
||||
temp_dir: Path | None = None
|
||||
if not is_source:
|
||||
os.makedirs(model_path.parent, exist_ok=True)
|
||||
temp_dir = Path(
|
||||
tempfile.mkdtemp(
|
||||
dir=model_path.parent,
|
||||
prefix=f".transfer_{model_path.name}_",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
for entry in manifest:
|
||||
rel_path = str(entry["path"])
|
||||
file_size = int(entry["size"])
|
||||
logger.info(f" {rel_path} ({file_size} bytes)")
|
||||
_transfer_file_to_disk(
|
||||
source_path=model_path,
|
||||
rel_path=rel_path,
|
||||
file_size=file_size,
|
||||
group=group,
|
||||
is_source=is_source,
|
||||
dest_path=temp_dir if temp_dir is not None else model_path,
|
||||
)
|
||||
|
||||
if temp_dir is not None:
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
for entry in manifest:
|
||||
rel_path = str(entry["path"])
|
||||
src = temp_dir / rel_path
|
||||
dst = model_path / rel_path
|
||||
os.makedirs(dst.parent, exist_ok=True)
|
||||
os.replace(src, dst)
|
||||
logger.info(
|
||||
f"Transfer complete: {len(manifest)} files moved to {model_path}"
|
||||
)
|
||||
finally:
|
||||
if temp_dir is not None and temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def transfer_metadata_files(model_path: Path, group: Group, is_source: bool) -> None:
|
||||
"""
|
||||
Transfer metadata files (config.json, tokenizer files, etc.) to receivers' disk.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
"""
|
||||
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=True)
|
||||
|
||||
|
||||
def transfer_all_files(model_path: Path, group: Group, is_source: bool) -> None:
|
||||
"""
|
||||
Transfer ALL model files (including safetensors) to receivers' disk.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
"""
|
||||
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=False)
|
||||
|
||||
|
||||
def _parse_mx_dtype(dtype_str: str) -> mx.Dtype:
|
||||
"""Convert a dtype string like 'float16' or 'mlx.core.float16' to mx.Dtype."""
|
||||
name = dtype_str.split(".")[-1]
|
||||
dtype = getattr(mx, name, None)
|
||||
if dtype is None:
|
||||
raise ValueError(f"Unknown MLX dtype: {dtype_str}")
|
||||
return dtype # type: ignore[return-value]
|
||||
|
||||
|
||||
def _extract_layer_index(name: str) -> int | None:
|
||||
"""Extract layer index from a weight name, or None for non-layer weights.
|
||||
|
||||
Matches patterns like ``model.layers.5.self_attn.q_proj.weight``
|
||||
or ``transformer.h.12.mlp.gate_proj.scales``.
|
||||
"""
|
||||
m = _LAYER_RE.search(name)
|
||||
return int(m.group(2)) if m else None
|
||||
|
||||
|
||||
class WeightBroadcastState:
|
||||
"""Holds state for layer-by-layer weight broadcasting.
|
||||
|
||||
Created by :func:`prepare_weight_broadcast`. Callers stream weights
|
||||
incrementally via :meth:`broadcast_non_layer_weights` and
|
||||
:meth:`broadcast_layer` so that at most one layer's worth of un-sharded
|
||||
weight data is resident at a time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta: dict[str, dict[str, Any]],
|
||||
source_weights: dict[str, mx.array] | None,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> None:
|
||||
self.meta = meta
|
||||
self.source_weights = source_weights
|
||||
self.group = group
|
||||
self.is_source = is_source
|
||||
|
||||
# Partition weight names into layer vs. non-layer
|
||||
self.layer_names: dict[int, list[str]] = {}
|
||||
self.non_layer_names: list[str] = []
|
||||
for name in sorted(meta.keys()):
|
||||
layer_idx = _extract_layer_index(name)
|
||||
if layer_idx is not None:
|
||||
self.layer_names.setdefault(layer_idx, []).append(name)
|
||||
else:
|
||||
self.non_layer_names.append(name)
|
||||
|
||||
logger.info(
|
||||
f"WeightBroadcastState: {len(self.non_layer_names)} non-layer weights, "
|
||||
f"{len(self.layer_names)} layers"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _broadcast_names(self, names: list[str]) -> dict[str, mx.array]:
|
||||
"""Broadcast a specific set of weight tensors by name."""
|
||||
all_sum = partial(_all_sum_cpu, group=self.group)
|
||||
result: dict[str, mx.array] = {}
|
||||
for name in names:
|
||||
info = self.meta[name]
|
||||
shape = cast(list[int], info["s"])
|
||||
dtype = _parse_mx_dtype(cast(str, info["d"]))
|
||||
|
||||
if self.is_source:
|
||||
assert self.source_weights is not None
|
||||
tensor = self.source_weights.pop(name)
|
||||
mx.eval(tensor) # loads from disk (lazy)
|
||||
else:
|
||||
tensor = mx.zeros(shape, dtype=dtype)
|
||||
|
||||
broadcasted = all_sum(tensor)
|
||||
mx.eval(broadcasted)
|
||||
result[name] = broadcasted
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def broadcast_non_layer_weights(self) -> dict[str, mx.array]:
|
||||
"""Broadcast non-layer weights (embeddings, norms, lm_head)."""
|
||||
if not self.non_layer_names:
|
||||
return {}
|
||||
logger.info(
|
||||
f"Broadcasting {len(self.non_layer_names)} non-layer weight tensors"
|
||||
)
|
||||
return self._broadcast_names(self.non_layer_names)
|
||||
|
||||
def broadcast_layer(self, layer_idx: int) -> dict[str, mx.array]:
|
||||
"""Broadcast weights for a single transformer layer."""
|
||||
names = self.layer_names.get(layer_idx, [])
|
||||
if not names:
|
||||
return {}
|
||||
return self._broadcast_names(names)
|
||||
|
||||
|
||||
def prepare_weight_broadcast(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> WeightBroadcastState:
|
||||
"""Prepare for layer-by-layer weight broadcasting.
|
||||
|
||||
Source loads safetensors lazily and broadcasts weight metadata (names,
|
||||
shapes, dtypes) as JSON. Returns a :class:`WeightBroadcastState` that
|
||||
can then stream weights incrementally via ``broadcast_layer()``.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
"""
|
||||
source_weights: dict[str, mx.array] | None = None
|
||||
if is_source:
|
||||
source_weights = {}
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
weight_files = sorted(model_path.glob("**/*.safetensors"))
|
||||
for wf in weight_files:
|
||||
try:
|
||||
loaded = cast(
|
||||
dict[str, mx.array],
|
||||
mx.load(str(wf), lazy=True), # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
except TypeError:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
|
||||
source_weights.update(loaded)
|
||||
logger.info(
|
||||
f"Source loaded {len(source_weights)} weight tensors (lazy) "
|
||||
f"from {len(weight_files)} files"
|
||||
)
|
||||
|
||||
# Broadcast metadata
|
||||
if is_source and source_weights is not None:
|
||||
source_meta: dict[str, dict[str, Any]] = {
|
||||
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
|
||||
for name, tensor in source_weights.items()
|
||||
}
|
||||
else:
|
||||
source_meta = {}
|
||||
|
||||
meta = cast(
|
||||
dict[str, dict[str, Any]],
|
||||
_broadcast_json(source_meta if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
logger.info(f"Weight broadcast prepared: {len(meta)} tensors")
|
||||
return WeightBroadcastState(meta, source_weights, group, is_source)
|
||||
|
||||
|
||||
def broadcast_model_weights(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> dict[str, mx.array]:
|
||||
"""
|
||||
Broadcast model weight tensors from source rank to all receivers' memory.
|
||||
|
||||
Source loads weights from .safetensors files on disk and broadcasts each
|
||||
tensor via all_sum. Receivers receive tensors directly as mx.arrays in
|
||||
memory — no disk write for weight data.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
|
||||
Returns:
|
||||
dict mapping weight names to mx.arrays (on all ranks).
|
||||
"""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
# Source loads weights (lazy if supported, so only one tensor in memory at a time)
|
||||
weights: dict[str, mx.array] = {}
|
||||
if is_source:
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
weight_files = sorted(model_path.glob("**/*.safetensors"))
|
||||
for wf in weight_files:
|
||||
try:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf), lazy=True)) # pyright: ignore[reportCallIssue]
|
||||
except TypeError:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
|
||||
weights.update(loaded)
|
||||
logger.info(
|
||||
f"Source loaded {len(weights)} weight tensors from {len(weight_files)} files"
|
||||
)
|
||||
|
||||
# Broadcast weight metadata: {name: {shape, dtype}}
|
||||
if is_source:
|
||||
source_meta: dict[str, dict[str, Any]] = {
|
||||
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
|
||||
for name, tensor in weights.items()
|
||||
}
|
||||
else:
|
||||
source_meta = {}
|
||||
meta = cast(
|
||||
dict[str, dict[str, Any]],
|
||||
_broadcast_json(source_meta if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
logger.info(f"Broadcasting {len(meta)} weight tensors")
|
||||
|
||||
# Broadcast each tensor in sorted order (deterministic across ranks).
|
||||
# Source loads one tensor at a time from disk (lazy), broadcasts it,
|
||||
# then drops the reference so only one tensor is in flight at a time.
|
||||
result: dict[str, mx.array] = {}
|
||||
for i, name in enumerate(sorted(meta.keys())):
|
||||
info = meta[name]
|
||||
shape = cast(list[int], info["s"])
|
||||
dtype_str = cast(str, info["d"])
|
||||
dtype = _parse_mx_dtype(dtype_str)
|
||||
|
||||
if is_source:
|
||||
tensor = weights.pop(name) # pop to free lazy ref after broadcast
|
||||
mx.eval(tensor) # loads from disk
|
||||
else:
|
||||
tensor = mx.zeros(shape, dtype=dtype)
|
||||
|
||||
broadcasted = all_sum(tensor)
|
||||
mx.eval(broadcasted)
|
||||
result[name] = broadcasted
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
logger.info(f" Broadcast {i + 1}/{len(meta)} tensors")
|
||||
|
||||
logger.info(f"Weight broadcast complete: {len(result)} tensors")
|
||||
return result
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -59,6 +60,13 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
pipeline_auto_parallel,
|
||||
tensor_auto_parallel,
|
||||
)
|
||||
from exo.worker.engines.mlx.model_transfer import (
|
||||
WeightBroadcastState,
|
||||
coordinate_transfer,
|
||||
model_path_for_id,
|
||||
prepare_weight_broadcast,
|
||||
transfer_metadata_files,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
@@ -171,6 +179,7 @@ def load_mlx_items(
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
has_local_model: bool = True,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
@@ -185,7 +194,10 @@ def load_mlx_items(
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
bound_instance.bound_shard,
|
||||
group=group,
|
||||
on_timeout=on_timeout,
|
||||
has_local_model=has_local_model,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
@@ -201,30 +213,89 @@ def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
has_local_model: bool = True,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_card.model_id)
|
||||
model_id = shard_metadata.model_card.model_id
|
||||
model_path = model_path_for_id(model_id)
|
||||
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
# Coordinate: does any rank need a transfer?
|
||||
needs_transfer, source_rank = coordinate_transfer(group, has_local_model)
|
||||
is_source = group.rank() == source_rank
|
||||
|
||||
# Step 1: Always ensure all nodes have metadata files (config, tokenizer, etc.).
|
||||
# This is cheap (~20MB, ~1s) and guarantees config.json is present for load_model().
|
||||
transfer_metadata_files(model_path, group, is_source)
|
||||
|
||||
# Step 2: Only broadcast weights if some rank is missing the model
|
||||
broadcast_state: WeightBroadcastState | None = None
|
||||
if needs_transfer:
|
||||
logger.info(
|
||||
f"Model transfer needed (source_rank={source_rank}, "
|
||||
f"is_source={is_source}, local_weights={has_local_model})"
|
||||
)
|
||||
broadcast_state = prepare_weight_broadcast(model_path, group, is_source)
|
||||
|
||||
# Create model architecture (all ranks have config.json on disk now).
|
||||
# Always use lazy=True when we have broadcast state: load_model's internal
|
||||
# nn.quantize skips quantization when weights dict is empty (no safetensors),
|
||||
# leaving the model un-quantized. lazy=False would then mx.eval() the full
|
||||
# fp16 model (~72GB for a 36B-param model), causing OOM on the receiver.
|
||||
# We handle quantization ourselves below before loading broadcast weights.
|
||||
use_lazy = has_local_model or broadcast_state is not None
|
||||
model, _ = load_model(model_path, lazy=use_lazy, strict=False)
|
||||
logger.debug(model)
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# TODO: See if we should quantize the model.
|
||||
# def is_attention_layer(path: str) -> bool:
|
||||
# path = path.lower()
|
||||
|
||||
# return "self_attn" in path and "layernorm" not in path
|
||||
|
||||
# def quant_predicate(path: str, module: nn.Module):
|
||||
# if not isinstance(module, nn.Linear):
|
||||
# return False
|
||||
|
||||
# return is_attention_layer(path)
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
if broadcast_state is not None:
|
||||
# When receiver has no weight files, load_model skips quantization
|
||||
# (its class_predicate checks `f"{p}.scales" in weights`, which is
|
||||
# always False when weights is empty). Apply quantization explicitly
|
||||
# using the broadcast metadata to determine which layers are quantized,
|
||||
# matching load_model's selective quantization logic exactly.
|
||||
if not has_local_model:
|
||||
config_path = model_path / "config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f) # pyright: ignore[reportAny]
|
||||
quant_config: dict[str, Any] | None = config.get( # pyright: ignore[reportAny]
|
||||
"quantization", None
|
||||
)
|
||||
if quant_config is not None:
|
||||
logger.info(f"Applying quantization to receiver model: {quant_config}")
|
||||
broadcast_weight_names = set(broadcast_state.meta.keys())
|
||||
|
||||
def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
|
||||
# Per-layer overrides from config (e.g. "lm_head": false)
|
||||
assert quant_config is not None
|
||||
if p in quant_config:
|
||||
return quant_config[p] # pyright: ignore[reportAny]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
# Only quantize layers whose .scales exist in broadcast weights
|
||||
return f"{p}.scales" in broadcast_weight_names
|
||||
|
||||
group_size = int(quant_config.get("group_size", 64)) # pyright: ignore[reportAny]
|
||||
bits = int(quant_config.get("bits", 4)) # pyright: ignore[reportAny]
|
||||
mode: str = quant_config.get("mode", "affine") # pyright: ignore[reportAny]
|
||||
nn.quantize( # pyright: ignore[reportUnknownMemberType]
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mode,
|
||||
class_predicate=_class_predicate,
|
||||
)
|
||||
|
||||
# Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.
|
||||
# These are small (~600MB) and needed before the sharding loop.
|
||||
non_layer_weights = broadcast_state.broadcast_non_layer_weights()
|
||||
if non_layer_weights:
|
||||
model.load_weights(list(non_layer_weights.items()), strict=False)
|
||||
logger.info(f"Loaded {len(non_layer_weights)} non-layer weight tensors")
|
||||
del non_layer_weights
|
||||
|
||||
tokenizer = get_tokenizer(model_path, shard_metadata)
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
@@ -238,12 +309,43 @@ def shard_and_load(
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
# Build per-layer weight loader for streaming broadcast during sharding.
|
||||
# Each layer's weights are broadcast via all_sum just before that layer is
|
||||
# sharded, so at most one un-sharded layer is in memory at a time.
|
||||
weight_loader_fn: Callable[[nn.Module, int], None] | None = None
|
||||
if broadcast_state is not None:
|
||||
_state = broadcast_state # capture for closure
|
||||
|
||||
def _load_layer_weights(mdl: nn.Module, layer_idx: int) -> None:
|
||||
layer_weights = _state.broadcast_layer(layer_idx)
|
||||
if layer_weights:
|
||||
mdl.load_weights(list(layer_weights.items()), strict=False)
|
||||
|
||||
weight_loader_fn = _load_layer_weights
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
|
||||
model = tensor_auto_parallel(
|
||||
model, group, timeout_seconds, on_timeout, weight_loader_fn
|
||||
)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
# Broadcast all layers (all_sum is collective — all ranks must
|
||||
# participate) but only load weights for layers this node will
|
||||
# keep after pipeline slicing. Out-of-range results are discarded,
|
||||
# keeping peak memory proportional to this node's layer count.
|
||||
if broadcast_state is not None:
|
||||
for layer_idx in sorted(broadcast_state.layer_names.keys()):
|
||||
layer_weights = broadcast_state.broadcast_layer(layer_idx)
|
||||
if (
|
||||
shard_metadata.start_layer
|
||||
<= layer_idx
|
||||
< shard_metadata.end_layer
|
||||
and layer_weights
|
||||
):
|
||||
model.load_weights(list(layer_weights.items()), strict=False)
|
||||
del layer_weights
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
case CfgShardMetadata():
|
||||
@@ -252,6 +354,8 @@ def shard_and_load(
|
||||
"this metadata type is only for image generation models"
|
||||
)
|
||||
|
||||
del broadcast_state
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
@@ -269,52 +373,19 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
|
||||
|
||||
def _read_model_type_from_config(model_path: Path) -> str | None:
|
||||
"""Read the model_type field from config.json at the given model path.
|
||||
|
||||
Returns None if config.json doesn't exist or doesn't contain model_type.
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
"""
|
||||
config_path = model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config: dict[str, Any] = json.load(f) # pyright: ignore[reportAny]
|
||||
model_type: Any = config.get("model_type")
|
||||
if model_type is None:
|
||||
text_config: Any = config.get("text_config")
|
||||
if isinstance(text_config, dict):
|
||||
model_type = text_config.get("model_type") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
return model_type if isinstance(model_type, str) else None
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
Get the EOS token IDs for a model based on its ID.
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(
|
||||
model_id: ModelId, model_type: str | None = None
|
||||
) -> list[int] | None:
|
||||
"""Get the EOS token IDs for a model based on its architecture type.
|
||||
|
||||
Uses model_type from config.json when available, falls back to model_id
|
||||
string matching for backward compatibility.
|
||||
Some models require explicit EOS token configuration that isn't in their
|
||||
tokenizer config. This function returns the known EOS token IDs for such models.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID
|
||||
model_type: The model_type field from config.json (e.g., "kimi", "glm4")
|
||||
|
||||
Returns:
|
||||
List of EOS token IDs, or None if the model uses standard tokenizer config
|
||||
"""
|
||||
if model_type is not None:
|
||||
if model_type == "kimi":
|
||||
return [163586]
|
||||
elif model_type == "glm4_moe_lite":
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif model_type.startswith("glm"):
|
||||
return [151336, 151329, 151338]
|
||||
|
||||
# Fallback: string matching on model_id
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
@@ -329,10 +400,11 @@ def get_eos_token_ids_for_model(
|
||||
def load_tokenizer_for_model_id(
|
||||
model_id: ModelId, model_path: Path
|
||||
) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model given its ID and local path.
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
|
||||
Uses model_type from config.json for architecture detection when available,
|
||||
falling back to model_id string matching for backward compatibility.
|
||||
This is the core tokenizer loading logic, handling special cases for different
|
||||
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
|
||||
@@ -341,21 +413,11 @@ def load_tokenizer_for_model_id(
|
||||
Returns:
|
||||
TokenizerWrapper instance configured for the model
|
||||
"""
|
||||
model_type = _read_model_type_from_config(model_path)
|
||||
model_id_lower = model_id.lower()
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id, model_type=model_type)
|
||||
|
||||
is_kimi = (
|
||||
model_type == "kimi" if model_type is not None else "kimi-k2" in model_id_lower
|
||||
)
|
||||
is_gemma3 = (
|
||||
model_type == "gemma3"
|
||||
if model_type is not None
|
||||
else "gemma-3" in model_id_lower
|
||||
)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if is_kimi:
|
||||
if "kimi-k2" in model_id_lower:
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
@@ -409,7 +471,7 @@ def load_tokenizer_for_model_id(
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
if is_gemma3:
|
||||
if "gemma-3" in model_id_lower:
|
||||
gemma_3_eos_id = 1
|
||||
gemma_3_end_of_turn_id = 106
|
||||
if tokenizer.eos_token_ids is not None:
|
||||
@@ -616,6 +678,11 @@ def mlx_cleanup(
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
"""Synchronize a boolean across all distributed nodes.
|
||||
|
||||
Returns True if any node has bool_=True. Uses all_sum so every
|
||||
node participates in the collective — preventing GPU deadlocks.
|
||||
"""
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
|
||||
@@ -24,6 +24,7 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
JacclSideChannelGathered,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -33,7 +34,6 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -159,6 +159,15 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
|
||||
if isinstance(event, JacclSideChannelGathered):
|
||||
for runner in self.runners.values():
|
||||
if (
|
||||
runner.bound_instance.instance.instance_id
|
||||
== event.instance_id
|
||||
):
|
||||
runner.notify_gathered(event)
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
@@ -225,22 +234,15 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await runner.start_task(task)
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -278,18 +280,18 @@ class Worker:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self._start_runner_task(modified_task)
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self._start_runner_task(task)
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
await self.runners[
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
@@ -17,6 +18,7 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
TransferModelToDisk,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
@@ -35,8 +37,11 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -56,9 +61,10 @@ def plan(
|
||||
return (
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _create_runner(node_id, runners, instances, all_runners)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _transfer_model_to_disk(runners, all_runners, global_download_status)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
@@ -75,6 +81,12 @@ def _kill_runner(
|
||||
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
# Master removed our runner from state (retry signal) and process is dead
|
||||
if runner_id not in all_runners and isinstance(
|
||||
runner.status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
for (
|
||||
global_runner_id
|
||||
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
|
||||
@@ -92,6 +104,7 @@ def _create_runner(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> CreateRunner | None:
|
||||
for instance in instances.values():
|
||||
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
|
||||
@@ -101,6 +114,16 @@ def _create_runner(
|
||||
if runner_id in runners:
|
||||
continue
|
||||
|
||||
# Don't create while any peer runner is in a terminal state — wait for
|
||||
# the master to emit InstanceRetrying which removes them from state.
|
||||
has_terminal_peer = any(
|
||||
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
|
||||
for peer_rid in instance.shard_assignments.node_to_runner.values()
|
||||
if peer_rid != runner_id
|
||||
)
|
||||
if has_terminal_peer:
|
||||
continue
|
||||
|
||||
shard = instance.shard(runner_id)
|
||||
assert shard is not None
|
||||
|
||||
@@ -123,6 +146,10 @@ def _model_needs_download(
|
||||
}
|
||||
|
||||
for runner in runners.values():
|
||||
# Transfer-only instances don't need downloads
|
||||
if runner.bound_instance.instance.shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
@@ -131,6 +158,15 @@ def _model_needs_download(
|
||||
(DownloadOngoing, DownloadCompleted, DownloadFailed),
|
||||
)
|
||||
):
|
||||
# For multi-node instances, skip download if a peer already has the model.
|
||||
# The model will be transferred via MLX distributed during LoadModel.
|
||||
instance = runner.bound_instance.instance
|
||||
is_multi_node = len(instance.shard_assignments.node_to_runner) > 1
|
||||
if is_multi_node and _any_peer_has_model(
|
||||
node_id, model_id, instance, global_download_status
|
||||
):
|
||||
continue
|
||||
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
instance_id=runner.bound_instance.instance.instance_id,
|
||||
@@ -188,6 +224,43 @@ def _init_distributed_backend(
|
||||
return None
|
||||
|
||||
|
||||
def _transfer_model_to_disk(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> TransferModelToDisk | None:
|
||||
"""For transfer-only instances: after all ranks are connected, emit TransferModelToDisk."""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
if not shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
is_runner_connected = isinstance(runner.status, RunnerConnected)
|
||||
all_connected_or_further = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerShuttingDown, RunnerShutdown),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_connected and all_connected_or_further:
|
||||
has_local = _node_has_download(
|
||||
runner.bound_instance.bound_node_id,
|
||||
shard_assignments.model_id,
|
||||
global_download_status,
|
||||
)
|
||||
return TransferModelToDisk(
|
||||
instance_id=instance.instance_id,
|
||||
shard_metadata=runner.bound_instance.bound_shard,
|
||||
has_local_model=has_local,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -197,38 +270,97 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
# Transfer-only instances don't load models for inference
|
||||
if shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
if is_single_node_instance:
|
||||
# Single-node: require local download complete
|
||||
if not _all_downloads_complete(shard_assignments, global_download_status):
|
||||
continue
|
||||
if isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id, has_local_model=True)
|
||||
else:
|
||||
# Multi-node: require at least one node to have the model downloaded.
|
||||
# Nodes without the model will receive it via MLX distributed transfer
|
||||
# during model loading.
|
||||
if not _any_download_complete(shard_assignments, global_download_status):
|
||||
continue
|
||||
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
has_local = _node_has_download(
|
||||
runner.bound_instance.bound_node_id,
|
||||
shard_assignments.model_id,
|
||||
global_download_status,
|
||||
)
|
||||
return LoadModel(
|
||||
instance_id=instance.instance_id,
|
||||
has_local_model=has_local,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _node_has_download(
|
||||
nid: NodeId,
|
||||
model_id: ModelId,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if a specific node has completed downloading the given model."""
|
||||
return any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == model_id
|
||||
for dp in global_download_status.get(nid, [])
|
||||
)
|
||||
|
||||
|
||||
def _any_peer_has_model(
|
||||
node_id: NodeId,
|
||||
model_id: ModelId,
|
||||
instance: Instance,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if any other node in the instance already has the model downloaded."""
|
||||
return any(
|
||||
_node_has_download(nid, model_id, global_download_status)
|
||||
for nid in instance.shard_assignments.node_to_runner
|
||||
if nid != node_id
|
||||
)
|
||||
|
||||
|
||||
def _all_downloads_complete(
|
||||
shard_assignments: ShardAssignments,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if ALL nodes in the instance have completed downloading the model."""
|
||||
return all(
|
||||
_node_has_download(nid, shard_assignments.model_id, global_download_status)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
|
||||
|
||||
def _any_download_complete(
|
||||
shard_assignments: ShardAssignments,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if at least one node in the instance has completed downloading the model."""
|
||||
return any(
|
||||
_node_has_download(nid, shard_assignments.model_id, global_download_status)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
|
||||
|
||||
def _ready_to_warmup(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -236,6 +368,11 @@ def _ready_to_warmup(
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
# Transfer-only instances don't go through warmup
|
||||
if shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
@@ -310,7 +447,8 @@ def _pending_tasks(
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
) -> CancelTask | None:
|
||||
"""Find a cancelled task that hasn't been sent to the runner yet."""
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
|
||||
@@ -17,6 +17,7 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
pipe_fifo_paths: tuple[str, str] | None = None,
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
@@ -30,6 +31,16 @@ def entrypoint(
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
|
||||
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
|
||||
if pipe_fifo_paths is not None:
|
||||
fifo_c2p, fifo_p2c = pipe_fifo_paths
|
||||
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
|
||||
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
|
||||
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
|
||||
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
|
||||
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
@@ -56,7 +67,9 @@ def entrypoint(
|
||||
try:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
cancel_receiver.close()
|
||||
finally:
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
cancel_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -42,6 +42,7 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
TransferModelToDisk,
|
||||
)
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -81,6 +82,11 @@ from exo.worker.engines.image import (
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.model_transfer import (
|
||||
coordinate_transfer,
|
||||
model_path_for_id,
|
||||
transfer_all_files,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
@@ -201,7 +207,10 @@ def main(
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
bound_instance,
|
||||
group,
|
||||
on_timeout=on_model_load_timeout,
|
||||
has_local_model=task.has_local_model,
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
|
||||
@@ -243,7 +252,7 @@ def main(
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
t = time.monotonic()
|
||||
t = time.perf_counter()
|
||||
toks = warmup_inference(
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -251,7 +260,7 @@ def main(
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||
math.ceil(toks / max(time.perf_counter() - t, 0.001)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
@@ -535,6 +544,27 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case TransferModelToDisk() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
):
|
||||
logger.info("starting disk-to-disk model transfer")
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
model_path = model_path_for_id(
|
||||
task.shard_metadata.model_card.model_id
|
||||
)
|
||||
_, source_rank = coordinate_transfer(group, task.has_local_model)
|
||||
is_source = group.rank() == source_rank
|
||||
transfer_all_files(model_path, group, is_source)
|
||||
|
||||
logger.info("disk-to-disk model transfer complete")
|
||||
current_status = RunnerShuttingDown()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from multiprocessing import Process
|
||||
from typing import Self
|
||||
|
||||
@@ -14,12 +18,14 @@ from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
@@ -34,6 +40,26 @@ from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
|
||||
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
|
||||
data = b""
|
||||
while len(data) < n:
|
||||
chunk = os.read(fd, n - len(data))
|
||||
if not chunk:
|
||||
return None
|
||||
data += chunk
|
||||
return data
|
||||
|
||||
|
||||
def _pipe_write_all(fd: int, data: bytes) -> None:
|
||||
"""Write all bytes to a file descriptor."""
|
||||
view = memoryview(data)
|
||||
while view:
|
||||
written = os.write(fd, view)
|
||||
view = view[written:]
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
DECODE_TIMEOUT_SECONDS = 5
|
||||
|
||||
@@ -46,12 +72,21 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_event_sender: Sender[Event]
|
||||
_pipe_read_fd: int | None = None # Python reads runner's pipe output
|
||||
_pipe_write_fd: int | None = None # Python writes gathered data to runner
|
||||
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
|
||||
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
|
||||
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
|
||||
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
_gathered_waiters: dict[
|
||||
int, tuple[anyio.Event, JacclSideChannelGathered | None]
|
||||
] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -65,6 +100,23 @@ class RunnerSupervisor:
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
|
||||
# Named pipes work across multiprocessing.Process spawn (macOS default).
|
||||
# FIFO c2p: C++ writes local data → Python reads it
|
||||
# FIFO p2c: Python writes gathered data → C++ reads it
|
||||
fifo_dir: str | None = None
|
||||
fifo_c2p: str | None = None
|
||||
fifo_p2c: str | None = None
|
||||
pipe_fifo_paths: tuple[str, str] | None = None
|
||||
|
||||
if isinstance(bound_instance.instance, MlxJacclInstance):
|
||||
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
|
||||
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
|
||||
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
|
||||
os.mkfifo(fifo_c2p)
|
||||
os.mkfifo(fifo_p2c)
|
||||
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
@@ -73,6 +125,7 @@ class RunnerSupervisor:
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
pipe_fifo_paths,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
@@ -88,21 +141,54 @@ class RunnerSupervisor:
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
_fifo_dir=fifo_dir,
|
||||
_fifo_c2p=fifo_c2p,
|
||||
_fifo_p2c=fifo_p2c,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
await self._forward_events()
|
||||
|
||||
if self._fifo_c2p is not None and self._fifo_p2c is not None:
|
||||
# Open FIFOs from parent side. These block until child opens the other end,
|
||||
# so we run them in threads concurrently to avoid deadlock.
|
||||
fifo_c2p = self._fifo_c2p
|
||||
fifo_p2c = self._fifo_p2c
|
||||
|
||||
async def open_read() -> None:
|
||||
self._pipe_read_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_c2p, os.O_RDONLY)
|
||||
)
|
||||
|
||||
async def open_write() -> None:
|
||||
self._pipe_write_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_p2c, os.O_WRONLY)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as open_tg:
|
||||
open_tg.start_soon(open_read)
|
||||
open_tg.start_soon(open_write)
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(self._pipe_relay)
|
||||
tg.start_soon(self._forward_events)
|
||||
else:
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self._event_sender.close()
|
||||
self._close_pipe_fds()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
@@ -140,6 +226,7 @@ class RunnerSupervisor:
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
"""Send a cancellation signal to the runner process."""
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
@@ -181,6 +268,110 @@ class RunnerSupervisor:
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
|
||||
def _close_pipe_fds(self) -> None:
|
||||
if self._pipe_read_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_read_fd)
|
||||
self._pipe_read_fd = None
|
||||
if self._pipe_write_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_write_fd)
|
||||
self._pipe_write_fd = None
|
||||
if self._child_pipe_fds is not None:
|
||||
for fd in self._child_pipe_fds:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
self._child_pipe_fds = None
|
||||
# Clean up FIFO files
|
||||
if self._fifo_c2p is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_c2p)
|
||||
self._fifo_c2p = None
|
||||
if self._fifo_p2c is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_p2c)
|
||||
self._fifo_p2c = None
|
||||
if self._fifo_dir is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.rmdir(self._fifo_dir)
|
||||
self._fifo_dir = None
|
||||
|
||||
async def _pipe_relay(self) -> None:
|
||||
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
|
||||
assert self._pipe_read_fd is not None
|
||||
assert self._pipe_write_fd is not None
|
||||
read_fd = self._pipe_read_fd
|
||||
write_fd = self._pipe_write_fd
|
||||
sequence = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 1. Read local data from runner: [uint32 size][size bytes]
|
||||
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
|
||||
if header is None:
|
||||
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
|
||||
break
|
||||
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
|
||||
local_data = await to_thread.run_sync(
|
||||
partial(_pipe_read_exact, read_fd, data_size)
|
||||
)
|
||||
if local_data is None:
|
||||
logger.warning("JACCL pipe relay: EOF reading data payload")
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
|
||||
)
|
||||
|
||||
# 2. Emit JacclSideChannelData event
|
||||
waiter = anyio.Event()
|
||||
self._gathered_waiters[sequence] = (waiter, None)
|
||||
await self._event_sender.send(
|
||||
JacclSideChannelData(
|
||||
instance_id=self.bound_instance.instance.instance_id,
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
sequence=sequence,
|
||||
data=local_data,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Wait for gathered result
|
||||
await waiter.wait()
|
||||
_, gathered_event = self._gathered_waiters.pop(sequence)
|
||||
assert gathered_event is not None
|
||||
|
||||
# 4. Order gathered data by runner rank and concatenate
|
||||
instance = self.bound_instance.instance
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
|
||||
ordered_data = b"".join(
|
||||
gathered_event.gathered_data[rid] for rid in runner_order
|
||||
)
|
||||
|
||||
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
|
||||
total_size = len(ordered_data)
|
||||
response = struct.pack("<I", total_size) + ordered_data
|
||||
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
|
||||
)
|
||||
sequence += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"JACCL pipe relay: OS error: {e}")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
|
||||
|
||||
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
|
||||
"""Called by the worker when a JacclSideChannelGathered event arrives."""
|
||||
seq = event.sequence
|
||||
if seq not in self._gathered_waiters:
|
||||
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
|
||||
return
|
||||
waiter, _ = self._gathered_waiters[seq]
|
||||
self._gathered_waiters[seq] = (waiter, event)
|
||||
waiter.set()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||
|
||||
@@ -24,7 +24,6 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
|
||||
# Files needed for tokenizer functionality
|
||||
TOKENIZER_FILE_PATTERNS = [
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
@@ -339,9 +338,6 @@ async def test_kimi_tokenizer_specifically():
|
||||
# Verify EOS token is set
|
||||
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
|
||||
|
||||
# Verify architecture-based detection gives same result
|
||||
assert get_eos_token_ids_for_model(model_id, model_type="kimi") == [163586]
|
||||
|
||||
|
||||
# Test GLM tokenizer since it also has special handling
|
||||
@pytest.mark.asyncio
|
||||
@@ -382,10 +378,3 @@ async def test_glm_tokenizer_specifically():
|
||||
151329,
|
||||
151338,
|
||||
], "GLM EOS tokens should be correct"
|
||||
|
||||
# Verify architecture-based detection gives same result
|
||||
assert get_eos_token_ids_for_model(model_id, model_type="glm4") == [
|
||||
151336,
|
||||
151329,
|
||||
151338,
|
||||
]
|
||||
|
||||
@@ -112,6 +112,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is True
|
||||
|
||||
|
||||
def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
@@ -157,10 +158,11 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
assert not isinstance(result, plan_mod.DownloadModel)
|
||||
|
||||
|
||||
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
def test_plan_loads_model_when_any_node_has_download_for_multi_node():
|
||||
"""
|
||||
LoadModel should not be emitted while some shards are still missing from
|
||||
the global_download_status.
|
||||
For multi-node instances, LoadModel should be emitted when at least one
|
||||
node has the model downloaded. Nodes without the model will receive it
|
||||
via MLX distributed transfer during model loading.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -185,6 +187,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A has the model — LoadModel should still fire
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -203,19 +206,42 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is True
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
|
||||
def test_plan_does_not_load_model_when_no_node_has_download():
|
||||
"""
|
||||
LoadModel should not be emitted when no node has the model downloaded.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# No node has the model
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -227,4 +253,57 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_load_model_has_local_model_false_when_node_missing_download():
|
||||
"""
|
||||
For multi-node instances, when the local node does NOT have the model
|
||||
but a peer does, LoadModel should be emitted with has_local_model=False.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
# NODE_B is the local node (bound_node_id=NODE_B), it does NOT have the model
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A has the model, NODE_B does not
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is False
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
import unittest.mock
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
@@ -117,6 +115,12 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
|
||||
# Mock mx.distributed.all_gather so MockGroup doesn't hit real MLX C++ bindings.
|
||||
def _mock_all_gather(x: object, **_kw: object) -> object:
|
||||
return x
|
||||
|
||||
monkeypatch.setattr(mlx_runner.mx.distributed, "all_gather", _mock_all_gather)
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
@@ -178,16 +182,15 @@ def _run(tasks: Iterable[Task]):
|
||||
# this is some c++ nonsense
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
with unittest.mock.patch(
|
||||
"exo.worker.runner.runner.mx.distributed.all_gather",
|
||||
make_nothin(mx.array([1])),
|
||||
):
|
||||
mlx_runner.main(
|
||||
bound_instance,
|
||||
event_sender, # pyright: ignore[reportArgumentType]
|
||||
task_receiver,
|
||||
cancel_receiver,
|
||||
)
|
||||
cancel_receiver.close = nothin
|
||||
cancel_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(
|
||||
bound_instance,
|
||||
event_sender, # pyright: ignore[reportArgumentType]
|
||||
task_receiver,
|
||||
cancel_receiver,
|
||||
)
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
40
uv.lock
generated
40
uv.lock
generated
@@ -377,8 +377,8 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -416,7 +416,7 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.5" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
{ name = "msgspec", specifier = ">=0.19.0" },
|
||||
@@ -1020,8 +1020,8 @@ dependencies = [
|
||||
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1048,18 +1048,12 @@ wheels = [
|
||||
name = "mlx"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux'",
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
|
||||
]
|
||||
@@ -1072,6 +1066,14 @@ cuda13 = [
|
||||
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.7.dev20260217+50487b41"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.6"
|
||||
@@ -1102,7 +1104,7 @@ version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1114,16 +1116,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "more-itertools"
|
||||
version = "10.8.0"
|
||||
|
||||
Reference in New Issue
Block a user