mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 07:17:30 -05:00
Compare commits
18 Commits
alexcheema
...
add-glm5-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21b594b176 | ||
|
|
c8997217cf | ||
|
|
490d2e46ba | ||
|
|
facf2d4d03 | ||
|
|
a962a28afc | ||
|
|
db79c350c1 | ||
|
|
d6301ed593 | ||
|
|
6d1ca6689b | ||
|
|
c01b6fff21 | ||
|
|
8392e78afe | ||
|
|
86735ece78 | ||
|
|
2759e92334 | ||
|
|
131fb141a6 | ||
|
|
2d8bfc2e3c | ||
|
|
042999f728 | ||
|
|
b61dc2eb35 | ||
|
|
36a7115b6f | ||
|
|
0b7d88b43b |
27
.github/workflows/pipeline.yml
vendored
27
.github/workflows/pipeline.yml
vendored
@@ -8,33 +8,6 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
typecheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
@@ -5,21 +5,21 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
|
||||
11
README.md
11
README.md
@@ -72,16 +72,23 @@ There are two ways to run exo:
|
||||
|
||||
### Run from Source (macOS)
|
||||
|
||||
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
|
||||
|
||||
```bash
|
||||
nix run .#exo
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
```
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard)
|
||||
|
||||
|
||||
```bash
|
||||
brew install uv macmon node
|
||||
```
|
||||
|
||||
@@ -126,11 +126,37 @@ final class ExoProcessController: ObservableObject {
|
||||
return
|
||||
}
|
||||
process.terminationHandler = nil
|
||||
if process.isRunning {
|
||||
process.terminate()
|
||||
}
|
||||
self.process = nil
|
||||
status = .stopped
|
||||
|
||||
guard process.isRunning else {
|
||||
self.process = nil
|
||||
return
|
||||
}
|
||||
|
||||
let proc = process
|
||||
self.process = nil
|
||||
|
||||
Task.detached {
|
||||
proc.interrupt()
|
||||
|
||||
for _ in 0..<50 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
proc.terminate()
|
||||
}
|
||||
|
||||
for _ in 0..<30 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
kill(proc.processIdentifier, SIGKILL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func restart() {
|
||||
|
||||
7
bench/bench.toml
Normal file
7
bench/bench.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
# Canary benchmark manifest
|
||||
#
|
||||
# Lists the suite files to include. Each file defines benchmarks
|
||||
# with shared constraints, topology, and default args.
|
||||
include = [
|
||||
"single-m3-ultra.toml",
|
||||
]
|
||||
@@ -288,6 +288,151 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
@@ -535,6 +680,11 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -569,13 +719,16 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and args.settle_timeout > 0:
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + args.settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
@@ -607,6 +760,16 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
189
bench/single-m3-ultra.toml
Normal file
189
bench/single-m3-ultra.toml
Normal file
@@ -0,0 +1,189 @@
|
||||
# Single-node M3 Ultra benchmarks
|
||||
#
|
||||
# Shared constraints applied to ALL benchmarks in this file.
|
||||
constraints = [
|
||||
"All(MacOsBuild(=25D125))",
|
||||
"Hosts(=1)",
|
||||
"All(Chip(m3_ultra))",
|
||||
"All(GpuCores(=80))",
|
||||
]
|
||||
|
||||
[topology]
|
||||
type = "none"
|
||||
|
||||
# Default args merged into each benchmark's args (benchmark-level args win).
|
||||
[defaults]
|
||||
pp = [512, 2048, 8192, 16384]
|
||||
tg = 128
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-3bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-6bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
@@ -265,6 +265,7 @@
|
||||
|
||||
function handleSubmit() {
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
|
||||
if (isEditOnlyWithoutImage) return;
|
||||
|
||||
const content = message.trim();
|
||||
const files = [...uploadedFiles];
|
||||
@@ -289,7 +290,11 @@
|
||||
if (imageFile.preview) {
|
||||
editImage(content, imageFile.preview);
|
||||
}
|
||||
} else if (isImageModel() && content) {
|
||||
} else if (
|
||||
currentModel &&
|
||||
modelSupportsTextToImage(currentModel) &&
|
||||
content
|
||||
) {
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
|
||||
@@ -225,6 +225,7 @@
|
||||
}
|
||||
|
||||
function handleDeleteClick(messageId: string) {
|
||||
if (loading) return;
|
||||
deleteConfirmId = messageId;
|
||||
}
|
||||
|
||||
@@ -255,7 +256,7 @@
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
{#each messageList as message, i (message.id)}
|
||||
<div
|
||||
class="group flex {message.role === 'user'
|
||||
? 'justify-end'
|
||||
@@ -317,9 +318,11 @@
|
||||
<!-- Delete confirmation -->
|
||||
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
|
||||
<p class="text-xs text-red-400 mb-3">
|
||||
Delete this message{message.role === "user"
|
||||
? " and all responses after it"
|
||||
: ""}?
|
||||
{#if i === messageList.length - 1}
|
||||
Delete this message?
|
||||
{:else}
|
||||
Delete this message and all messages after it?
|
||||
{/if}
|
||||
</p>
|
||||
<div class="flex gap-2 justify-end">
|
||||
<button
|
||||
@@ -751,8 +754,13 @@
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
onclick={() => handleDeleteClick(message.id)}
|
||||
class="p-1.5 text-exo-light-gray hover:text-red-400 transition-colors rounded hover:bg-red-500/10 cursor-pointer"
|
||||
title="Delete message"
|
||||
disabled={loading}
|
||||
class="p-1.5 transition-colors rounded {loading
|
||||
? 'text-exo-light-gray/30 cursor-not-allowed'
|
||||
: 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}"
|
||||
title={loading
|
||||
? "Cannot delete while generating"
|
||||
: "Delete message"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
|
||||
@@ -185,11 +185,7 @@
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
} | null;
|
||||
nodes?: Record<string, NodeInfo>;
|
||||
sharding?: "Pipeline" | "Tensor";
|
||||
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
runtime?: "MlxRing" | "MlxJaccl";
|
||||
onLaunch?: () => void;
|
||||
tags?: string[];
|
||||
apiPreview?: PlacementPreview | null;
|
||||
@@ -348,7 +348,7 @@
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
|
||||
const isRdma = $derived(runtime === "MlxJaccl");
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
@@ -575,7 +575,7 @@
|
||||
>
|
||||
{runtime === "MlxRing"
|
||||
? "MLX Ring"
|
||||
: runtime === "MlxIbv" || runtime === "MlxJaccl"
|
||||
: runtime === "MlxJaccl"
|
||||
? "MLX RDMA"
|
||||
: runtime}
|
||||
</span>
|
||||
|
||||
@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
|
||||
export interface PlacementPreview {
|
||||
model_id: string;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
instance_meta: "MlxRing" | "MlxJaccl";
|
||||
instance: unknown | null;
|
||||
memory_delta_by_node: Record<string, number> | null;
|
||||
error: string | null;
|
||||
@@ -219,7 +219,6 @@ interface RawStateResponse {
|
||||
string,
|
||||
{
|
||||
MlxRingInstance?: Instance;
|
||||
MlxIbvInstance?: Instance;
|
||||
MlxJacclInstance?: Instance;
|
||||
}
|
||||
>;
|
||||
@@ -250,6 +249,20 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// MetaInstances (declarative instance constraints)
|
||||
metaInstances?: Record<string, MetaInstanceData>;
|
||||
}
|
||||
|
||||
export interface MetaInstanceData {
|
||||
metaInstanceId: string;
|
||||
modelId: string;
|
||||
sharding: string;
|
||||
instanceMeta: string;
|
||||
minNodes: number;
|
||||
nodeIds: string[] | null;
|
||||
placementError: string | null;
|
||||
consecutiveFailures: number;
|
||||
lastFailureError: string | null;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -537,6 +550,7 @@ class AppStore {
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
|
||||
metaInstances = $state<Record<string, MetaInstanceData>>({});
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderbolt = $state<
|
||||
Record<
|
||||
@@ -895,11 +909,7 @@ class AppStore {
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
@@ -1273,6 +1283,8 @@ class AppStore {
|
||||
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
|
||||
// RDMA ctl status per node
|
||||
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
|
||||
// MetaInstances
|
||||
this.metaInstances = data.metaInstances ?? {};
|
||||
// Thunderbolt bridge cycles
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
|
||||
// Thunderbolt bridge status per node
|
||||
@@ -3044,6 +3056,7 @@ export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const metaInstances = () => appStore.metaInstances;
|
||||
export const runners = () => appStore.runners;
|
||||
export const downloads = () => appStore.downloads;
|
||||
export const nodeDisk = () => appStore.nodeDisk;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -115,7 +115,7 @@
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
|
||||
10
nix/mlx.nix
10
nix/mlx.nix
@@ -41,16 +41,16 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.6"; in
|
||||
version = let v = "0.30.7.dev20260217+50487b41"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
|
||||
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.6; sys_platform == 'darwin'",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
@@ -64,6 +64,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
@@ -132,7 +133,7 @@ markers = [
|
||||
env = [
|
||||
"EXO_TESTS=1"
|
||||
]
|
||||
addopts = "-m 'not slow'"
|
||||
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
|
||||
filterwarnings = [
|
||||
"ignore:builtin type Swig:DeprecationWarning",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel.
|
||||
# Preserve passthru so mkVirtualEnv can resolve dependency groups.
|
||||
# Copy .pyi stub + py.typed marker so basedpyright can find the types.
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
@@ -22,6 +24,12 @@
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
passthru = prev.exo-pyo3-bindings.passthru or { };
|
||||
postInstall = ''
|
||||
local siteDir=$out/${final.python.sitePackages}/exo_pyo3_bindings
|
||||
cp ${inputs.self}/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi $siteDir/
|
||||
touch $siteDir/py.typed
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
@@ -29,17 +37,47 @@
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
} // lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
# Use our pure Nix-built MLX with Metal support (macOS only)
|
||||
mlx = self'.packages.mlx;
|
||||
};
|
||||
|
||||
# Additional overlay for Linux-specific fixes (type checking env).
|
||||
# Native wheels have shared lib dependencies we don't need at type-check time.
|
||||
linuxOverlay = final: prev:
|
||||
let
|
||||
ignoreMissing = drv: drv.overrideAttrs { autoPatchelfIgnoreMissingDeps = [ "*" ]; };
|
||||
nvidiaPackages = lib.filterAttrs (name: _: lib.hasPrefix "nvidia-" name) prev;
|
||||
in
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
|
||||
buildInputs = (old.buildInputs or [ ]) ++ [
|
||||
final.nvidia-cublas
|
||||
final.nvidia-cuda-nvrtc
|
||||
final.nvidia-cudnn-cu13
|
||||
final.nvidia-nccl-cu13
|
||||
];
|
||||
preFixup = ''
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cublas}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
|
||||
'';
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
);
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
@@ -48,16 +86,28 @@
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
|
||||
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
|
||||
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
|
||||
"lib/python3.13/site-packages/mlx*"
|
||||
"lib/python3.13/site-packages/nvidia*"
|
||||
];
|
||||
|
||||
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
);
|
||||
)).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
@@ -118,6 +168,21 @@
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
|
||||
# Hermetic basedpyright type checking
|
||||
typecheck = pkgs.runCommand "typecheck"
|
||||
{
|
||||
nativeBuildInputs = [
|
||||
testVenv
|
||||
pkgs.basedpyright
|
||||
];
|
||||
}
|
||||
''
|
||||
cd ${inputs.self}
|
||||
export HOME=$TMPDIR
|
||||
basedpyright --pythonpath ${testVenv}/bin/python
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-4bit"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 418621403136
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-8bit-MXFP8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 767273926656
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-MXFP4-Q8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "MXFP4-Q8"
|
||||
base_model = "GLM 5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405480321024
|
||||
12
resources/inference_model_cards/mlx-community--GLM-5.toml
Normal file
12
resources/inference_model_cards/mlx-community--GLM-5.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM 5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1487822475264
|
||||
@@ -14,6 +14,7 @@ from exo.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
@@ -63,6 +64,9 @@ class DownloadCoordinator:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
def _model_dir(self, model_id: ModelId) -> str:
|
||||
return str(EXO_MODELS_DIR / model_id.normalize())
|
||||
|
||||
async def _download_progress_callback(
|
||||
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
@@ -74,6 +78,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
@@ -93,6 +98,7 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
@@ -170,7 +176,11 @@ class DownloadCoordinator:
|
||||
return
|
||||
|
||||
# Emit pending status
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
@@ -184,6 +194,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
@@ -206,6 +217,7 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
@@ -219,6 +231,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=str(e),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(
|
||||
@@ -253,6 +266,7 @@ class DownloadCoordinator:
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
@@ -295,11 +309,28 @@ class DownloadCoordinator:
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
if (
|
||||
progress.downloaded_bytes.in_bytes
|
||||
>= progress.total_bytes.in_bytes
|
||||
> 0
|
||||
):
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
@@ -308,6 +339,9 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -136,6 +136,8 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -147,8 +149,6 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
@@ -254,7 +254,7 @@ def main():
|
||||
target = min(max(soft, 65535), hard)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
mp.set_start_method("spawn", force=True)
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
|
||||
@@ -17,6 +17,7 @@ from exo.shared.types.api import (
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -125,6 +126,8 @@ async def generate_chat_stream(
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
@@ -138,6 +141,8 @@ async def generate_chat_stream(
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
@@ -161,12 +166,15 @@ async def generate_chat_stream(
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
@@ -176,7 +184,9 @@ async def generate_chat_stream(
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ChatCompletionResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
@@ -184,6 +194,7 @@ async def collect_chat_response(
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -193,6 +204,8 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
@@ -223,7 +236,7 @@ async def collect_chat_response(
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
yield ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
@@ -241,4 +254,6 @@ async def collect_chat_response(
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
usage=last_usage,
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.api import FinishReason, Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
@@ -161,12 +161,14 @@ async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -174,6 +176,8 @@ async def collect_claude_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
tool_use_blocks.append(
|
||||
@@ -183,12 +187,10 @@ async def collect_claude_response(
|
||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -208,11 +210,11 @@ async def collect_claude_response(
|
||||
if not content:
|
||||
content.append(ClaudeTextBlock(text=""))
|
||||
|
||||
# Use actual usage data from stats if available
|
||||
input_tokens = last_stats.prompt_tokens if last_stats else 0
|
||||
output_tokens = last_stats.generation_tokens if last_stats else 0
|
||||
# Use actual usage data if available
|
||||
input_tokens = last_usage.prompt_tokens if last_usage else 0
|
||||
output_tokens = last_usage.completion_tokens if last_usage else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
yield ClaudeMessagesResponse(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=content,
|
||||
@@ -221,7 +223,8 @@ async def collect_claude_response(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
@@ -249,7 +252,7 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -257,8 +260,9 @@ async def generate_claude_stream(
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
|
||||
# Emit tool_use content blocks
|
||||
@@ -290,7 +294,6 @@ async def generate_claude_stream(
|
||||
continue
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
@@ -302,9 +305,9 @@ async def generate_claude_stream(
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
# Use actual token count from usage if available
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
@@ -121,13 +122,15 @@ async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ResponsesResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -135,32 +138,32 @@ async def collect_responses_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
function_call_items.append(
|
||||
ResponseFunctionCallItem(
|
||||
id=f"fc_{tool.id}",
|
||||
call_id=f"call_{tool.id}",
|
||||
id=tool.id,
|
||||
call_id=tool.id,
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create usage from stats if available
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
if last_usage is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
@@ -172,14 +175,15 @@ async def collect_responses_response(
|
||||
]
|
||||
output.extend(function_call_items)
|
||||
|
||||
return ResponsesResponse(
|
||||
yield ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
@@ -235,15 +239,16 @@ async def generate_responses_stream(
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
for tool in chunk.tool_calls:
|
||||
fc_id = f"fc_{tool.id}"
|
||||
call_id = f"call_{tool.id}"
|
||||
@@ -302,7 +307,6 @@ async def generate_responses_stream(
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
@@ -346,13 +350,13 @@ async def generate_responses_stream(
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
if last_usage is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
|
||||
@@ -71,8 +71,11 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
CreateMetaInstanceParams,
|
||||
CreateMetaInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DeleteMetaInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -115,8 +118,10 @@ from exo.shared.types.claude_api import (
|
||||
from exo.shared.types.commands import (
|
||||
Command,
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
@@ -125,10 +130,11 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -137,6 +143,7 @@ from exo.shared.types.events import (
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
@@ -275,6 +282,9 @@ class API:
|
||||
self.app.get("/instance/previews")(self.get_placement_previews)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/meta_instances")(self.list_meta_instances)
|
||||
self.app.post("/meta_instance")(self.create_meta_instance)
|
||||
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/models/add")(self.add_custom_model)
|
||||
@@ -304,12 +314,27 @@ class API:
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
model_card = await ModelCard.load(payload.model_id)
|
||||
command = PlaceInstance(
|
||||
model_card=await ModelCard.load(payload.model_id),
|
||||
model_card=model_card,
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
)
|
||||
|
||||
# Validate placement before sending — fail fast with a clear error
|
||||
# instead of silently dropping the command in the master.
|
||||
try:
|
||||
get_instance_placements(
|
||||
command,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
@@ -521,6 +546,44 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]:
|
||||
return dict(self.state.meta_instances)
|
||||
|
||||
async def create_meta_instance(
|
||||
self, payload: CreateMetaInstanceParams
|
||||
) -> CreateMetaInstanceResponse:
|
||||
meta_instance = MetaInstance(
|
||||
model_id=payload.model_id,
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
node_ids=payload.node_ids,
|
||||
)
|
||||
command = CreateMetaInstance(meta_instance=meta_instance)
|
||||
await self._send(command)
|
||||
return CreateMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
)
|
||||
|
||||
async def delete_meta_instance(
|
||||
self, meta_instance_id: MetaInstanceId
|
||||
) -> DeleteMetaInstanceResponse:
|
||||
meta = self.state.meta_instances.get(meta_instance_id)
|
||||
if not meta:
|
||||
raise HTTPException(status_code=404, detail="MetaInstance not found")
|
||||
|
||||
# Command processor handles cascade-deleting backing instances
|
||||
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
|
||||
await self._send(command)
|
||||
|
||||
return DeleteMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
async def _token_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
@@ -540,16 +603,14 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
@@ -644,11 +705,14 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
@@ -664,8 +728,7 @@ class API:
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
||||
return response
|
||||
return await self._collect_text_generation_with_stats(command.command_id)
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
"""Validate a text model exists and return the resolved model ID.
|
||||
@@ -883,6 +946,11 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -964,6 +1032,11 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
cancel_command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=cancel_command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -1221,12 +1294,15 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
@@ -1254,11 +1330,15 @@ class API:
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -12,11 +13,22 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.process_managers import ProcessManager
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
@@ -24,6 +36,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
@@ -35,10 +48,15 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
@@ -58,7 +76,8 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
|
||||
|
||||
@@ -82,16 +101,16 @@ class Master:
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
|
||||
self._process_managers: Sequence[ProcessManager] = [
|
||||
InstanceHealthReconciler(),
|
||||
NodeTimeoutReconciler(),
|
||||
MetaInstanceReconciler(),
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -100,15 +119,12 @@ class Master:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
tg.start_soon(self._reconcile)
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -279,7 +295,7 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
@@ -290,6 +306,86 @@ class Master:
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateMetaInstance():
|
||||
logger.info(
|
||||
f"Creating MetaInstance for {command.meta_instance.model_id}"
|
||||
f" (min_nodes={command.meta_instance.min_nodes},"
|
||||
f" sharding={command.meta_instance.sharding})"
|
||||
)
|
||||
# Apply immediately so self.state is fresh across
|
||||
# the await below and the reconciler won't race.
|
||||
await self._apply_and_broadcast(
|
||||
MetaInstanceCreated(meta_instance=command.meta_instance)
|
||||
)
|
||||
# Immediate placement attempt for responsiveness
|
||||
model_card = await ModelCard.load(
|
||||
command.meta_instance.model_id
|
||||
)
|
||||
# Re-check: reconciler may have satisfied it during the await
|
||||
meta_id = command.meta_instance.meta_instance_id
|
||||
still_unsatisfied = any(
|
||||
m.meta_instance_id == meta_id
|
||||
for m in find_unsatisfied_meta_instances(
|
||||
self.state.meta_instances,
|
||||
self.state.instances,
|
||||
self.state.topology,
|
||||
)
|
||||
)
|
||||
if still_unsatisfied:
|
||||
result = try_place_for_meta_instance(
|
||||
command.meta_instance,
|
||||
model_card,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
self.state.tasks,
|
||||
)
|
||||
generated_events.extend(result.events)
|
||||
if result.error is not None:
|
||||
generated_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
case DeleteMetaInstance():
|
||||
backing_count = sum(
|
||||
1
|
||||
for inst in self.state.instances.values()
|
||||
if inst.meta_instance_id == command.meta_instance_id
|
||||
)
|
||||
logger.info(
|
||||
f"Deleting MetaInstance {command.meta_instance_id}"
|
||||
f" (cascade-deleting {backing_count} backing instance(s))"
|
||||
)
|
||||
generated_events.append(
|
||||
MetaInstanceDeleted(
|
||||
meta_instance_id=command.meta_instance_id
|
||||
)
|
||||
)
|
||||
# Cascade-delete backing instances atomically,
|
||||
# cancelling any active tasks first.
|
||||
for iid, inst in self.state.instances.items():
|
||||
if inst.meta_instance_id == command.meta_instance_id:
|
||||
for task in self.state.tasks.values():
|
||||
if (
|
||||
task.instance_id == iid
|
||||
and task.task_status
|
||||
in (
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
)
|
||||
):
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
generated_events.append(
|
||||
InstanceDeleted(instance_id=iid)
|
||||
)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
command,
|
||||
@@ -299,7 +395,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -309,7 +405,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -319,6 +415,21 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
command.cancelled_command_id
|
||||
in self.command_task_mapping
|
||||
):
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.cancelled_command_id
|
||||
]
|
||||
)
|
||||
)
|
||||
del self.command_task_mapping[
|
||||
command.cancelled_command_id
|
||||
]
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -341,31 +452,32 @@ class Master:
|
||||
):
|
||||
await self._send_event(IndexedEvent(idx=i, event=event))
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
await self._apply_and_broadcast(event)
|
||||
except ValueError as e:
|
||||
logger.opt(exception=e).warning("Error in command processor")
|
||||
|
||||
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
|
||||
async def _plan(self) -> None:
|
||||
async def _apply_and_broadcast(self, event: Event) -> None:
|
||||
"""Apply event to state, persist to disk, and broadcast to workers.
|
||||
|
||||
State is updated synchronously (before any await), so callers can
|
||||
rely on ``self.state`` reflecting this event immediately after the
|
||||
call. Python's cooperative scheduling guarantees no interleaving
|
||||
between the state read and write.
|
||||
"""
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _reconcile(self) -> None:
|
||||
while True:
|
||||
# kill broken instances
|
||||
connected_node_ids = set(self.state.topology.list_nodes())
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
await self.event_sender.send(
|
||||
InstanceDeleted(instance_id=instance_id)
|
||||
)
|
||||
break
|
||||
|
||||
# time out dead nodes
|
||||
for node_id, time in self.state.last_seen.items():
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if now - time > timedelta(seconds=30):
|
||||
logger.info(f"Manually removing node {node_id} due to inactivity")
|
||||
await self.event_sender.send(NodeTimedOut(node_id=node_id))
|
||||
|
||||
await anyio.sleep(10)
|
||||
for pm in self._process_managers:
|
||||
events = await pm.reconcile(self.state)
|
||||
for event in events:
|
||||
await self._apply_and_broadcast(event)
|
||||
await anyio.sleep(1)
|
||||
|
||||
async def _event_processor(self) -> None:
|
||||
with self.local_event_receiver as local_events:
|
||||
@@ -383,32 +495,15 @@ class Master:
|
||||
await self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
if isinstance(event, JacclSideChannelData):
|
||||
await self._apply_and_broadcast(event)
|
||||
await self._handle_jaccl_side_channel(event)
|
||||
continue
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
event.when = str(datetime.now(tz=timezone.utc))
|
||||
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
local_index = 0
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
await self._apply_and_broadcast(event)
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
@@ -440,10 +535,49 @@ class Master:
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
all_trace_data.extend(trace_data)
|
||||
|
||||
await self.event_sender.send(
|
||||
await self._apply_and_broadcast(
|
||||
TracesMerged(task_id=task_id, traces=all_trace_data)
|
||||
)
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
|
||||
"""Accumulate SideChannel contributions; when all runners for an instance
|
||||
have submitted for the same sequence, emit JacclSideChannelGathered."""
|
||||
iid = event.instance_id
|
||||
seq = event.sequence
|
||||
|
||||
if iid not in self._jaccl_pending:
|
||||
self._jaccl_pending[iid] = {}
|
||||
if seq not in self._jaccl_pending[iid]:
|
||||
self._jaccl_pending[iid][seq] = {}
|
||||
self._jaccl_pending[iid][seq][event.runner_id] = event.data
|
||||
|
||||
instance = self.state.instances.get(iid)
|
||||
if instance is None:
|
||||
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
|
||||
return
|
||||
|
||||
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
|
||||
submitted = set(self._jaccl_pending[iid][seq].keys())
|
||||
|
||||
logger.info(
|
||||
f"JACCL side channel: instance={iid} seq={seq} "
|
||||
f"submitted={len(submitted)}/{len(expected_runners)}"
|
||||
)
|
||||
|
||||
if submitted >= expected_runners:
|
||||
gathered = dict(self._jaccl_pending[iid][seq])
|
||||
del self._jaccl_pending[iid][seq]
|
||||
if not self._jaccl_pending[iid]:
|
||||
del self._jaccl_pending[iid]
|
||||
|
||||
await self._apply_and_broadcast(
|
||||
JacclSideChannelGathered(
|
||||
instance_id=iid,
|
||||
sequence=seq,
|
||||
gathered_data=gathered,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Sequence
|
||||
from exo.master.placement_utils import (
|
||||
Cycle,
|
||||
filter_cycles_by_memory,
|
||||
get_largest_cycles,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
@@ -22,9 +22,15 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
@@ -100,23 +106,27 @@ def place_instance(
|
||||
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
|
||||
)
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
largest_cycles = get_largest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_rdma_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
largest_rdma_cycles = [
|
||||
cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl:
|
||||
if not largest_rdma_cycles:
|
||||
raise ValueError(
|
||||
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
|
||||
)
|
||||
largest_cycles = largest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
for cycle in largest_cycles
|
||||
if any(topology.node_is_leaf(node_id) for node_id in cycle)
|
||||
]
|
||||
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(node_memory[node_id].ram_available for node_id in cycle),
|
||||
start=Memory(),
|
||||
@@ -186,6 +196,7 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -201,6 +212,18 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -37,11 +37,11 @@ def filter_cycles_by_memory(
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(
|
||||
def get_largest_cycles(
|
||||
cycles: list[Cycle],
|
||||
) -> list[Cycle]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
max_nodes = max(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == max_nodes]
|
||||
|
||||
|
||||
def allocate_layers_proportionally(
|
||||
|
||||
12
src/exo/master/process_managers/__init__.py
Normal file
12
src/exo/master/process_managers/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.state import State
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProcessManager(Protocol):
|
||||
"""A reconciliation step that examines state and returns corrective events."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]: ...
|
||||
62
src/exo/master/process_managers/instance_health.py
Normal file
62
src/exo/master/process_managers/instance_health.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
|
||||
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
|
||||
from exo.shared.types.state import State
|
||||
|
||||
MAX_INSTANCE_RETRIES = 3
|
||||
|
||||
|
||||
@final
|
||||
class InstanceHealthReconciler:
|
||||
"""Delete instances whose network connections are broken or whose runners have all failed."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
for instance_id, instance in state.instances.items():
|
||||
if not instance_connections_healthy(instance, state.topology):
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error="Network connection lost",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
is_failed, error_message = instance_runners_failed(
|
||||
instance, state.runners, state.node_identities
|
||||
)
|
||||
if is_failed:
|
||||
# Retry within the same instance if backed by a MetaInstance
|
||||
mid = instance.meta_instance_id
|
||||
mi = state.meta_instances.get(mid) if mid else None
|
||||
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
|
||||
logger.info(
|
||||
f"Instance {instance_id} failed (attempt"
|
||||
f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES}),"
|
||||
f" retrying: {error_message}"
|
||||
)
|
||||
events.append(
|
||||
InstanceRetrying(
|
||||
instance_id=instance_id,
|
||||
meta_instance_id=mid,
|
||||
failure_error=error_message or "Runner failed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
if mid and mi:
|
||||
logger.warning(
|
||||
f"Instance {instance_id} exceeded retry limit"
|
||||
f" ({MAX_INSTANCE_RETRIES}), deleting:"
|
||||
f" {error_message}"
|
||||
)
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error=error_message,
|
||||
)
|
||||
)
|
||||
return events
|
||||
92
src/exo/master/process_managers/meta_instance.py
Normal file
92
src/exo/master/process_managers/meta_instance.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
|
||||
MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstanceReconciler:
|
||||
"""Place instances for unsatisfied MetaInstances."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
all_events: list[Event] = []
|
||||
# Local copy for intermediate tracking — so placement of B
|
||||
# sees A's instance and doesn't double-place on same resources.
|
||||
current_instances: dict[InstanceId, Instance] = dict(state.instances)
|
||||
|
||||
unsatisfied = find_unsatisfied_meta_instances(
|
||||
state.meta_instances,
|
||||
current_instances,
|
||||
state.topology,
|
||||
)
|
||||
for meta_instance in unsatisfied:
|
||||
try:
|
||||
with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS):
|
||||
model_card = await ModelCard.load(meta_instance.model_id)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle"
|
||||
)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"ModelCard.load failed for {meta_instance.model_id}: {exc}"
|
||||
)
|
||||
error = f"Failed to load model card: {exc}"
|
||||
if meta_instance.placement_error != error:
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=error,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
result = try_place_for_meta_instance(
|
||||
meta_instance,
|
||||
model_card,
|
||||
state.topology,
|
||||
current_instances,
|
||||
state.node_memory,
|
||||
state.node_network,
|
||||
state.tasks,
|
||||
)
|
||||
# Update local instance map so next placement sees this one
|
||||
for event in result.events:
|
||||
if isinstance(event, InstanceCreated):
|
||||
logger.info(
|
||||
f"MetaInstance reconciler placed instance"
|
||||
f" {event.instance.instance_id} for"
|
||||
f" {meta_instance.model_id}"
|
||||
)
|
||||
current_instances[event.instance.instance_id] = event.instance
|
||||
all_events.extend(result.events)
|
||||
|
||||
# Emit placement failure if error differs from what's already in state
|
||||
if (
|
||||
result.error is not None
|
||||
and meta_instance.placement_error != result.error
|
||||
):
|
||||
logger.warning(
|
||||
f"MetaInstance placement failed for"
|
||||
f" {meta_instance.model_id}: {result.error}"
|
||||
)
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
return all_events
|
||||
27
src/exo/master/process_managers/node_timeout.py
Normal file
27
src/exo/master/process_managers/node_timeout.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import Event, NodeTimedOut
|
||||
from exo.shared.types.state import State
|
||||
|
||||
_DEFAULT_TIMEOUT = timedelta(seconds=30)
|
||||
|
||||
|
||||
@final
|
||||
class NodeTimeoutReconciler:
|
||||
"""Time out nodes that haven't been seen recently."""
|
||||
|
||||
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
|
||||
self.timeout = timeout
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
events: list[Event] = []
|
||||
for node_id, last_seen in state.last_seen.items():
|
||||
if now - last_seen > self.timeout:
|
||||
logger.info(f"Removing node {node_id} due to inactivity")
|
||||
events.append(NodeTimedOut(node_id=node_id))
|
||||
return events
|
||||
244
src/exo/master/reconcile.py
Normal file
244
src/exo/master/reconcile.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import NamedTuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import get_transition_events, place_instance
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
BaseInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
)
|
||||
|
||||
|
||||
class PlacementResult(NamedTuple):
|
||||
"""Result of a placement attempt: events to apply and optional error reason."""
|
||||
|
||||
events: Sequence[Event]
|
||||
error: str | None
|
||||
|
||||
|
||||
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
|
||||
"""Reconstruct ring order from shard device_rank."""
|
||||
node_ranks: list[tuple[NodeId, int]] = []
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
shard = instance.shard_assignments.runner_to_shard[runner_id]
|
||||
node_ranks.append((node_id, shard.device_rank))
|
||||
node_ranks.sort(key=lambda x: x[1])
|
||||
return [node_id for node_id, _ in node_ranks]
|
||||
|
||||
|
||||
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific IPs used by a ring instance still exist in the topology."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for node in ring:
|
||||
hosts = instance.hosts_by_node[node]
|
||||
for idx in range(n):
|
||||
host = hosts[idx]
|
||||
if host.ip in ("0.0.0.0", "198.51.100.1"):
|
||||
continue # self or placeholder
|
||||
# Real connection: node → ring[idx]. Check specific IP.
|
||||
connections = topology.get_all_connections_between(node, ring[idx])
|
||||
if not any(
|
||||
isinstance(c, SocketConnection)
|
||||
and c.sink_multiaddr.ip_address == host.ip
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
iface = instance.jaccl_devices[i][j]
|
||||
if iface is None:
|
||||
continue
|
||||
connections = topology.get_all_connections_between(ring[i], ring[j])
|
||||
if not any(
|
||||
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
|
||||
"""Check that an instance's nodes and specific connections are still in the topology."""
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
if not all(topology.contains_node(n) for n in instance_nodes):
|
||||
return False
|
||||
if len(instance_nodes) <= 1:
|
||||
return True
|
||||
match instance:
|
||||
case MlxRingInstance():
|
||||
return _ring_connections_healthy(instance, topology)
|
||||
case MlxJacclInstance():
|
||||
return _jaccl_connections_healthy(instance, topology)
|
||||
|
||||
|
||||
def instance_runners_failed(
|
||||
instance: Instance,
|
||||
runners: Mapping[RunnerId, RunnerStatus],
|
||||
node_identities: Mapping[NodeId, NodeIdentity],
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if an instance's runners have all reached terminal failure states.
|
||||
|
||||
Returns ``(True, error_message)`` when ALL runners are terminal
|
||||
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
|
||||
|
||||
Returns ``(False, None)`` when runners are still active, haven't reported
|
||||
yet, or all gracefully shut down (no ``RunnerFailed``).
|
||||
"""
|
||||
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
|
||||
|
||||
if not instance_runner_ids:
|
||||
return False, None
|
||||
|
||||
# Build reverse mapping: runner_id -> node_id
|
||||
runner_to_node: dict[RunnerId, NodeId] = {
|
||||
runner_id: node_id
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
}
|
||||
|
||||
has_any_failed = False
|
||||
error_messages: list[str] = []
|
||||
|
||||
for runner_id in instance_runner_ids:
|
||||
status = runners.get(runner_id)
|
||||
if status is None:
|
||||
# Runner hasn't reported yet — instance is still starting
|
||||
return False, None
|
||||
if isinstance(status, RunnerFailed):
|
||||
has_any_failed = True
|
||||
if status.error_message:
|
||||
node_id = runner_to_node.get(runner_id)
|
||||
name = (
|
||||
node_identities[node_id].friendly_name
|
||||
if node_id and node_id in node_identities
|
||||
else node_id or "unknown"
|
||||
)
|
||||
error_messages.append(f"{name}: {status.error_message}")
|
||||
elif isinstance(status, RunnerShutdown):
|
||||
pass # Terminal but not a failure indicator on its own
|
||||
else:
|
||||
# Runner is still active (connecting, loading, running, etc.)
|
||||
return False, None
|
||||
|
||||
if has_any_failed:
|
||||
return True, "; ".join(error_messages) if error_messages else "Runner failed"
|
||||
|
||||
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
|
||||
return False, None
|
||||
|
||||
|
||||
def instance_satisfies_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
instance: Instance,
|
||||
) -> bool:
|
||||
"""Check if a single instance satisfies a meta-instance's constraints.
|
||||
|
||||
This is a pure constraint check (model, min_nodes, node_ids).
|
||||
Use ``instance_connections_healthy`` separately for topology health.
|
||||
"""
|
||||
if instance.shard_assignments.model_id != meta_instance.model_id:
|
||||
return False
|
||||
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
|
||||
if len(instance_nodes) < meta_instance.min_nodes:
|
||||
return False
|
||||
|
||||
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
|
||||
instance_nodes
|
||||
)
|
||||
|
||||
|
||||
def find_unsatisfied_meta_instances(
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
topology: Topology,
|
||||
) -> Sequence[MetaInstance]:
|
||||
"""Return meta-instances that have no healthy backing instance."""
|
||||
unsatisfied: list[MetaInstance] = []
|
||||
for meta_id, meta_instance in meta_instances.items():
|
||||
has_healthy_backing = any(
|
||||
instance.meta_instance_id == meta_id
|
||||
and instance_connections_healthy(instance, topology)
|
||||
for instance in instances.values()
|
||||
)
|
||||
if not has_healthy_backing:
|
||||
unsatisfied.append(meta_instance)
|
||||
return unsatisfied
|
||||
|
||||
|
||||
def try_place_for_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
model_card: ModelCard,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> PlacementResult:
|
||||
"""Try to place an instance satisfying the meta-instance constraints.
|
||||
|
||||
Returns a :class:`PlacementResult` with events on success, or an error
|
||||
reason on failure.
|
||||
"""
|
||||
command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=meta_instance.sharding,
|
||||
instance_meta=meta_instance.instance_meta,
|
||||
min_nodes=meta_instance.min_nodes,
|
||||
)
|
||||
try:
|
||||
target_instances = place_instance(
|
||||
command,
|
||||
topology,
|
||||
current_instances,
|
||||
node_memory,
|
||||
node_network,
|
||||
required_nodes=(
|
||||
set(meta_instance.node_ids) if meta_instance.node_ids else None
|
||||
),
|
||||
)
|
||||
# Tag the new instance with meta_instance_id
|
||||
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
|
||||
if new_instance_ids:
|
||||
new_id = next(iter(new_instance_ids))
|
||||
target_instances[new_id] = target_instances[new_id].model_copy(
|
||||
update={"meta_instance_id": meta_instance.meta_instance_id}
|
||||
)
|
||||
return PlacementResult(
|
||||
events=list(
|
||||
get_transition_events(current_instances, target_instances, tasks)
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
|
||||
)
|
||||
return PlacementResult(events=[], error=str(e))
|
||||
@@ -4,7 +4,11 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
|
||||
from exo.master.adapters.claude import (
|
||||
ClaudeMessagesResponse,
|
||||
collect_claude_response,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
@@ -17,6 +21,18 @@ async def _chunks_to_stream(
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _collect_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Helper to consume the async generator and parse the JSON response."""
|
||||
parts: list[str] = []
|
||||
async for part in collect_claude_response(command_id, model, chunk_stream):
|
||||
parts.append(part)
|
||||
return ClaudeMessagesResponse.model_validate_json("".join(parts))
|
||||
|
||||
|
||||
MODEL = ModelId("test-model")
|
||||
COMMAND_ID = CommandId("cmd_test123")
|
||||
|
||||
@@ -47,7 +63,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -77,7 +93,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -102,7 +118,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -116,7 +132,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
|
||||
async def test_no_content_produces_empty_text_block(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
assert len(response.content) == 1
|
||||
|
||||
778
src/exo/master/tests/test_meta_instance_edge_cases.py
Normal file
778
src/exo/master/tests/test_meta_instance_edge_cases.py
Normal file
@@ -0,0 +1,778 @@
|
||||
"""Edge-case and regression tests for MetaInstance lifecycle, concurrent operations, and error handling."""
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.process_managers.instance_health import (
|
||||
MAX_INSTANCE_RETRIES,
|
||||
InstanceHealthReconciler,
|
||||
)
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NodeIdentity
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import LoadModel, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerReady,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
# --- Helpers (copied from test_reconcile.py for independence) ---
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
consecutive_failures: int = 0,
|
||||
last_failure_error: str | None = None,
|
||||
placement_error: str | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
consecutive_failures=consecutive_failures,
|
||||
last_failure_error=last_failure_error,
|
||||
placement_error=placement_error,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 1. MetaInstance lifecycle edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_meta_instance_model_is_frozen():
|
||||
"""MetaInstance should be immutable (frozen model)."""
|
||||
meta = _meta_instance()
|
||||
try:
|
||||
meta.model_id = ModelId("something-else")
|
||||
raise AssertionError("Should have raised")
|
||||
except Exception:
|
||||
pass # Expected — frozen model
|
||||
|
||||
|
||||
def test_meta_instance_created_then_deleted_roundtrip():
|
||||
"""Create and delete a MetaInstance through apply — state should be clean."""
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta))
|
||||
)
|
||||
assert meta.meta_instance_id in state.meta_instances
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=1, event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
),
|
||||
)
|
||||
assert meta.meta_instance_id not in state.meta_instances
|
||||
assert len(state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_delete_nonexistent_meta_instance_is_safe():
|
||||
"""Deleting a MetaInstance that doesn't exist should not crash."""
|
||||
state = State()
|
||||
event = MetaInstanceDeleted(meta_instance_id=MetaInstanceId("nonexistent"))
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_placement_failed_for_nonexistent_meta_instance_is_safe():
|
||||
"""MetaInstancePlacementFailed for unknown ID should not crash."""
|
||||
state = State()
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=MetaInstanceId("nonexistent"),
|
||||
reason="test",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_multiple_meta_instances_for_same_model():
|
||||
"""Multiple MetaInstances for the same model are tracked independently."""
|
||||
state = State()
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta_a))
|
||||
)
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=1, event=MetaInstanceCreated(meta_instance=meta_b))
|
||||
)
|
||||
assert len(state.meta_instances) == 2
|
||||
assert meta_a.meta_instance_id in state.meta_instances
|
||||
assert meta_b.meta_instance_id in state.meta_instances
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 2. Retry logic edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_retry_counter_resets_on_successful_instance_creation():
|
||||
"""When a new instance is created for a meta-instance, failures should reset."""
|
||||
meta = _meta_instance(consecutive_failures=2, last_failure_error="old")
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
# last_failure_error is preserved (for UI display)
|
||||
assert mi.last_failure_error == "old"
|
||||
|
||||
|
||||
async def test_retry_count_increments_through_full_cycle():
|
||||
"""Walk through MAX_INSTANCE_RETRIES worth of retries, then verify delete."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=topology,
|
||||
)
|
||||
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
for idx, i in enumerate(range(MAX_INSTANCE_RETRIES)):
|
||||
# Simulate runners failing
|
||||
state_with_runners = state.model_copy(
|
||||
update={"runners": {runner_ids[0]: RunnerFailed(error_message=f"fail-{i}")}}
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state_with_runners)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying), f"iteration {i}"
|
||||
state = apply(state, IndexedEvent(idx=idx, event=events[0]))
|
||||
|
||||
# After MAX_INSTANCE_RETRIES retries, failure counter should be at max
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == MAX_INSTANCE_RETRIES
|
||||
|
||||
# Next failure should result in deletion
|
||||
state_with_runners = state.model_copy(
|
||||
update={"runners": {runner_ids[0]: RunnerFailed(error_message="final")}}
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state_with_runners)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_respects_exact_limit():
|
||||
"""At exactly MAX_INSTANCE_RETRIES, reconciler should delete, not retry."""
|
||||
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_at_limit_minus_one_retries():
|
||||
"""At MAX_INSTANCE_RETRIES - 1, reconciler should still retry."""
|
||||
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES - 1)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 3. Error handling edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_runners_failed_with_empty_error_message():
|
||||
"""RunnerFailed with empty error_message should still report as failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
# Empty error message means we get the fallback
|
||||
assert error == "Runner failed"
|
||||
|
||||
|
||||
def test_runners_failed_with_none_error_message():
|
||||
"""RunnerFailed with None error_message should still report as failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message=None)
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error == "Runner failed"
|
||||
|
||||
|
||||
def test_runners_failed_collects_all_error_messages():
|
||||
"""With multiple failed runners, all error messages should be collected."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM on GPU 0"),
|
||||
runner_ids[1]: RunnerFailed(error_message="OOM on GPU 1"),
|
||||
runner_ids[2]: RunnerFailed(error_message="OOM on GPU 2"),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM on GPU 0" in error
|
||||
assert "OOM on GPU 1" in error
|
||||
assert "OOM on GPU 2" in error
|
||||
|
||||
|
||||
def test_runners_failed_includes_friendly_name():
|
||||
"""Error messages should include node friendly names when available."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
node_id = NodeId("node-a")
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {runner_ids[0]: RunnerFailed(error_message="OOM")}
|
||||
identities = {node_id: NodeIdentity(friendly_name="My Mac Studio")}
|
||||
is_failed, error = instance_runners_failed(inst, runners, identities)
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "My Mac Studio" in error
|
||||
|
||||
|
||||
def test_instance_retrying_for_missing_instance_is_safe():
|
||||
"""InstanceRetrying for an instance not in state should not crash.
|
||||
|
||||
NOTE: When the instance is missing, the handler returns early WITHOUT
|
||||
incrementing the MetaInstance failure counter. This means stale retry
|
||||
events for already-deleted instances are silently dropped. This is
|
||||
acceptable since the InstanceDeleted handler already increments failures.
|
||||
"""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceRetrying(
|
||||
instance_id=InstanceId("nonexistent"),
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Does not crash, but failure count is NOT incremented (early return)
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 4. Backward compatibility
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_instance_without_meta_instance_id_works():
|
||||
"""Instances created without meta_instance_id should still function normally."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert inst.meta_instance_id is None
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_instance_deleted_without_meta_does_not_affect_meta_instances():
|
||||
"""Deleting an instance without meta_instance_id should not affect meta_instances."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"]) # no meta_instance_id
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0 # unchanged
|
||||
|
||||
|
||||
def test_satisfies_ignores_meta_instance_id_binding():
|
||||
"""instance_satisfies_meta_instance checks constraints only, not binding."""
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"]) # no meta_instance_id set
|
||||
# Should match on constraints (model, min_nodes) regardless of binding
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_find_unsatisfied_uses_binding_not_constraints():
|
||||
"""find_unsatisfied checks meta_instance_id binding, not just constraint matching."""
|
||||
meta = _meta_instance()
|
||||
# Instance matches constraints but is NOT bound to this meta_instance
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {iid: inst}, topology
|
||||
)
|
||||
# Should be unsatisfied because instance.meta_instance_id != meta.meta_instance_id
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 5. Concurrent / multi-instance scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def test_health_reconciler_handles_multiple_failing_instances():
|
||||
"""Multiple instances failing simultaneously should each get their own event."""
|
||||
meta_a = _meta_instance()
|
||||
meta_b = _meta_instance()
|
||||
iid_a, inst_a = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
iid_b, inst_b = _instance(
|
||||
node_ids=["node-b"], meta_instance_id=meta_b.meta_instance_id
|
||||
)
|
||||
runner_ids_a = list(inst_a.shard_assignments.node_to_runner.values())
|
||||
runner_ids_b = list(inst_b.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
instances={iid_a: inst_a, iid_b: inst_b},
|
||||
runners={
|
||||
runner_ids_a[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids_b[0]: RunnerFailed(error_message="OOM"),
|
||||
},
|
||||
topology=_topology("node-a", "node-b"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 2
|
||||
# Both should be InstanceRetrying since failures < MAX
|
||||
assert all(isinstance(e, InstanceRetrying) for e in events)
|
||||
instance_ids = {e.instance_id for e in events} # type: ignore[union-attr]
|
||||
assert instance_ids == {iid_a, iid_b}
|
||||
|
||||
|
||||
async def test_health_reconciler_mixed_healthy_and_failing():
|
||||
"""Only failing instances should produce events; healthy ones should not."""
|
||||
meta_healthy = _meta_instance()
|
||||
meta_failing = _meta_instance()
|
||||
iid_h, inst_h = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta_healthy.meta_instance_id
|
||||
)
|
||||
iid_f, inst_f = _instance(
|
||||
node_ids=["node-b"], meta_instance_id=meta_failing.meta_instance_id
|
||||
)
|
||||
runner_ids_h = list(inst_h.shard_assignments.node_to_runner.values())
|
||||
runner_ids_f = list(inst_f.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_healthy.meta_instance_id: meta_healthy,
|
||||
meta_failing.meta_instance_id: meta_failing,
|
||||
},
|
||||
instances={iid_h: inst_h, iid_f: inst_f},
|
||||
runners={
|
||||
runner_ids_h[0]: RunnerReady(),
|
||||
runner_ids_f[0]: RunnerFailed(error_message="crash"),
|
||||
},
|
||||
topology=_topology("node-a", "node-b"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid_f
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_empty_state():
|
||||
"""MetaInstanceReconciler with no meta_instances should produce no events."""
|
||||
state = State()
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 6. Placement error tracking
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_placement_failed_sets_error():
|
||||
"""MetaInstancePlacementFailed should set placement_error on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
reason="Not enough memory",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.placement_error == "Not enough memory"
|
||||
|
||||
|
||||
def test_instance_created_clears_placement_error():
|
||||
"""InstanceCreated should clear placement_error on the MetaInstance."""
|
||||
meta = _meta_instance(placement_error="Not enough memory")
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
def test_placement_error_does_not_increment_failures():
|
||||
"""Placement failures should only set placement_error, not increment consecutive_failures."""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
reason="No resources",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.placement_error == "No resources"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 7. State serialization roundtrip
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_state_with_meta_instances_serializes():
|
||||
"""State with meta_instances should serialize and deserialize correctly."""
|
||||
meta = _meta_instance(consecutive_failures=2, last_failure_error="test")
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
json_str = state.model_dump_json()
|
||||
restored = State.model_validate_json(json_str)
|
||||
assert meta.meta_instance_id in restored.meta_instances
|
||||
mi = restored.meta_instances[meta.meta_instance_id]
|
||||
assert mi.model_id == meta.model_id
|
||||
assert mi.consecutive_failures == 2
|
||||
assert mi.last_failure_error == "test"
|
||||
assert iid in restored.instances
|
||||
assert restored.instances[iid].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. MetaInstanceReconciler error handling
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_model_load_error_emits_placement_failed(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""When ModelCard.load raises, reconciler emits MetaInstancePlacementFailed."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta = _meta_instance()
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
async def _failing_load(_model_id: ModelId) -> ModelCard:
|
||||
raise RuntimeError("Network error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
|
||||
assert len(placement_failed) == 1
|
||||
assert "Failed to load model card" in placement_failed[0].reason
|
||||
assert meta.meta_instance_id == placement_failed[0].meta_instance_id
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_model_load_error_skips_dedup(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""When ModelCard.load error matches existing placement_error, no duplicate event."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta = _meta_instance(placement_error="Failed to load model card: Network error")
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
async def _failing_load(_model_id: ModelId) -> ModelCard:
|
||||
raise RuntimeError("Network error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
# Error matches existing placement_error, so no duplicate event emitted
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_continues_after_error(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""Reconciler should continue to next meta-instance after one fails to load."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta_a = _meta_instance(model_id="org/model-a")
|
||||
meta_b = _meta_instance(model_id="org/model-b")
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _load_second_fails(model_id: ModelId) -> ModelCard:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError(f"Cannot load {model_id}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_load_second_fails)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
# Both meta-instances should have been attempted (not short-circuited)
|
||||
assert call_count == 2
|
||||
# Both should have placement failed events
|
||||
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
|
||||
assert len(placement_failed) == 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. Cascade delete with task cancellation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_cascade_delete_cancels_active_tasks():
|
||||
"""Deleting a MetaInstance should cancel tasks on backing instances.
|
||||
|
||||
Regression test: previously, cascade-deleting backing instances via
|
||||
DeleteMetaInstance did not emit TaskStatusUpdated(Cancelled) for active
|
||||
tasks, leaving orphaned task references in state.
|
||||
"""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
task_id = TaskId()
|
||||
task = LoadModel(task_id=task_id, instance_id=iid, task_status=TaskStatus.Running)
|
||||
|
||||
# Build state with meta-instance, backing instance, and active task
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
tasks={task_id: task},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
|
||||
# Simulate the cascade-delete event sequence produced by main.py:
|
||||
# 1. MetaInstanceDeleted
|
||||
# 2. TaskStatusUpdated(Cancelled) for active tasks
|
||||
# 3. InstanceDeleted
|
||||
idx = 0
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=idx,
|
||||
event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id),
|
||||
),
|
||||
)
|
||||
idx += 1
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=idx,
|
||||
event=TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Cancelled),
|
||||
),
|
||||
)
|
||||
idx += 1
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(idx=idx, event=InstanceDeleted(instance_id=iid)),
|
||||
)
|
||||
|
||||
# Verify everything is cleaned up
|
||||
assert len(state.meta_instances) == 0
|
||||
assert len(state.instances) == 0
|
||||
assert state.tasks[task_id].task_status == TaskStatus.Cancelled
|
||||
|
||||
|
||||
def test_cascade_delete_skips_completed_tasks():
|
||||
"""Cascade delete should only cancel Pending/Running tasks, not completed ones."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
|
||||
running_task_id = TaskId()
|
||||
completed_task_id = TaskId()
|
||||
running_task = LoadModel(
|
||||
task_id=running_task_id, instance_id=iid, task_status=TaskStatus.Running
|
||||
)
|
||||
completed_task = LoadModel(
|
||||
task_id=completed_task_id, instance_id=iid, task_status=TaskStatus.Complete
|
||||
)
|
||||
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
tasks={running_task_id: running_task, completed_task_id: completed_task},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
|
||||
# Only the running task should be cancelled — we verify the logic pattern
|
||||
# by checking which tasks are Pending or Running
|
||||
active_tasks = [
|
||||
t
|
||||
for t in state.tasks.values()
|
||||
if t.instance_id == iid
|
||||
and t.task_status in (TaskStatus.Pending, TaskStatus.Running)
|
||||
]
|
||||
assert len(active_tasks) == 1
|
||||
assert active_tasks[0].task_id == running_task_id
|
||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
||||
target_instances = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
|
||||
@@ -3,10 +3,10 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_largest_cycles,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_node_memory,
|
||||
@@ -143,7 +143,7 @@ def test_filter_multiple_cycles_by_memory():
|
||||
}
|
||||
|
||||
|
||||
def test_get_smallest_cycles():
|
||||
def test_get_largest_cycles():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
@@ -175,12 +175,12 @@ def test_get_smallest_cycles():
|
||||
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
largest_cycles = get_largest_cycles(cycles)
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
assert len(smallest_cycles[0]) == 2
|
||||
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
assert len(largest_cycles) == 1
|
||||
assert len(largest_cycles[0]) == 3
|
||||
assert set(n for n in largest_cycles[0]) == {node_a_id, node_b_id, node_c_id}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
742
src/exo/master/tests/test_reconcile.py
Normal file
742
src/exo/master/tests/test_reconcile.py
Normal file
@@ -0,0 +1,742 @@
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
|
||||
|
||||
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
|
||||
between consecutive nodes (including wrap-around).
|
||||
"""
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
# Build hosts_by_node with IPs matching _topology() convention:
|
||||
# node at index idx has IP 10.0.0.{idx+1}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# --- instance_satisfies_meta_instance (pure constraint matching) ---
|
||||
|
||||
|
||||
def test_satisfies_matching_model():
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_not_satisfies_wrong_model():
|
||||
meta = _meta_instance("test-org/model-a")
|
||||
_, inst = _instance("test-org/model-b")
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_missing_required_node():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-c")])
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_fewer_than_min_nodes():
|
||||
meta = _meta_instance(min_nodes=3)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_satisfies_with_node_ids_specified():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
# --- instance_connections_healthy ---
|
||||
|
||||
|
||||
def test_healthy_single_node_present():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_single_node_missing():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = Topology() # empty
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_two_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_two_node_edge_removed():
|
||||
"""Nodes present but edge removed — ring broken."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", connect=False)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_two_node_ip_changed():
|
||||
"""Edge exists but with a different IP than instance was configured with."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
# Build topology with different IPs than _instance() expects
|
||||
topology = Topology()
|
||||
topology.add_node(NodeId("node-a"))
|
||||
topology.add_node(NodeId("node-b"))
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-a"),
|
||||
sink=NodeId("node-b"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-b"),
|
||||
sink=NodeId("node-a"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_three_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_three_node_one_edge_removed():
|
||||
"""Remove one edge from a three-node ring — instance unhealthy."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
# Build topology with one direction of one edge missing
|
||||
topology = Topology()
|
||||
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
|
||||
for n in nodes:
|
||||
topology.add_node(n)
|
||||
# Add all edges except node-a → node-b
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[1],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[0],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
# Missing: node-a → node-b (ip 10.0.0.2)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_node_missing_from_topology():
|
||||
"""Instance has a node that's not in the topology at all."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a") # node-b not present
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_extra_nodes_in_topology():
|
||||
"""Extra nodes in topology don't affect instance health."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
# --- find_unsatisfied_meta_instances ---
|
||||
|
||||
|
||||
def test_unsatisfied_no_meta_instances():
|
||||
result = find_unsatisfied_meta_instances({}, {}, Topology())
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_satisfied():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_not_satisfied():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
id_a, inst_a = _instance("test-org/model-y")
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_mix():
|
||||
meta_satisfied = _meta_instance("test-org/model-a")
|
||||
meta_unsatisfied = _meta_instance("test-org/model-b")
|
||||
id_a, inst_a = _instance(
|
||||
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_satisfied.meta_instance_id: meta_satisfied,
|
||||
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
|
||||
},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_unsatisfied]
|
||||
|
||||
|
||||
def test_unsatisfied_node_disconnect():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a") # node-b disconnected
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_edge_break():
|
||||
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_idempotent():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
topology = _topology("node-a")
|
||||
meta_instances = {meta.meta_instance_id: meta}
|
||||
instances: dict[InstanceId, MlxRingInstance] = {}
|
||||
result_1 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
result_2 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
assert result_1 == result_2
|
||||
|
||||
|
||||
def test_unsatisfied_exclusive_binding():
|
||||
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
id_inst, inst = _instance(
|
||||
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
{id_inst: inst},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_b]
|
||||
|
||||
|
||||
# --- apply handlers ---
|
||||
|
||||
|
||||
def test_apply_meta_instance_created():
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
event = MetaInstanceCreated(meta_instance=meta)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id in new_state.meta_instances
|
||||
assert new_state.meta_instances[meta.meta_instance_id] == meta
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted():
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted_clears_failure_info():
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
|
||||
)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
# --- instance_runners_failed ---
|
||||
|
||||
|
||||
def test_runners_failed_all_failed():
|
||||
"""All runners in RunnerFailed -> instance is failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="OOM")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM" in error
|
||||
|
||||
|
||||
def test_runners_failed_mixed_failed_shutdown():
|
||||
"""One Failed + one Shutdown = failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="crash"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "crash" in error
|
||||
|
||||
|
||||
def test_runners_not_failed_all_shutdown():
|
||||
"""All Shutdown (graceful) = not a failure."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerShutdown() for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_still_active():
|
||||
"""Some runners still active = not failed yet."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerLoading(),
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_no_status():
|
||||
"""Runner not yet reported = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
is_failed, _ = instance_runners_failed(inst, {}, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_healthy():
|
||||
"""Runners in Ready state = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerReady() for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
# --- failure tracking in apply_instance_deleted ---
|
||||
|
||||
|
||||
def test_apply_instance_deleted_tracks_failure():
|
||||
"""InstanceDeleted with failure_error increments meta instance failure count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "Runner OOM"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_increments_failure():
|
||||
"""Subsequent failures increment the counter."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
|
||||
)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="new error")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 3
|
||||
assert mi.last_failure_error == "new error"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_no_failure_no_tracking():
|
||||
"""InstanceDeleted without failure_error does not track."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_apply_instance_deleted_orphan_no_tracking():
|
||||
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
state = State(instances={iid: inst})
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
# --- InstanceRetrying ---
|
||||
|
||||
|
||||
def test_apply_instance_retrying_removes_runners():
|
||||
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners=runners,
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="OOM",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Instance still exists
|
||||
assert iid in new_state.instances
|
||||
# Runners removed
|
||||
assert runner_ids[0] not in new_state.runners
|
||||
assert runner_ids[1] not in new_state.runners
|
||||
|
||||
|
||||
def test_apply_instance_retrying_increments_failure():
|
||||
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "crash"
|
||||
|
||||
|
||||
def test_apply_instance_retrying_skips_missing_runners():
|
||||
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
# No runners in state at all
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
# Should not raise
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert iid in new_state.instances
|
||||
|
||||
|
||||
def test_apply_instance_created_resets_failure_counter():
|
||||
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 3, "last_failure_error": "old error"}
|
||||
)
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceCreated(instance=inst)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.last_failure_error == "old error"
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
# --- InstanceHealthReconciler retry-vs-delete ---
|
||||
|
||||
|
||||
async def test_health_reconciler_retries_when_under_limit():
|
||||
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid
|
||||
assert events[0].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_when_limit_reached():
|
||||
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
|
||||
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_without_meta_instance():
|
||||
"""Instances without a MetaInstance are deleted immediately on runner failure."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_network_failure_always_deletes():
|
||||
"""Network failure always triggers InstanceDeleted regardless of retry count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=_topology("node-a"), # node-b missing
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
assert events[0].failure_error == "Network connection lost"
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -12,6 +12,12 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -28,6 +34,7 @@ from exo.shared.types.events import (
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
@@ -66,12 +73,22 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| InputChunkReceived()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
| JacclSideChannelData()
|
||||
| JacclSideChannelGathered()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceRetrying():
|
||||
return apply_instance_retrying(event, state)
|
||||
case MetaInstanceCreated():
|
||||
return apply_meta_instance_created(event, state)
|
||||
case MetaInstanceDeleted():
|
||||
return apply_meta_instance_deleted(event, state)
|
||||
case MetaInstancePlacementFailed():
|
||||
return apply_meta_instance_placement_failed(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDownloadProgress():
|
||||
@@ -174,20 +191,123 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
def _update_meta_instance(
|
||||
state: State, mid: MetaInstanceId, **fields: object
|
||||
) -> Mapping[MetaInstanceId, MetaInstance]:
|
||||
mi = state.meta_instances[mid]
|
||||
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
|
||||
|
||||
|
||||
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
instance = event.instance
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
instance.instance_id: instance,
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
# Reset failure tracking when a new instance is created for a meta-instance
|
||||
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
|
||||
mi = state.meta_instances[instance.meta_instance_id]
|
||||
if mi.placement_error is not None or mi.consecutive_failures > 0:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
instance.meta_instance_id,
|
||||
placement_error=None,
|
||||
consecutive_failures=0,
|
||||
)
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
deleted_instance = state.instances.get(event.instance_id)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
|
||||
# Track failure on the MetaInstance itself
|
||||
if (
|
||||
event.failure_error
|
||||
and deleted_instance
|
||||
and deleted_instance.meta_instance_id
|
||||
and deleted_instance.meta_instance_id in state.meta_instances
|
||||
):
|
||||
mid = deleted_instance.meta_instance_id
|
||||
mi = state.meta_instances[mid]
|
||||
update["meta_instances"] = {
|
||||
**state.meta_instances,
|
||||
mid: mi.model_copy(
|
||||
update={
|
||||
"consecutive_failures": mi.consecutive_failures + 1,
|
||||
"last_failure_error": event.failure_error,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
|
||||
"""Runners failed but retry limit not reached — remove runners, keep instance."""
|
||||
instance = state.instances.get(event.instance_id)
|
||||
if instance is None:
|
||||
# Instance was already deleted (e.g. cascade from DeleteMetaInstance).
|
||||
# The InstanceDeleted handler already incremented consecutive_failures
|
||||
# on the MetaInstance, so skipping here avoids double-counting.
|
||||
return state
|
||||
|
||||
# Remove all runners belonging to this instance from state
|
||||
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
|
||||
}
|
||||
|
||||
update: dict[str, object] = {"runners": new_runners}
|
||||
|
||||
# Increment failure count on the MetaInstance
|
||||
if event.meta_instance_id in state.meta_instances:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
event.meta_instance_id,
|
||||
consecutive_failures=state.meta_instances[
|
||||
event.meta_instance_id
|
||||
].consecutive_failures
|
||||
+ 1,
|
||||
last_failure_error=event.failure_error,
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
**state.meta_instances,
|
||||
event.meta_instance.meta_instance_id: event.meta_instance,
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
mid: mi
|
||||
for mid, mi in state.meta_instances.items()
|
||||
if mid != event.meta_instance_id
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_placement_failed(
|
||||
event: MetaInstancePlacementFailed, state: State
|
||||
) -> State:
|
||||
if event.meta_instance_id not in state.meta_instances:
|
||||
return state
|
||||
return state.model_copy(
|
||||
update={
|
||||
"meta_instances": _update_meta_instance(
|
||||
state, event.meta_instance_id, placement_error=event.reason
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
@@ -218,11 +338,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
# Clean up all granular node mappings
|
||||
node_identities = {
|
||||
key: value
|
||||
for key, value in state.node_identities.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_memory = {
|
||||
key: value for key, value in state.node_memory.items() if key != event.node_id
|
||||
}
|
||||
@@ -263,7 +378,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
"node_identities": node_identities,
|
||||
"node_memory": node_memory,
|
||||
"node_disk": node_disk,
|
||||
"node_system": node_system,
|
||||
|
||||
@@ -182,6 +182,7 @@ class ConfigData(BaseModel):
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["GlmMoeDsaForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
|
||||
@@ -3,11 +3,10 @@ from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -228,13 +227,6 @@ class PlaceInstanceParams(BaseModel):
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
def use_default(cls, v: object):
|
||||
if not v or not isinstance(v, (Sharding, InstanceMeta)):
|
||||
raise PydanticUseDefault()
|
||||
return v
|
||||
|
||||
|
||||
class CreateInstanceParams(BaseModel):
|
||||
instance: Instance
|
||||
@@ -270,6 +262,26 @@ class DeleteInstanceResponse(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class CreateMetaInstanceParams(BaseModel):
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
|
||||
|
||||
class CreateMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class DeleteMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
|
||||
@@ -6,7 +6,8 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -48,6 +49,18 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class CreateMetaInstance(BaseCommand):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class DeleteMetaInstance(BaseCommand):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -89,6 +102,9 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| CreateMetaInstance
|
||||
| DeleteMetaInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -42,6 +42,10 @@ class CommandId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class MetaInstanceId(Id):
|
||||
"""Identifier for a MetaInstance."""
|
||||
|
||||
|
||||
class Host(CamelCaseModel):
|
||||
ip: str
|
||||
port: int
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from typing import Annotated, final
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BeforeValidator, Field, PlainSerializer
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -14,6 +17,28 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
|
||||
|
||||
def _decode_base64_bytes(v: bytes | str) -> bytes:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
return base64.b64decode(v)
|
||||
|
||||
|
||||
def _encode_base64_bytes(v: bytes) -> str:
|
||||
return base64.b64encode(v).decode("ascii")
|
||||
|
||||
|
||||
Base64Bytes = Annotated[
|
||||
bytes,
|
||||
BeforeValidator(_decode_base64_bytes),
|
||||
PlainSerializer(_encode_base64_bytes, return_type=str),
|
||||
]
|
||||
"""bytes that serialize to/from base64 strings in JSON.
|
||||
|
||||
Needed because TaggedModel's wrap validator converts JSON→Python validation
|
||||
context, which breaks strict-mode bytes deserialization from JSON strings.
|
||||
"""
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
"""
|
||||
Newtype around `ID`
|
||||
@@ -66,6 +91,30 @@ class InstanceCreated(BaseEvent):
|
||||
|
||||
class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
failure_error: str | None = None
|
||||
|
||||
|
||||
class MetaInstanceCreated(BaseEvent):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class MetaInstanceDeleted(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstancePlacementFailed(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
reason: str
|
||||
|
||||
|
||||
@final
|
||||
class InstanceRetrying(BaseEvent):
|
||||
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
meta_instance_id: MetaInstanceId
|
||||
failure_error: str
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
@@ -132,6 +181,25 @@ class TracesMerged(BaseEvent):
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelData(BaseEvent):
|
||||
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
runner_id: RunnerId
|
||||
sequence: int
|
||||
data: Base64Bytes
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelGathered(BaseEvent):
|
||||
"""Gathered result of a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
sequence: int
|
||||
gathered_data: Mapping[RunnerId, Base64Bytes]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -141,6 +209,10 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceRetrying
|
||||
| MetaInstanceCreated
|
||||
| MetaInstanceDeleted
|
||||
| MetaInstancePlacementFailed
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
@@ -152,6 +224,8 @@ Event = (
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
| TracesMerged
|
||||
| JacclSideChannelData
|
||||
| JacclSideChannelGathered
|
||||
)
|
||||
|
||||
|
||||
|
||||
25
src/exo/shared/types/meta_instance.py
Normal file
25
src/exo/shared/types/meta_instance.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstance(FrozenModel):
|
||||
"""Declarative constraint: ensure an instance matching these parameters always exists."""
|
||||
|
||||
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
# Failure tracking
|
||||
placement_error: str | None = None
|
||||
consecutive_failures: int = 0
|
||||
last_failure_error: str | None = None
|
||||
@@ -6,7 +6,8 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import (
|
||||
DiskUsage,
|
||||
MemoryUsage,
|
||||
@@ -41,6 +42,7 @@ class State(CamelCaseModel):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
instances: Mapping[InstanceId, Instance] = {}
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
|
||||
@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask): # emitted by Worker when master cancels a task
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -87,6 +93,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -26,6 +26,7 @@ class DownloadProgressData(CamelCaseModel):
|
||||
class BaseDownloadProgress(TaggedModel):
|
||||
node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
model_directory: str = ""
|
||||
|
||||
|
||||
class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,6 +19,7 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
meta_instance_id: MetaInstanceId | None = None
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -62,6 +62,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import sys
|
||||
|
||||
|
||||
def print_startup_banner(port: int) -> None:
|
||||
"""Print a prominent startup banner with API endpoint information."""
|
||||
dashboard_url = f"http://localhost:{port}"
|
||||
banner = f"""
|
||||
╔═══════════════════════════════════════════════════════════════════════╗
|
||||
@@ -27,4 +29,4 @@ def print_startup_banner(port: int) -> None:
|
||||
|
||||
"""
|
||||
|
||||
print(banner)
|
||||
print(banner, file=sys.stderr)
|
||||
|
||||
@@ -306,7 +306,7 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Ready to prefill")
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
@@ -393,10 +393,11 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
total_prompt_tokens = len(all_prompt_tokens)
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
total_tokens=total_prompt_tokens + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
|
||||
@@ -64,8 +64,6 @@ from exo.worker.runner.bootstrap import logger
|
||||
Group = mx.distributed.Group
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -83,30 +81,6 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -311,7 +285,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm-4.7-flash" in model_id_lower:
|
||||
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif "glm" in model_id_lower:
|
||||
@@ -379,7 +353,13 @@ def load_tokenizer_for_model_id(
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
|
||||
return TokenizerWrapper(
|
||||
hf_tokenizer,
|
||||
eos_token_ids=eos_token_ids,
|
||||
tool_call_start="<|tool_calls_section_begin|>",
|
||||
tool_call_end="<|tool_calls_section_end|>",
|
||||
tool_parser=_parse_kimi_tool_calls,
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
@@ -591,3 +571,66 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
"""Synchronize a boolean across all distributed nodes.
|
||||
|
||||
Returns True if any node has bool_=True. Uses all_sum so every
|
||||
node participates in the collective — preventing GPU deadlocks.
|
||||
"""
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _parse_kimi_tool_calls(text: str):
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
_tool_call_split_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
|
||||
)
|
||||
|
||||
def _parse_single_tool(text: str) -> dict[str, Any]:
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError("No tool call found.")
|
||||
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
|
||||
func_name = func_name_match.group(2) # e.g. "get_weather"
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError("No tool call arguments found.")
|
||||
func_args = func_args_match.group(1)
|
||||
try:
|
||||
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
arg_dct = None
|
||||
|
||||
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
|
||||
|
||||
tool_matches = _tool_call_split_regex.findall(text)
|
||||
if tool_matches:
|
||||
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
|
||||
else:
|
||||
return [_parse_single_tool(text)]
|
||||
|
||||
@@ -24,6 +24,7 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
JacclSideChannelGathered,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -158,6 +159,15 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
|
||||
if isinstance(event, JacclSideChannelGathered):
|
||||
for runner in self.runners.values():
|
||||
if (
|
||||
runner.bound_instance.instance.instance_id
|
||||
== event.instance_id
|
||||
):
|
||||
runner.notify_gathered(event)
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
@@ -320,8 +330,6 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -34,6 +35,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
@@ -53,13 +55,14 @@ def plan(
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances, all_runners)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
)
|
||||
|
||||
|
||||
@@ -73,6 +76,12 @@ def _kill_runner(
|
||||
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
# Master removed our runner from state (retry signal) and process is dead
|
||||
if runner_id not in all_runners and isinstance(
|
||||
runner.status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
for (
|
||||
global_runner_id
|
||||
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
|
||||
@@ -90,6 +99,7 @@ def _create_runner(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> CreateRunner | None:
|
||||
for instance in instances.values():
|
||||
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
|
||||
@@ -99,6 +109,16 @@ def _create_runner(
|
||||
if runner_id in runners:
|
||||
continue
|
||||
|
||||
# Don't create while any peer runner is in a terminal state — wait for
|
||||
# the master to emit InstanceRetrying which removes them from state.
|
||||
has_terminal_peer = any(
|
||||
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
|
||||
for peer_rid in instance.shard_assignments.node_to_runner.values()
|
||||
if peer_rid != runner_id
|
||||
)
|
||||
if has_terminal_peer:
|
||||
continue
|
||||
|
||||
shard = instance.shard(runner_id)
|
||||
assert shard is not None
|
||||
|
||||
@@ -270,7 +290,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -284,7 +304,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -292,16 +312,34 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> CancelTask | None:
|
||||
"""Find a cancelled task that hasn't been sent to the runner yet."""
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner_id, runner in runners.items():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id,
|
||||
cancelled_task_id=task.task_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,7 +15,9 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
pipe_fifo_paths: tuple[str, str] | None = None,
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
@@ -29,6 +31,16 @@ def entrypoint(
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
|
||||
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
|
||||
if pipe_fifo_paths is not None:
|
||||
fifo_c2p, fifo_p2c = pipe_fifo_paths
|
||||
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
|
||||
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
|
||||
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
|
||||
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
|
||||
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
@@ -38,7 +50,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
@@ -55,7 +67,9 @@ def entrypoint(
|
||||
try:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
cancel_receiver.close()
|
||||
finally:
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
cancel_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import resource
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any, Callable, Literal
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -15,7 +15,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
@@ -88,9 +87,12 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
from .tool_parsers import ToolParser, make_mlx_parser
|
||||
|
||||
|
||||
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
|
||||
"""Check if this node is the primary output node for image generation.
|
||||
@@ -112,6 +114,7 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||
@@ -129,11 +132,16 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
model: Model | DistributedImageModel | None = None
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
tool_parser: ToolParser | None = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -146,6 +154,7 @@ def main(
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -191,19 +200,28 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
|
||||
)
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
tool_parser = make_mlx_parser(
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
@@ -211,8 +229,6 @@ def main(
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -224,16 +240,31 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
t = time.perf_counter()
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / max(time.perf_counter() - t, 0.001)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
mx.max(
|
||||
mx.distributed.all_gather(
|
||||
mx.array([check_for_cancel_every]), group=group
|
||||
)
|
||||
).item()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||
)
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
@@ -241,8 +272,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -262,9 +293,9 @@ def main(
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
@@ -274,7 +305,7 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
@@ -289,34 +320,25 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
if isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
elif tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
@@ -364,6 +386,7 @@ def main(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -388,7 +411,7 @@ def main(
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
assert image_model
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -401,7 +424,9 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||
|
||||
if is_primary_output:
|
||||
@@ -451,7 +476,7 @@ def main(
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
assert image_model
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -464,7 +489,9 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -523,14 +550,20 @@ def main(
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
was_cancelled = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if not was_cancelled:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
@@ -544,21 +577,8 @@ def get_gpt_oss_encoding():
|
||||
return encoding
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
assert isinstance(resp, GenerationResponse)
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
):
|
||||
continue
|
||||
yield resp
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
@@ -615,9 +635,9 @@ def parse_gpt_oss(
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
@@ -738,218 +758,55 @@ def _process_image_response(
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
responses: Generator[GenerationResponse], tool_parser: ToolParser
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
if response.text.startswith(tool_parser.start_parsing):
|
||||
in_tool_call = True
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if in_tool_call and response.text == tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if isinstance(parsed, list):
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# ValueError: our parsers raise this for malformed tool calls
|
||||
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
|
||||
)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
if response.text.endswith(tool_parser.end_parsing):
|
||||
# parse the actual tool calls from the tool call text
|
||||
parsed = tool_parser.parse_tool_calls(
|
||||
"".join(tool_call_text_parts).strip()
|
||||
)
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if parsed is not None:
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed, usage=response.usage, stats=response.stats
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
|
||||
)
|
||||
response.text = "".join(tool_call_text_parts)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if response.finish_reason is not None:
|
||||
logger.info(
|
||||
"toll call parsing interrupted, yield partial tool call as text"
|
||||
"tool call parsing interrupted, yield partial tool call as text"
|
||||
)
|
||||
yield GenerationResponse(
|
||||
text=tool_call_start + "".join(tool_call_text_parts),
|
||||
token=0,
|
||||
finish_reason=response.finish_reason,
|
||||
usage=None,
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"text": "".join(tool_call_text_parts),
|
||||
"token": 0,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
continue
|
||||
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Version of to-be-upstreamed kimi-k2 tool parser
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
|
||||
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
|
||||
tool_call_start = "<|tool_call_begin|>"
|
||||
tool_call_end = "<|tool_call_end|>"
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
original_func_name = func_name_match.group(1)
|
||||
tool_id = func_name_match.group(2)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = original_func_name[original_func_name.find(".") + 1 :]
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError(f"Could not parse function args from tool call: {text!r}")
|
||||
func_args = func_args_match.group(1)
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
return dict(
|
||||
id=f"{original_func_name}:{tool_id}",
|
||||
name=func_name,
|
||||
arguments=arg_dct, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
|
||||
_func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
tool_call_start = "<tool_call>"
|
||||
tool_call_end = "</tool_call>"
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Any] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools: # pyright: ignore[reportAny]
|
||||
func = tool["function"] # pyright: ignore[reportAny]
|
||||
if func["name"] == tool_name:
|
||||
params = func["parameters"] # pyright: ignore[reportAny]
|
||||
if params is None:
|
||||
return False
|
||||
props = params.get("properties", {}) # pyright: ignore[reportAny]
|
||||
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
|
||||
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
|
||||
return arg_type == "string" # pyright: ignore[reportAny]
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: list[Any] | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
pairs = _func_arg_regex.findall(text)
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs: # pyright: ignore[reportAny]
|
||||
arg_key = key.strip() # pyright: ignore[reportAny]
|
||||
arg_val = value.strip() # pyright: ignore[reportAny]
|
||||
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
|
||||
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
|
||||
arg_dct[arg_key] = arg_val
|
||||
return dict(name=func_name, arguments=arg_dct)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
and ((args := obj.get("arguments")) is not None)
|
||||
and isinstance(name, str)
|
||||
):
|
||||
raw_id: object = obj.get("id")
|
||||
extra = {"id": str(raw_id)} if raw_id is not None else {}
|
||||
return ToolCallItem(
|
||||
**extra,
|
||||
name=name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
else:
|
||||
raise ValidationError
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from multiprocessing import Process
|
||||
from typing import Self
|
||||
|
||||
@@ -14,12 +18,14 @@ from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
@@ -34,6 +40,26 @@ from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
|
||||
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
|
||||
data = b""
|
||||
while len(data) < n:
|
||||
chunk = os.read(fd, n - len(data))
|
||||
if not chunk:
|
||||
return None
|
||||
data += chunk
|
||||
return data
|
||||
|
||||
|
||||
def _pipe_write_all(fd: int, data: bytes) -> None:
|
||||
"""Write all bytes to a file descriptor."""
|
||||
view = memoryview(data)
|
||||
while view:
|
||||
written = os.write(fd, view)
|
||||
view = view[written:]
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
DECODE_TIMEOUT_SECONDS = 5
|
||||
|
||||
@@ -46,10 +72,21 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_event_sender: Sender[Event]
|
||||
_pipe_read_fd: int | None = None # Python reads runner's pipe output
|
||||
_pipe_write_fd: int | None = None # Python writes gathered data to runner
|
||||
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
|
||||
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
|
||||
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
|
||||
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
_gathered_waiters: dict[
|
||||
int, tuple[anyio.Event, JacclSideChannelGathered | None]
|
||||
] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -60,8 +97,25 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
|
||||
# Named pipes work across multiprocessing.Process spawn (macOS default).
|
||||
# FIFO c2p: C++ writes local data → Python reads it
|
||||
# FIFO p2c: Python writes gathered data → C++ reads it
|
||||
fifo_dir: str | None = None
|
||||
fifo_c2p: str | None = None
|
||||
fifo_p2c: str | None = None
|
||||
pipe_fifo_paths: tuple[str, str] | None = None
|
||||
|
||||
if isinstance(bound_instance.instance, MlxJacclInstance):
|
||||
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
|
||||
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
|
||||
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
|
||||
os.mkfifo(fifo_c2p)
|
||||
os.mkfifo(fifo_p2c)
|
||||
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -69,7 +123,9 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
pipe_fifo_paths,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
@@ -83,20 +139,56 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
_fifo_dir=fifo_dir,
|
||||
_fifo_c2p=fifo_c2p,
|
||||
_fifo_p2c=fifo_p2c,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
await self._forward_events()
|
||||
|
||||
if self._fifo_c2p is not None and self._fifo_p2c is not None:
|
||||
# Open FIFOs from parent side. These block until child opens the other end,
|
||||
# so we run them in threads concurrently to avoid deadlock.
|
||||
fifo_c2p = self._fifo_c2p
|
||||
fifo_p2c = self._fifo_p2c
|
||||
|
||||
async def open_read() -> None:
|
||||
self._pipe_read_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_c2p, os.O_RDONLY)
|
||||
)
|
||||
|
||||
async def open_write() -> None:
|
||||
self._pipe_write_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_p2c, os.O_WRONLY)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as open_tg:
|
||||
open_tg.start_soon(open_read)
|
||||
open_tg.start_soon(open_write)
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(self._pipe_relay)
|
||||
tg.start_soon(self._forward_events)
|
||||
else:
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self._event_sender.close()
|
||||
self._close_pipe_fds()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
@@ -112,14 +204,6 @@ class RunnerSupervisor:
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
@@ -141,6 +225,18 @@ class RunnerSupervisor:
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
"""Send a cancellation signal to the runner process."""
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
with anyio.move_on_after(0.5) as scope:
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
if scope.cancel_called:
|
||||
logger.error("RunnerSupervisor cancel pipe blocked")
|
||||
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
@@ -172,6 +268,110 @@ class RunnerSupervisor:
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
|
||||
def _close_pipe_fds(self) -> None:
|
||||
if self._pipe_read_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_read_fd)
|
||||
self._pipe_read_fd = None
|
||||
if self._pipe_write_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_write_fd)
|
||||
self._pipe_write_fd = None
|
||||
if self._child_pipe_fds is not None:
|
||||
for fd in self._child_pipe_fds:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
self._child_pipe_fds = None
|
||||
# Clean up FIFO files
|
||||
if self._fifo_c2p is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_c2p)
|
||||
self._fifo_c2p = None
|
||||
if self._fifo_p2c is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_p2c)
|
||||
self._fifo_p2c = None
|
||||
if self._fifo_dir is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.rmdir(self._fifo_dir)
|
||||
self._fifo_dir = None
|
||||
|
||||
async def _pipe_relay(self) -> None:
|
||||
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
|
||||
assert self._pipe_read_fd is not None
|
||||
assert self._pipe_write_fd is not None
|
||||
read_fd = self._pipe_read_fd
|
||||
write_fd = self._pipe_write_fd
|
||||
sequence = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 1. Read local data from runner: [uint32 size][size bytes]
|
||||
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
|
||||
if header is None:
|
||||
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
|
||||
break
|
||||
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
|
||||
local_data = await to_thread.run_sync(
|
||||
partial(_pipe_read_exact, read_fd, data_size)
|
||||
)
|
||||
if local_data is None:
|
||||
logger.warning("JACCL pipe relay: EOF reading data payload")
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
|
||||
)
|
||||
|
||||
# 2. Emit JacclSideChannelData event
|
||||
waiter = anyio.Event()
|
||||
self._gathered_waiters[sequence] = (waiter, None)
|
||||
await self._event_sender.send(
|
||||
JacclSideChannelData(
|
||||
instance_id=self.bound_instance.instance.instance_id,
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
sequence=sequence,
|
||||
data=local_data,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Wait for gathered result
|
||||
await waiter.wait()
|
||||
_, gathered_event = self._gathered_waiters.pop(sequence)
|
||||
assert gathered_event is not None
|
||||
|
||||
# 4. Order gathered data by runner rank and concatenate
|
||||
instance = self.bound_instance.instance
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
|
||||
ordered_data = b"".join(
|
||||
gathered_event.gathered_data[rid] for rid in runner_order
|
||||
)
|
||||
|
||||
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
|
||||
total_size = len(ordered_data)
|
||||
response = struct.pack("<I", total_size) + ordered_data
|
||||
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
|
||||
)
|
||||
sequence += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"JACCL pipe relay: OS error: {e}")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
|
||||
|
||||
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
|
||||
"""Called by the worker when a JacclSideChannelGathered event arrives."""
|
||||
seq = event.sequence
|
||||
if seq not in self._gathered_waiters:
|
||||
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
|
||||
return
|
||||
waiter, _ = self._gathered_waiters[seq]
|
||||
self._gathered_waiters[seq] = (waiter, event)
|
||||
waiter.set()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||
|
||||
72
src/exo/worker/runner/tool_parsers.py
Normal file
72
src/exo/worker/runner/tool_parsers.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParser:
|
||||
start_parsing: str
|
||||
end_parsing: str
|
||||
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
|
||||
|
||||
|
||||
def make_mlx_parser(
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
) -> ToolParser:
|
||||
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix(tool_call_start)
|
||||
text = text.removesuffix(tool_call_end)
|
||||
parsed = tool_parser(text)
|
||||
if isinstance(parsed, list):
|
||||
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
|
||||
else:
|
||||
return [ToolCallItem.model_validate(_flatten(parsed))]
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return ToolParser(
|
||||
start_parsing=tool_call_start,
|
||||
end_parsing=tool_call_end,
|
||||
parse_tool_calls=parse_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
# TODO / example code:
|
||||
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix("<tool_call>")
|
||||
text = text.removesuffix("</tool_call>")
|
||||
top_level = {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else v
|
||||
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
|
||||
}
|
||||
return [ToolCallItem.model_validate(top_level)]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _flatten(p: dict[str, Any]) -> dict[str, str]:
|
||||
return {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
|
||||
for k, v in p.items() # pyright: ignore[reportAny]
|
||||
}
|
||||
|
||||
|
||||
json_tool_parser = ToolParser(
|
||||
start_parsing="<tool_call>",
|
||||
end_parsing="</tool_call>",
|
||||
parse_tool_calls=_parse_json_calls,
|
||||
)
|
||||
|
||||
|
||||
def infer_tool_parser(chat_template: str) -> ToolParser | None:
|
||||
"""Attempt to auto-infer a tool parser from the chat template."""
|
||||
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
|
||||
return json_tool_parser
|
||||
return None
|
||||
@@ -19,6 +19,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -113,6 +114,13 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
|
||||
# Mock mx.distributed.all_gather so MockGroup doesn't hit real MLX C++ bindings.
|
||||
def _mock_all_gather(x: object, **_kw: object) -> object:
|
||||
return x
|
||||
|
||||
monkeypatch.setattr(mlx_runner.mx.distributed, "all_gather", _mock_all_gather)
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
@@ -163,6 +171,7 @@ def _run(tasks: Iterable[Task]):
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||
event_sender = EventCollector()
|
||||
|
||||
with task_sender:
|
||||
@@ -173,8 +182,15 @@ def _run(tasks: Iterable[Task]):
|
||||
# this is some c++ nonsense
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
cancel_receiver.close = nothin
|
||||
cancel_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
|
||||
mlx_runner.main(
|
||||
bound_instance,
|
||||
event_sender, # pyright: ignore[reportArgumentType]
|
||||
task_receiver,
|
||||
cancel_receiver,
|
||||
)
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@ from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||
from exo.worker.runner.runner import parse_tool_calls
|
||||
from exo.worker.runner.tool_parsers import make_mlx_parser
|
||||
|
||||
|
||||
def _make_responses(
|
||||
texts: list[str],
|
||||
finish_on_last: bool = True,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Create a sequence of GenerationResponses from text strings."""
|
||||
for i, text in enumerate(texts):
|
||||
is_last = i == len(texts) - 1
|
||||
@@ -22,10 +23,13 @@ def _make_responses(
|
||||
)
|
||||
|
||||
|
||||
def _dummy_parser(text: str) -> dict[str, Any]:
|
||||
def _dummier_parser(text: str) -> dict[str, Any]:
|
||||
return {"name": "test_fn", "arguments": {"arg": text}}
|
||||
|
||||
|
||||
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
|
||||
|
||||
|
||||
class TestParseToolCalls:
|
||||
"""Tests for parse_tool_calls generator."""
|
||||
|
||||
@@ -35,8 +39,6 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts, finish_on_last=False),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_dummy_parser,
|
||||
)
|
||||
)
|
||||
@@ -50,8 +52,6 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_dummy_parser,
|
||||
)
|
||||
)
|
||||
@@ -76,9 +76,7 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts, finish_on_last=False),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_failing_parser,
|
||||
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
40
uv.lock
generated
40
uv.lock
generated
@@ -377,8 +377,8 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -416,7 +416,7 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.5" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
{ name = "msgspec", specifier = ">=0.19.0" },
|
||||
@@ -1020,8 +1020,8 @@ dependencies = [
|
||||
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1048,18 +1048,12 @@ wheels = [
|
||||
name = "mlx"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux'",
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
|
||||
]
|
||||
@@ -1072,6 +1066,14 @@ cuda13 = [
|
||||
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.7.dev20260217+50487b41"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.6"
|
||||
@@ -1102,7 +1104,7 @@ version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1114,16 +1116,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "more-itertools"
|
||||
version = "10.8.0"
|
||||
|
||||
Reference in New Issue
Block a user