mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-17 09:33:15 -05:00
Compare commits
12 Commits
alexcheema
...
splitting-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7e3d4e693 | ||
|
|
6d1ca6689b | ||
|
|
c01b6fff21 | ||
|
|
8392e78afe | ||
|
|
86735ece78 | ||
|
|
2759e92334 | ||
|
|
131fb141a6 | ||
|
|
2d8bfc2e3c | ||
|
|
042999f728 | ||
|
|
b61dc2eb35 | ||
|
|
36a7115b6f | ||
|
|
0b7d88b43b |
27
.github/workflows/pipeline.yml
vendored
27
.github/workflows/pipeline.yml
vendored
@@ -8,33 +8,6 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
typecheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
@@ -276,23 +276,24 @@ class BatchGenerator:
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
unprocessed_prompts: List[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set[int]] = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
|
||||
) -> List[int]: ...
|
||||
def stats(self) -> BatchStats: ...
|
||||
def next(self) -> List[Response]: ...
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
|
||||
@@ -39,11 +39,11 @@ class StreamingDetokenizer:
|
||||
"""
|
||||
|
||||
__slots__ = ...
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
@property
|
||||
def last_segment(self) -> str:
|
||||
def last_segment(self):
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
||||
39
AGENTS.md
39
AGENTS.md
@@ -116,49 +116,10 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Model Storage
|
||||
|
||||
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
|
||||
|
||||
## Creating Model Instances via API
|
||||
|
||||
When testing with the API, you must first create a model instance before sending chat completions:
|
||||
|
||||
```bash
|
||||
# 1. Get instance previews for a model
|
||||
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
|
||||
|
||||
# 2. Create an instance from the first valid preview
|
||||
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
|
||||
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
|
||||
|
||||
# 3. Wait for the runner to become ready (check logs for "runner ready")
|
||||
|
||||
# 4. Send chat completions using the full model ID
|
||||
curl -X POST http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
|
||||
```
|
||||
|
||||
## Logs
|
||||
|
||||
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
### Distributed Testing
|
||||
|
||||
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
|
||||
|
||||
```bash
|
||||
# On each machine in the test cluster, use the same unique namespace
|
||||
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
|
||||
```
|
||||
|
||||
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.
|
||||
|
||||
## Dashboard UI Testing & Screenshots
|
||||
|
||||
### Building and Running the Dashboard
|
||||
|
||||
@@ -5,21 +5,21 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
|
||||
7
bench/bench.toml
Normal file
7
bench/bench.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
# Canary benchmark manifest
|
||||
#
|
||||
# Lists the suite files to include. Each file defines benchmarks
|
||||
# with shared constraints, topology, and default args.
|
||||
include = [
|
||||
"single-m3-ultra.toml",
|
||||
]
|
||||
@@ -288,6 +288,151 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
@@ -535,6 +680,11 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -569,13 +719,16 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and args.settle_timeout > 0:
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + args.settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
@@ -607,6 +760,16 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
189
bench/single-m3-ultra.toml
Normal file
189
bench/single-m3-ultra.toml
Normal file
@@ -0,0 +1,189 @@
|
||||
# Single-node M3 Ultra benchmarks
|
||||
#
|
||||
# Shared constraints applied to ALL benchmarks in this file.
|
||||
constraints = [
|
||||
"All(MacOsBuild(=25D125))",
|
||||
"Hosts(=1)",
|
||||
"All(Chip(m3_ultra))",
|
||||
"All(GpuCores(=80))",
|
||||
]
|
||||
|
||||
[topology]
|
||||
type = "none"
|
||||
|
||||
# Default args merged into each benchmark's args (benchmark-level args win).
|
||||
[defaults]
|
||||
pp = [512, 2048, 8192, 16384]
|
||||
tg = 128
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-3bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-6bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
@@ -1 +0,0 @@
|
||||
collect_ignore = ["tests/start_distributed_test.py"]
|
||||
@@ -265,6 +265,7 @@
|
||||
|
||||
function handleSubmit() {
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
|
||||
if (isEditOnlyWithoutImage) return;
|
||||
|
||||
const content = message.trim();
|
||||
const files = [...uploadedFiles];
|
||||
@@ -289,7 +290,11 @@
|
||||
if (imageFile.preview) {
|
||||
editImage(content, imageFile.preview);
|
||||
}
|
||||
} else if (isImageModel() && content) {
|
||||
} else if (
|
||||
currentModel &&
|
||||
modelSupportsTextToImage(currentModel) &&
|
||||
content
|
||||
) {
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
|
||||
@@ -225,6 +225,7 @@
|
||||
}
|
||||
|
||||
function handleDeleteClick(messageId: string) {
|
||||
if (loading) return;
|
||||
deleteConfirmId = messageId;
|
||||
}
|
||||
|
||||
@@ -255,7 +256,7 @@
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
{#each messageList as message, i (message.id)}
|
||||
<div
|
||||
class="group flex {message.role === 'user'
|
||||
? 'justify-end'
|
||||
@@ -317,9 +318,11 @@
|
||||
<!-- Delete confirmation -->
|
||||
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
|
||||
<p class="text-xs text-red-400 mb-3">
|
||||
Delete this message{message.role === "user"
|
||||
? " and all responses after it"
|
||||
: ""}?
|
||||
{#if i === messageList.length - 1}
|
||||
Delete this message?
|
||||
{:else}
|
||||
Delete this message and all messages after it?
|
||||
{/if}
|
||||
</p>
|
||||
<div class="flex gap-2 justify-end">
|
||||
<button
|
||||
@@ -751,8 +754,13 @@
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
onclick={() => handleDeleteClick(message.id)}
|
||||
class="p-1.5 text-exo-light-gray hover:text-red-400 transition-colors rounded hover:bg-red-500/10 cursor-pointer"
|
||||
title="Delete message"
|
||||
disabled={loading}
|
||||
class="p-1.5 transition-colors rounded {loading
|
||||
? 'text-exo-light-gray/30 cursor-not-allowed'
|
||||
: 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}"
|
||||
title={loading
|
||||
? "Cannot delete while generating"
|
||||
: "Delete message"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel.
|
||||
# Preserve passthru so mkVirtualEnv can resolve dependency groups.
|
||||
# Copy .pyi stub + py.typed marker so basedpyright can find the types.
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
@@ -22,6 +24,12 @@
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
passthru = prev.exo-pyo3-bindings.passthru or { };
|
||||
postInstall = ''
|
||||
local siteDir=$out/${final.python.sitePackages}/exo_pyo3_bindings
|
||||
cp ${inputs.self}/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi $siteDir/
|
||||
touch $siteDir/py.typed
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
@@ -29,17 +37,32 @@
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
} // lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
# Use our pure Nix-built MLX with Metal support (macOS only)
|
||||
mlx = self'.packages.mlx;
|
||||
};
|
||||
|
||||
# Additional overlay for Linux-specific fixes (type checking env).
|
||||
# Native wheels have shared lib dependencies we don't need at type-check time.
|
||||
linuxOverlay = final: prev:
|
||||
let
|
||||
ignoreMissing = drv: drv.overrideAttrs { autoPatchelfIgnoreMissingDeps = [ "*" ]; };
|
||||
nvidiaPackages = lib.filterAttrs (name: _: lib.hasPrefix "nvidia-" name) prev;
|
||||
in
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
);
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
@@ -48,6 +71,7 @@
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
@@ -118,6 +142,21 @@
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
|
||||
# Hermetic basedpyright type checking
|
||||
typecheck = pkgs.runCommand "typecheck"
|
||||
{
|
||||
nativeBuildInputs = [
|
||||
testVenv
|
||||
pkgs.basedpyright
|
||||
];
|
||||
}
|
||||
''
|
||||
cd ${inputs.self}
|
||||
export HOME=$TMPDIR
|
||||
basedpyright --pythonpath ${testVenv}/bin/python
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::{Pin, pin},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
@@ -33,8 +33,6 @@ where
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -17,7 +17,6 @@
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
@@ -25,7 +24,6 @@ use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -36,14 +34,10 @@ pub(crate) mod r#const {
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
@@ -51,7 +45,6 @@ pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
@@ -62,7 +55,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
Python::attach(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,7 +91,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
Python::attach(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
@@ -175,24 +168,6 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
|
||||
@@ -11,9 +11,9 @@ use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt a
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
@@ -25,7 +25,7 @@ use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3::{exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
@@ -155,7 +155,6 @@ async fn networking_task(
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
@@ -485,7 +484,7 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
|
||||
@@ -24,8 +24,8 @@ use libp2p::{
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -3,19 +3,7 @@
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
@@ -54,11 +42,3 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ from exo.shared.types.api import (
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -125,6 +126,8 @@ async def generate_chat_stream(
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
@@ -138,6 +141,8 @@ async def generate_chat_stream(
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
@@ -161,12 +166,15 @@ async def generate_chat_stream(
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
@@ -176,7 +184,9 @@ async def generate_chat_stream(
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ChatCompletionResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
@@ -184,6 +194,7 @@ async def collect_chat_response(
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -193,6 +204,8 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
@@ -223,7 +236,7 @@ async def collect_chat_response(
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
yield ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
@@ -241,4 +254,6 @@ async def collect_chat_response(
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
usage=last_usage,
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.api import FinishReason, Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
@@ -161,12 +161,14 @@ async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -174,6 +176,8 @@ async def collect_claude_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
tool_use_blocks.append(
|
||||
@@ -183,12 +187,10 @@ async def collect_claude_response(
|
||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -208,11 +210,11 @@ async def collect_claude_response(
|
||||
if not content:
|
||||
content.append(ClaudeTextBlock(text=""))
|
||||
|
||||
# Use actual usage data from stats if available
|
||||
input_tokens = last_stats.prompt_tokens if last_stats else 0
|
||||
output_tokens = last_stats.generation_tokens if last_stats else 0
|
||||
# Use actual usage data if available
|
||||
input_tokens = last_usage.prompt_tokens if last_usage else 0
|
||||
output_tokens = last_usage.completion_tokens if last_usage else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
yield ClaudeMessagesResponse(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=content,
|
||||
@@ -221,7 +223,8 @@ async def collect_claude_response(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
@@ -249,7 +252,7 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -257,8 +260,9 @@ async def generate_claude_stream(
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
|
||||
# Emit tool_use content blocks
|
||||
@@ -290,7 +294,6 @@ async def generate_claude_stream(
|
||||
continue
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
@@ -302,9 +305,9 @@ async def generate_claude_stream(
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
# Use actual token count from usage if available
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
@@ -121,13 +122,15 @@ async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ResponsesResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -135,32 +138,32 @@ async def collect_responses_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
function_call_items.append(
|
||||
ResponseFunctionCallItem(
|
||||
id=f"fc_{tool.id}",
|
||||
call_id=f"call_{tool.id}",
|
||||
id=tool.id,
|
||||
call_id=tool.id,
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create usage from stats if available
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
if last_usage is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
@@ -172,14 +175,15 @@ async def collect_responses_response(
|
||||
]
|
||||
output.extend(function_call_items)
|
||||
|
||||
return ResponsesResponse(
|
||||
yield ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
@@ -235,15 +239,16 @@ async def generate_responses_stream(
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
for tool in chunk.tool_calls:
|
||||
fc_id = f"fc_{tool.id}"
|
||||
call_id = f"call_{tool.id}"
|
||||
@@ -302,7 +307,6 @@ async def generate_responses_stream(
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
@@ -346,13 +350,13 @@ async def generate_responses_stream(
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
if last_usage is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
|
||||
@@ -125,6 +125,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -540,16 +541,14 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
@@ -644,11 +643,14 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
@@ -664,8 +666,7 @@ class API:
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
||||
return response
|
||||
return await self._collect_text_generation_with_stats(command.command_id)
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
"""Validate a text model exists and return the resolved model ID.
|
||||
@@ -883,6 +884,11 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -964,6 +970,11 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -1221,12 +1232,15 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
@@ -1254,11 +1268,15 @@ class API:
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
|
||||
@@ -24,6 +24,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
@@ -39,6 +40,7 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
@@ -279,7 +281,7 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
@@ -299,7 +301,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -309,7 +311,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -319,6 +321,18 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -327,10 +341,9 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
# rate limit to 1000 at a time
|
||||
|
||||
@@ -22,9 +22,15 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
@@ -186,6 +192,7 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -201,6 +208,18 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -4,7 +4,11 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
|
||||
from exo.master.adapters.claude import (
|
||||
ClaudeMessagesResponse,
|
||||
collect_claude_response,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
@@ -17,6 +21,18 @@ async def _chunks_to_stream(
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _collect_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Helper to consume the async generator and parse the JSON response."""
|
||||
parts: list[str] = []
|
||||
async for part in collect_claude_response(command_id, model, chunk_stream):
|
||||
parts.append(part)
|
||||
return ClaudeMessagesResponse.model_validate_json("".join(parts))
|
||||
|
||||
|
||||
MODEL = ModelId("test-model")
|
||||
COMMAND_ID = CommandId("cmd_test123")
|
||||
|
||||
@@ -47,7 +63,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -77,7 +93,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -102,7 +118,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -116,7 +132,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
|
||||
async def test_no_content_produces_empty_text_block(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
assert len(response.content) == 1
|
||||
|
||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
||||
target_instances = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
|
||||
@@ -184,19 +184,10 @@ def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
|
||||
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
deleted_instance = state.instances.get(event.instance_id)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||
}
|
||||
runner_ids_to_remove: set[RunnerId] = set()
|
||||
if deleted_instance is not None:
|
||||
runner_ids_to_remove = set(
|
||||
deleted_instance.shard_assignments.runner_to_shard.keys()
|
||||
)
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances, "runners": new_runners})
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
@@ -227,11 +218,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
# Clean up all granular node mappings
|
||||
node_identities = {
|
||||
key: value
|
||||
for key, value in state.node_identities.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_memory = {
|
||||
key: value for key, value in state.node_memory.items() if key != event.node_id
|
||||
}
|
||||
@@ -272,7 +258,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
"node_identities": node_identities,
|
||||
"node_memory": node_memory,
|
||||
"node_disk": node_disk,
|
||||
"node_system": node_system,
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
from exo.shared.apply import apply_instance_deleted
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.tests.conftest import get_pipeline_shard_metadata
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import InstanceDeleted
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import InstanceId, MlxRingInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerId,
|
||||
RunnerReady,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
INSTANCE_2_ID,
|
||||
MODEL_A_ID,
|
||||
MODEL_B_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
)
|
||||
|
||||
|
||||
def _make_instance(
|
||||
instance_id: InstanceId,
|
||||
model_id: ModelId,
|
||||
node_to_runner: dict[NodeId, RunnerId],
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata],
|
||||
) -> MlxRingInstance:
|
||||
return MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=model_id,
|
||||
node_to_runner=node_to_runner,
|
||||
runner_to_shard=runner_to_shard,
|
||||
),
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_instance_deleted_removes_runners():
|
||||
"""Deleting an instance must also remove its runner entries from state."""
|
||||
shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
|
||||
instance = _make_instance(
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
{NODE_A: RUNNER_1_ID},
|
||||
{RUNNER_1_ID: shard},
|
||||
)
|
||||
state = State(
|
||||
instances={INSTANCE_1_ID: instance},
|
||||
runners={RUNNER_1_ID: RunnerReady()},
|
||||
)
|
||||
|
||||
new_state = apply_instance_deleted(
|
||||
InstanceDeleted(instance_id=INSTANCE_1_ID), state
|
||||
)
|
||||
|
||||
assert INSTANCE_1_ID not in new_state.instances
|
||||
assert RUNNER_1_ID not in new_state.runners
|
||||
|
||||
|
||||
def test_instance_deleted_removes_only_its_runners():
|
||||
"""Deleting one instance must not remove runners belonging to another."""
|
||||
shard_a = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
|
||||
shard_b = get_pipeline_shard_metadata(MODEL_B_ID, device_rank=0)
|
||||
instance_1 = _make_instance(
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
{NODE_A: RUNNER_1_ID},
|
||||
{RUNNER_1_ID: shard_a},
|
||||
)
|
||||
instance_2 = _make_instance(
|
||||
INSTANCE_2_ID,
|
||||
MODEL_B_ID,
|
||||
{NODE_B: RUNNER_2_ID},
|
||||
{RUNNER_2_ID: shard_b},
|
||||
)
|
||||
state = State(
|
||||
instances={INSTANCE_1_ID: instance_1, INSTANCE_2_ID: instance_2},
|
||||
runners={RUNNER_1_ID: RunnerReady(), RUNNER_2_ID: RunnerReady()},
|
||||
)
|
||||
|
||||
new_state = apply_instance_deleted(
|
||||
InstanceDeleted(instance_id=INSTANCE_1_ID), state
|
||||
)
|
||||
|
||||
assert INSTANCE_1_ID not in new_state.instances
|
||||
assert RUNNER_1_ID not in new_state.runners
|
||||
# Instance 2 and its runner must remain
|
||||
assert INSTANCE_2_ID in new_state.instances
|
||||
assert RUNNER_2_ID in new_state.runners
|
||||
|
||||
|
||||
def test_instance_deleted_multi_node_removes_all_runners():
|
||||
"""Deleting a multi-node instance removes all of its runners."""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = _make_instance(
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
{NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
{RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
state = State(
|
||||
instances={INSTANCE_1_ID: instance},
|
||||
runners={RUNNER_1_ID: RunnerReady(), RUNNER_2_ID: RunnerReady()},
|
||||
)
|
||||
|
||||
new_state = apply_instance_deleted(
|
||||
InstanceDeleted(instance_id=INSTANCE_1_ID), state
|
||||
)
|
||||
|
||||
assert INSTANCE_1_ID not in new_state.instances
|
||||
assert RUNNER_1_ID not in new_state.runners
|
||||
assert RUNNER_2_ID not in new_state.runners
|
||||
assert len(new_state.runners) == 0
|
||||
|
||||
|
||||
def test_instance_deleted_unknown_id_is_noop_for_runners():
|
||||
"""Deleting a non-existent instance should not affect runners."""
|
||||
shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
|
||||
instance = _make_instance(
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
{NODE_A: RUNNER_1_ID},
|
||||
{RUNNER_1_ID: shard},
|
||||
)
|
||||
unknown_id = InstanceId("99999999-9999-4999-8999-999999999999")
|
||||
state = State(
|
||||
instances={INSTANCE_1_ID: instance},
|
||||
runners={RUNNER_1_ID: RunnerReady()},
|
||||
)
|
||||
|
||||
new_state = apply_instance_deleted(InstanceDeleted(instance_id=unknown_id), state)
|
||||
|
||||
# Everything should remain untouched
|
||||
assert INSTANCE_1_ID in new_state.instances
|
||||
assert RUNNER_1_ID in new_state.runners
|
||||
@@ -3,8 +3,7 @@ from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
@@ -228,13 +227,6 @@ class PlaceInstanceParams(BaseModel):
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
def use_default(cls, v: object):
|
||||
if not v or not isinstance(v, (Sharding, InstanceMeta)):
|
||||
raise PydanticUseDefault()
|
||||
return v
|
||||
|
||||
|
||||
class CreateInstanceParams(BaseModel):
|
||||
instance: Instance
|
||||
|
||||
@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -89,6 +93,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -87,6 +93,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -62,6 +62,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -50,9 +50,7 @@ class RunnerReady(BaseRunnerStatus):
|
||||
|
||||
|
||||
class RunnerRunning(BaseRunnerStatus):
|
||||
"""Runner is processing requests and can accept more (continuous batching)."""
|
||||
|
||||
active_requests: int = 0
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import sys
|
||||
|
||||
|
||||
def print_startup_banner(port: int) -> None:
|
||||
"""Print a prominent startup banner with API endpoint information."""
|
||||
dashboard_url = f"http://localhost:{port}"
|
||||
banner = f"""
|
||||
╔═══════════════════════════════════════════════════════════════════════╗
|
||||
@@ -27,4 +29,4 @@ def print_startup_banner(port: int) -> None:
|
||||
|
||||
"""
|
||||
|
||||
print(banner)
|
||||
print(banner, file=sys.stderr)
|
||||
|
||||
@@ -125,7 +125,9 @@ class MpSender[T]:
|
||||
self._state.buffer.put(item, block=True)
|
||||
|
||||
async def send_async(self, item: T) -> None:
|
||||
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
|
||||
await to_thread.run_sync(
|
||||
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import TaskId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingInsert:
|
||||
"""Pre-tokenized request ready for batch insertion."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
tokens: list[int]
|
||||
max_tokens: int
|
||||
prompt_tokens: int
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""Tracks an active request in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
uid: int # BatchGenerator's internal ID
|
||||
detokenizer: StreamingDetokenizer
|
||||
tokens_generated: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedGenerationResponse:
|
||||
"""Response from batch engine, tagged with command_id and task_id."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
response: GenerationResponse
|
||||
|
||||
|
||||
class BatchGenerationEngine:
|
||||
"""Manages continuous batching using mlx_lm's BatchGenerator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
completion_batch_size: int = 32,
|
||||
prefill_batch_size: int = 8,
|
||||
prefill_step_size: int = 2048,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.active_requests: dict[int, ActiveRequest] = {}
|
||||
self._pending_inserts: list[PendingInsert] = []
|
||||
self._pending_completions: list[
|
||||
int
|
||||
] = [] # UIDs completed but not yet synced/removed
|
||||
|
||||
self.group = group
|
||||
self.rank = group.rank() if group else 0
|
||||
self.is_distributed = group is not None and group.size() > 1
|
||||
|
||||
sampler = make_sampler(temp=0.7, top_p=1.0)
|
||||
|
||||
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
|
||||
|
||||
self.batch_gen: BatchGenerator = BatchGenerator(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
stop_tokens=eos_tokens,
|
||||
sampler=sampler,
|
||||
completion_batch_size=completion_batch_size,
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
prefill_step_size=prefill_step_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
|
||||
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
|
||||
)
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> str:
|
||||
"""Queue a pre-tokenized request for insertion. Only rank 0 should call this.
|
||||
|
||||
Tokenization happens here (eagerly) so that sync_and_insert_pending()
|
||||
only does the lightweight batch_gen.insert() call, keeping the decode
|
||||
thread unblocked for as long as possible.
|
||||
|
||||
Returns the prompt string for caller use (e.g. thinking-mode detection).
|
||||
"""
|
||||
assert self.rank == 0, "Only rank 0 should queue requests"
|
||||
prompt_str = apply_chat_template(self.tokenizer, task_params)
|
||||
tokens: list[int] = self.tokenizer.encode(prompt_str, add_special_tokens=False)
|
||||
max_tokens = task_params.max_output_tokens or self.max_tokens
|
||||
self._pending_inserts.append(
|
||||
PendingInsert(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
tokens=tokens,
|
||||
max_tokens=max_tokens,
|
||||
prompt_tokens=len(tokens),
|
||||
temperature=task_params.temperature,
|
||||
top_p=task_params.top_p,
|
||||
top_k=task_params.top_k,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)}, prompt_tokens={len(tokens)})"
|
||||
)
|
||||
return prompt_str
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Sync pre-tokenized pending inserts across ranks and insert them. Returns UIDs.
|
||||
|
||||
Tokens are already prepared by queue_request(), so this method only does
|
||||
the lightweight batch_gen.insert() call plus distributed sync if needed.
|
||||
"""
|
||||
inserts_to_process: list[PendingInsert]
|
||||
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: just insert directly from pending
|
||||
inserts_to_process = list(self._pending_inserts)
|
||||
else:
|
||||
# Distributed: broadcast pre-tokenized inserts from rank 0 to all ranks
|
||||
assert self.group is not None
|
||||
inserts_to_process = share_object(
|
||||
self._pending_inserts if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
|
||||
if not inserts_to_process:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
# Update sampler from per-request parameters (last request wins for batch)
|
||||
last = inserts_to_process[-1]
|
||||
self.batch_gen.sampler = make_sampler( # pyright: ignore[reportAttributeAccessIssue]
|
||||
temp=last.temperature if last.temperature is not None else 0.7,
|
||||
top_p=last.top_p if last.top_p is not None else 1.0,
|
||||
top_k=last.top_k if last.top_k is not None else 0,
|
||||
)
|
||||
|
||||
# Single batched insert for efficient prefill — tokens already prepared
|
||||
all_tokens = [p.tokens for p in inserts_to_process]
|
||||
all_max_tokens = [p.max_tokens for p in inserts_to_process]
|
||||
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
|
||||
|
||||
# Track all inserted requests
|
||||
for i, uid in enumerate(uids):
|
||||
p = inserts_to_process[i]
|
||||
self.active_requests[uid] = ActiveRequest(
|
||||
command_id=p.command_id,
|
||||
task_id=p.task_id,
|
||||
uid=uid,
|
||||
detokenizer=self.tokenizer.detokenizer,
|
||||
prompt_tokens=p.prompt_tokens,
|
||||
)
|
||||
logger.info(
|
||||
f"Inserted request {p.command_id} with uid={uid}, prompt_tokens={p.prompt_tokens}, max_tokens={p.max_tokens}"
|
||||
)
|
||||
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
|
||||
responses = self.batch_gen.next()
|
||||
if not responses:
|
||||
return []
|
||||
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
|
||||
for r in responses:
|
||||
uid: int = r.uid
|
||||
req = self.active_requests.get(uid)
|
||||
if req is None:
|
||||
logger.warning(f"Received response for unknown uid={uid}")
|
||||
continue
|
||||
|
||||
req.tokens_generated += 1
|
||||
|
||||
# Decode the token
|
||||
token: int = r.token
|
||||
req.detokenizer.add_token(token)
|
||||
text: str = req.detokenizer.last_segment
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
raw_finish_reason: str | None = r.finish_reason
|
||||
if raw_finish_reason is not None:
|
||||
# Finalize to get remaining text
|
||||
req.detokenizer.finalize()
|
||||
text = req.detokenizer.last_segment
|
||||
|
||||
elapsed = time.perf_counter() - req.start_time
|
||||
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=0.0, # Not tracked per-request in batch mode
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=req.prompt_tokens,
|
||||
generation_tokens=req.tokens_generated,
|
||||
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
|
||||
)
|
||||
|
||||
if raw_finish_reason in get_args(FinishReason):
|
||||
finish_reason = raw_finish_reason # pyright: ignore[reportAssignmentType]
|
||||
else:
|
||||
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
|
||||
finish_reason = "stop"
|
||||
|
||||
# Track completion but don't remove yet - wait for sync_completions()
|
||||
self._pending_completions.append(uid)
|
||||
logger.info(
|
||||
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=req.command_id,
|
||||
task_id=req.task_id,
|
||||
response=GenerationResponse(
|
||||
text=text,
|
||||
token=token,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# In non-distributed mode, clean up completions immediately
|
||||
if not self.is_distributed:
|
||||
self._remove_completed()
|
||||
|
||||
return results
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: early return if nothing to do
|
||||
if not self._pending_completions:
|
||||
return
|
||||
self._remove_completed()
|
||||
return
|
||||
|
||||
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
|
||||
# This prevents deadlock if one rank has completions and another doesn't
|
||||
assert self.group is not None
|
||||
self._pending_completions = share_object(
|
||||
self._pending_completions if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
self._remove_completed()
|
||||
|
||||
def _remove_completed(self) -> None:
|
||||
"""Remove completed requests from tracking."""
|
||||
for uid in self._pending_completions:
|
||||
if uid in self.active_requests:
|
||||
del self.active_requests[uid]
|
||||
self._pending_completions.clear()
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self.active_requests)
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return len(self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def has_pending_completions(self) -> bool:
|
||||
return bool(self._pending_completions)
|
||||
@@ -1,34 +0,0 @@
|
||||
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
import pickle
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def share_object[T](obj: T | None, rank: int, group: mx.distributed.Group) -> T:
|
||||
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data.
|
||||
|
||||
Rank 0 must always provide a non-None object. Non-rank-0 callers pass None
|
||||
(they are receivers only). Use mx_barrier() instead if no data needs to be shared.
|
||||
"""
|
||||
if rank == 0:
|
||||
assert obj is not None, (
|
||||
"Rank 0 must provide data; use mx_barrier() to sync without data"
|
||||
)
|
||||
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
|
||||
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
|
||||
mx.eval(mx.distributed.all_sum(data, group=group))
|
||||
return obj
|
||||
else:
|
||||
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
|
||||
if size == 0:
|
||||
raise RuntimeError(
|
||||
"share_object received size=0 from rank 0 — protocol violation"
|
||||
)
|
||||
data = mx.zeros(size, dtype=mx.uint8)
|
||||
data = mx.distributed.all_sum(data, group=group)
|
||||
mx.eval(data)
|
||||
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))
|
||||
@@ -306,7 +306,7 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Ready to prefill")
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
@@ -393,10 +393,11 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
total_prompt_tokens = len(all_prompt_tokens)
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
total_tokens=total_prompt_tokens + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"""Time budget iterator for controlling generation loop timing in distributed mode.
|
||||
|
||||
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
|
||||
rather than syncing every token. This reduces distributed sync overhead.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class TimeBudget(Iterator[None]):
|
||||
"""Controls generation loop timing, syncing across ranks periodically.
|
||||
|
||||
In distributed mode, periodically syncs timing across all ranks to
|
||||
dynamically adjust iteration count based on actual performance.
|
||||
|
||||
In non-distributed mode, simply runs for the time budget.
|
||||
|
||||
Usage:
|
||||
for _ in TimeBudget(budget=0.5):
|
||||
batch_engine.step()
|
||||
# ... process responses ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
budget: float = 0.5,
|
||||
iterations: int = 25,
|
||||
sync_frequency: int = 10,
|
||||
group: mx.distributed.Group | None = None,
|
||||
):
|
||||
"""Initialize TimeBudget.
|
||||
|
||||
Args:
|
||||
budget: Time budget in seconds before yielding control
|
||||
iterations: Initial number of iterations per budget period (distributed only)
|
||||
sync_frequency: How often to sync timing across ranks (distributed only)
|
||||
group: Distributed group, or None for non-distributed mode
|
||||
"""
|
||||
self._budget = budget
|
||||
self._iterations = iterations
|
||||
self._sync_frequency = sync_frequency
|
||||
self._group = group
|
||||
self._is_distributed = group is not None and group.size() > 1
|
||||
|
||||
# Runtime state
|
||||
self._start: float = 0.0
|
||||
self._current_iterations: int = 0
|
||||
self._loops: int = 0
|
||||
self._time_spent: float = 0.0
|
||||
|
||||
def __iter__(self) -> "TimeBudget":
|
||||
self._start = time.perf_counter()
|
||||
self._current_iterations = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> None:
|
||||
if not self._is_distributed:
|
||||
# Non-distributed: just check time budget
|
||||
if time.perf_counter() - self._start > self._budget:
|
||||
raise StopIteration()
|
||||
return None
|
||||
|
||||
# Distributed mode: iteration-based with periodic timing sync
|
||||
self._current_iterations += 1
|
||||
if self._current_iterations > self._iterations:
|
||||
self._loops += 1
|
||||
self._time_spent += time.perf_counter() - self._start
|
||||
|
||||
if self._loops % self._sync_frequency == 0:
|
||||
# Sync timing across all ranks
|
||||
assert self._group is not None
|
||||
with mx.stream(generation_stream):
|
||||
time_array = mx.array([self._time_spent], dtype=mx.float32)
|
||||
total_time = mx.distributed.all_sum(time_array, group=self._group)
|
||||
mx.eval(total_time)
|
||||
loop_time = float(total_time.item())
|
||||
|
||||
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
|
||||
|
||||
if avg_loop_time > 0:
|
||||
factor = self._budget / avg_loop_time
|
||||
self._iterations = max(round(self._iterations * factor), 1)
|
||||
logger.debug(
|
||||
f"TimeBudget adjusted iterations to {self._iterations}"
|
||||
)
|
||||
|
||||
self._loops = 0
|
||||
self._time_spent = 0.0
|
||||
|
||||
raise StopIteration()
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Current iterations per budget period."""
|
||||
return self._iterations
|
||||
@@ -64,8 +64,6 @@ from exo.worker.runner.bootstrap import logger
|
||||
Group = mx.distributed.Group
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -83,30 +81,6 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -379,7 +353,13 @@ def load_tokenizer_for_model_id(
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
|
||||
return TokenizerWrapper(
|
||||
hf_tokenizer,
|
||||
eos_token_ids=eos_token_ids,
|
||||
tool_call_start="<|tool_calls_section_begin|>",
|
||||
tool_call_end="<|tool_calls_section_end|>",
|
||||
tool_parser=_parse_kimi_tool_calls,
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
@@ -591,3 +571,61 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _parse_kimi_tool_calls(text: str):
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
_tool_call_split_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
|
||||
)
|
||||
|
||||
def _parse_single_tool(text: str) -> dict[str, Any]:
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError("No tool call found.")
|
||||
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
|
||||
func_name = func_name_match.group(2) # e.g. "get_weather"
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError("No tool call arguments found.")
|
||||
func_args = func_args_match.group(1)
|
||||
try:
|
||||
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
arg_dct = None
|
||||
|
||||
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
|
||||
|
||||
tool_matches = _tool_call_split_regex.findall(text)
|
||||
if tool_matches:
|
||||
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
|
||||
else:
|
||||
return [_parse_single_tool(text)]
|
||||
|
||||
@@ -33,6 +33,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -172,127 +173,123 @@ class Worker:
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
# Drain all available tasks before sleeping again.
|
||||
# This ensures concurrent request arrivals are dispatched
|
||||
# rapidly rather than one-per-100ms.
|
||||
while True:
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
break
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
|
||||
# Gate DownloadModel on backoff BEFORE emitting TaskCreated
|
||||
# to prevent flooding the event log with useless events
|
||||
if isinstance(task, DownloadModel):
|
||||
model_id = task.shard_metadata.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
break
|
||||
# Gate DownloadModel on backoff BEFORE emitting TaskCreated
|
||||
# to prevent flooding the event log with useless events
|
||||
if isinstance(task, DownloadModel):
|
||||
model_id = task.shard_metadata.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
continue
|
||||
|
||||
logger.info(f"Worker plan: {task.__class__.__name__}")
|
||||
assert task.task_status
|
||||
await self.event_sender.send(
|
||||
TaskCreated(task_id=task.task_id, task=task)
|
||||
)
|
||||
logger.info(f"Worker plan: {task.__class__.__name__}")
|
||||
assert task.task_status
|
||||
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
|
||||
|
||||
# lets not kill the worker if a runner is unresponsive
|
||||
match task:
|
||||
case CreateRunner():
|
||||
self._create_supervisor(task)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
# lets not kill the worker if a runner is unresponsive
|
||||
match task:
|
||||
case CreateRunner():
|
||||
self._create_supervisor(task)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
model_id = shard.model_card.model_id
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
model_id = shard.model_card.model_id
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
),
|
||||
)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Running,
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.TimedOut,
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsTaskParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
bench=task.task_params.bench,
|
||||
stream=task.task_params.stream,
|
||||
partial_images=task.task_params.partial_images,
|
||||
advanced_params=task.task_params.advanced_params,
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
task
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await runner.start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsTaskParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
bench=task.task_params.bench,
|
||||
stream=task.task_params.stream,
|
||||
partial_images=task.task_params.partial_images,
|
||||
advanced_params=task.task_params.advanced_params,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self._start_runner_task(modified_task)
|
||||
case task:
|
||||
await self._start_runner_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
await self.runners[
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -331,8 +328,6 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -53,13 +54,14 @@ def plan(
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
)
|
||||
|
||||
|
||||
@@ -270,7 +272,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -284,7 +286,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -292,18 +294,33 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner_id, runner in runners.items():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id,
|
||||
cancelled_task_id=task.task_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,6 +15,7 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -38,7 +39,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -47,9 +47,11 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -60,8 +62,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -69,6 +71,7 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -83,6 +86,7 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -97,6 +101,8 @@ class RunnerSupervisor:
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
@@ -112,14 +118,6 @@ class RunnerSupervisor:
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
@@ -141,6 +139,17 @@ class RunnerSupervisor:
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
with anyio.move_on_after(0.5) as scope:
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
if scope.cancel_called:
|
||||
logger.error("RunnerSupervisor cancel pipe blocked")
|
||||
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
@@ -148,11 +157,7 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
# Signal start_task() to return, but keep the entry
|
||||
# in self.pending so _pending_tasks won't re-dispatch.
|
||||
pending_event = self.pending.get(event.task_id)
|
||||
if pending_event is not None:
|
||||
pending_event.set()
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
@@ -170,8 +175,6 @@ class RunnerSupervisor:
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
# Clean up from pending now that it's fully complete
|
||||
self.pending.pop(event.task_id, None)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
72
src/exo/worker/runner/tool_parsers.py
Normal file
72
src/exo/worker/runner/tool_parsers.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParser:
|
||||
start_parsing: str
|
||||
end_parsing: str
|
||||
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
|
||||
|
||||
|
||||
def make_mlx_parser(
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
) -> ToolParser:
|
||||
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix(tool_call_start)
|
||||
text = text.removesuffix(tool_call_end)
|
||||
parsed = tool_parser(text)
|
||||
if isinstance(parsed, list):
|
||||
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
|
||||
else:
|
||||
return [ToolCallItem.model_validate(_flatten(parsed))]
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return ToolParser(
|
||||
start_parsing=tool_call_start,
|
||||
end_parsing=tool_call_end,
|
||||
parse_tool_calls=parse_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
# TODO / example code:
|
||||
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix("<tool_call>")
|
||||
text = text.removesuffix("</tool_call>")
|
||||
top_level = {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else v
|
||||
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
|
||||
}
|
||||
return [ToolCallItem.model_validate(top_level)]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _flatten(p: dict[str, Any]) -> dict[str, str]:
|
||||
return {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
|
||||
for k, v in p.items() # pyright: ignore[reportAny]
|
||||
}
|
||||
|
||||
|
||||
json_tool_parser = ToolParser(
|
||||
start_parsing="<tool_call>",
|
||||
end_parsing="</tool_call>",
|
||||
parse_tool_calls=_parse_json_calls,
|
||||
)
|
||||
|
||||
|
||||
def infer_tool_parser(chat_template: str) -> ToolParser | None:
|
||||
"""Attempt to auto-infer a tool parser from the chat template."""
|
||||
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
|
||||
return json_tool_parser
|
||||
return None
|
||||
@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
"""
|
||||
Tests for continuous batching behavior in the runner.
|
||||
|
||||
These tests verify that:
|
||||
1. Single requests work through the batch path
|
||||
2. Multiple concurrent requests batch together
|
||||
3. Tokens are routed to the correct requests
|
||||
4. Requests complete at different times appropriately
|
||||
|
||||
NOTE: These tests require the continuous-batching runner architecture
|
||||
(BatchGenerationEngine) which is not yet integrated with main.
|
||||
"""
|
||||
|
||||
# ruff: noqa: E402
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerRunning
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""
|
||||
Fake batch engine that generates a specified number of tokens per request.
|
||||
|
||||
This simulates realistic batch generation behavior where:
|
||||
- Requests are queued on insert
|
||||
- Each step() call generates one token for all active requests
|
||||
- Requests complete when they've generated all their tokens
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, TextGenerationTaskParams]
|
||||
] = []
|
||||
self._uid_counter = 0
|
||||
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> str:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
return ""
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._do_insert(command_id, task_id, task_params)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def _do_insert(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams | None,
|
||||
) -> int:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
# Track: (command_id, task_id, tokens_generated, max_tokens)
|
||||
max_tokens = (
|
||||
task_params.max_output_tokens if task_params else self._tokens_per_request
|
||||
)
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish_reason = "stop" if tokens_gen >= max_tokens else None
|
||||
text = f"token{tokens_gen}"
|
||||
|
||||
if finish_reason:
|
||||
uids_to_remove.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (
|
||||
command_id,
|
||||
task_id,
|
||||
tokens_gen,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for uid in uids_to_remove:
|
||||
del self._active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
pass # Completions already removed in step()
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False # Non-distributed mode for testing
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
"""Mock tokenizer with tool calling disabled."""
|
||||
|
||||
tool_parser = None
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
"""Fake MLX distributed group for testing."""
|
||||
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1 # Single node (non-distributed)
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
class EventCollector:
|
||||
"""Collects events directly into a list to avoid mp_channel flakiness."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: list[Event] = []
|
||||
|
||||
def send(self, event: Event) -> None:
|
||||
self.events.append(event)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def join(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
|
||||
"""
|
||||
Run tasks through the runner, adding shutdown at the end.
|
||||
|
||||
Tasks are sent in order, with shutdown sent last.
|
||||
The batch engine processes between task handling.
|
||||
"""
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_collector = EventCollector()
|
||||
|
||||
shutdown_task = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
with task_sender:
|
||||
# Send all tasks including shutdown
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown_task)
|
||||
|
||||
# Disable cleanup methods to prevent issues
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
|
||||
mlx_runner.main(bound_instance, event_collector, task_receiver) # type: ignore[arg-type]
|
||||
|
||||
return event_collector.events
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID,
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_single_request_generates_tokens(patch_batch_engine: None):
|
||||
"""
|
||||
Verify a single request generates the expected tokens through the batch path.
|
||||
|
||||
Tokens are generated during the generation loop (not during shutdown drain).
|
||||
The task completes after all tokens are generated.
|
||||
"""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Verify ChunkGenerated events are emitted for all tokens
|
||||
chunk_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd1")
|
||||
]
|
||||
assert len(chunk_events) == 3, (
|
||||
f"Expected 3 ChunkGenerated events, got {len(chunk_events)}"
|
||||
)
|
||||
|
||||
# Last chunk should have finish_reason="stop"
|
||||
last_chunk = chunk_events[-1].chunk
|
||||
assert isinstance(last_chunk, TokenChunk)
|
||||
assert last_chunk.finish_reason == "stop"
|
||||
|
||||
# Task should be marked complete after tokens are generated
|
||||
chat_complete = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Complete
|
||||
]
|
||||
assert len(chat_complete) == 1, "Expected exactly one chat task Complete status"
|
||||
|
||||
|
||||
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
|
||||
"""Verify RunnerRunning status includes active_requests count."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_chat_task_acknowledged(patch_batch_engine: None):
|
||||
"""Verify chat completion task is acknowledged with proper status updates."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find the chat task status events
|
||||
chat_running = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Running
|
||||
]
|
||||
|
||||
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
|
||||
|
||||
|
||||
def test_multiple_requests_generate_tokens(patch_batch_engine: None):
|
||||
"""Verify multiple requests each generate their expected tokens."""
|
||||
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
|
||||
|
||||
# Both requests should generate their expected number of tokens
|
||||
cmd1_chunks = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd1")
|
||||
]
|
||||
cmd2_chunks = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, ChunkGenerated) and e.command_id == CommandId("cmd2")
|
||||
]
|
||||
|
||||
assert len(cmd1_chunks) == 2, f"Expected 2 chunks for cmd1, got {len(cmd1_chunks)}"
|
||||
assert len(cmd2_chunks) == 2, f"Expected 2 chunks for cmd2, got {len(cmd2_chunks)}"
|
||||
|
||||
# Both tasks should be completed
|
||||
completed_task_ids = {
|
||||
e.task_id
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_status == TaskStatus.Complete
|
||||
and e.task_id in (TaskId("chat1"), TaskId("chat2"))
|
||||
}
|
||||
assert TaskId("chat1") in completed_task_ids
|
||||
assert TaskId("chat2") in completed_task_ids
|
||||
@@ -1,719 +0,0 @@
|
||||
"""
|
||||
Edge-case tests for continuous batching in the runner.
|
||||
|
||||
Tests cover:
|
||||
1. Concurrent requests with overlapping tool calls
|
||||
2. Requests that finish mid-generation with 'length' reason
|
||||
3. Multiple requests finishing on the same step() call
|
||||
4. Batch of 5+ simultaneous completions
|
||||
"""
|
||||
|
||||
# ruff: noqa: E402
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
# pyright: reportPrivateUsage=false
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerReady
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake batch engines
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScriptedBatchEngine:
|
||||
"""Batch engine driven by scripted per-request token sequences.
|
||||
|
||||
Each request produces a predefined list of (text, finish_reason) pairs.
|
||||
One step() call pops one token per active request.
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active: dict[
|
||||
int, tuple[CommandId, TaskId, list[tuple[str, FinishReason | None]]]
|
||||
] = {}
|
||||
self._pending: list[tuple[CommandId, TaskId, TextGenerationTaskParams]] = []
|
||||
self._uid = 0
|
||||
self.rank = 0
|
||||
# map command_id -> scripted tokens, set externally before tasks arrive
|
||||
self.scripts: dict[str, list[tuple[str, FinishReason | None]]] = {}
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> str:
|
||||
self._pending.append((command_id, task_id, task_params))
|
||||
return ""
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
uids: list[int] = []
|
||||
for cmd_id, task_id, _params in self._pending:
|
||||
uid = self._uid
|
||||
self._uid += 1
|
||||
script = list(self.scripts.get(str(cmd_id), [("tok", "stop")]))
|
||||
self._active[uid] = (cmd_id, task_id, script)
|
||||
uids.append(uid)
|
||||
self._pending.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending)
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
done: list[int] = []
|
||||
for uid, (cmd_id, task_id, script) in self._active.items():
|
||||
if not script:
|
||||
continue
|
||||
text, finish_reason = script.pop(0)
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=0, text=text, finish_reason=finish_reason, usage=None
|
||||
),
|
||||
)
|
||||
)
|
||||
if finish_reason is not None:
|
||||
done.append(uid)
|
||||
for uid in done:
|
||||
del self._active[uid]
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self._active)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active)
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""Generates N tokens per request (reused from the main test file)."""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, TextGenerationTaskParams]
|
||||
] = []
|
||||
self._uid_counter = 0
|
||||
self.rank = 0
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: TextGenerationTaskParams,
|
||||
) -> str:
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
return ""
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
max_tokens = task_params.max_output_tokens or 3
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
done: list[int] = []
|
||||
for uid, (cmd_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish = "stop" if tokens_gen >= max_tokens else None
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=f"token{tokens_gen}",
|
||||
finish_reason=finish,
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
if finish:
|
||||
done.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (cmd_id, task_id, tokens_gen, max_tokens)
|
||||
for uid in done:
|
||||
del self._active_requests[uid]
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self._active_requests)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock tokenizers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
tool_parser = None
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class MockToolTokenizer:
|
||||
"""Tokenizer with tool calling enabled for testing."""
|
||||
|
||||
has_tool_calling = True
|
||||
has_thinking = False
|
||||
tool_call_start = "<tool>"
|
||||
tool_call_end = "</tool>"
|
||||
|
||||
@staticmethod
|
||||
def _tool_parser(text: str) -> dict[str, Any]:
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event collector & runner helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EventCollector:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[Event] = []
|
||||
|
||||
def send(self, event: Event) -> None:
|
||||
self.events.append(event)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def join(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
SETUP_TASKS: list[Task] = [INIT_TASK, LOAD_TASK, WARMUP_TASK]
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID,
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def _run_with_tasks(
|
||||
tasks: list[Task],
|
||||
engine_cls: type = FakeBatchEngineWithTokens,
|
||||
tokenizer_cls: type = MockTokenizer,
|
||||
engine_instance: Any | None = None,
|
||||
) -> list[Event]:
|
||||
"""Run tasks through the runner with configurable engine and tokenizer."""
|
||||
bound = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
collector = EventCollector()
|
||||
shutdown = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
import exo.worker.runner.runner as r
|
||||
|
||||
orig_init_mlx = r.initialize_mlx
|
||||
orig_load = r.load_mlx_items
|
||||
orig_warmup = r.warmup_inference
|
||||
orig_check = r._check_for_debug_prompts
|
||||
orig_engine = r.BatchGenerationEngine
|
||||
|
||||
r.initialize_mlx = make_nothin(FakeGroup())
|
||||
r.load_mlx_items = make_nothin((MagicMock(), tokenizer_cls))
|
||||
r.warmup_inference = make_nothin(1)
|
||||
r._check_for_debug_prompts = make_nothin(None)
|
||||
if engine_instance is not None:
|
||||
r.BatchGenerationEngine = lambda *_a, **_kw: engine_instance # pyright: ignore[reportUnknownLambdaType]
|
||||
else:
|
||||
r.BatchGenerationEngine = engine_cls
|
||||
|
||||
try:
|
||||
with task_sender:
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown)
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
r.main(bound, collector, task_receiver) # pyright: ignore[reportArgumentType]
|
||||
finally:
|
||||
r.initialize_mlx = orig_init_mlx
|
||||
r.load_mlx_items = orig_load
|
||||
r.warmup_inference = orig_warmup
|
||||
r._check_for_debug_prompts = orig_check
|
||||
r.BatchGenerationEngine = orig_engine
|
||||
|
||||
return collector.events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for querying events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def chunks_for(events: list[Event], command_id: str) -> list[ChunkGenerated]:
|
||||
return [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, ChunkGenerated) and e.command_id == CommandId(command_id)
|
||||
]
|
||||
|
||||
|
||||
def completed_task_ids(events: list[Event]) -> set[TaskId]:
|
||||
return {
|
||||
e.task_id
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated) and e.task_status == TaskStatus.Complete
|
||||
}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test 1: Concurrent requests with overlapping tool calls
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_concurrent_tool_calls_and_normal_text():
|
||||
"""Two concurrent requests: one emits normal text, the other a tool call.
|
||||
|
||||
Verifies that:
|
||||
- The normal request produces TokenChunks with its text
|
||||
- The tool-call request produces a ToolCallChunk
|
||||
- Both tasks complete
|
||||
"""
|
||||
engine = ScriptedBatchEngine()
|
||||
# cmd_normal: 2 normal tokens then stop
|
||||
engine.scripts["cmd_normal"] = [
|
||||
("hello", None),
|
||||
(" world", "stop"),
|
||||
]
|
||||
# cmd_tool: tool_start, body, tool_end (suppressed), then finish
|
||||
engine.scripts["cmd_tool"] = [
|
||||
("<tool>", None), # swallowed by tracker
|
||||
('{"name":"get_weather","arguments":{"city":"SF"}}', None), # accumulated
|
||||
("</tool>", None), # triggers ToolCallChunk emission
|
||||
("done", "stop"), # normal trailing token
|
||||
]
|
||||
|
||||
chat_normal = make_chat_task("t_normal", "cmd_normal", max_tokens=100)
|
||||
chat_tool = make_chat_task("t_tool", "cmd_tool", max_tokens=100)
|
||||
|
||||
events = _run_with_tasks(
|
||||
[*SETUP_TASKS, chat_normal, chat_tool],
|
||||
tokenizer_cls=MockToolTokenizer,
|
||||
engine_instance=engine,
|
||||
)
|
||||
|
||||
# Normal request: all chunks should be TokenChunk
|
||||
normal_chunks = chunks_for(events, "cmd_normal")
|
||||
assert len(normal_chunks) == 2
|
||||
assert all(isinstance(c.chunk, TokenChunk) for c in normal_chunks)
|
||||
assert normal_chunks[-1].chunk.finish_reason == "stop"
|
||||
|
||||
# Tool-call request
|
||||
tool_chunks = chunks_for(events, "cmd_tool")
|
||||
# <tool> → swallowed, body → accumulated, </tool> → ToolCallChunk, "done" → TokenChunk
|
||||
tool_call_events = [c for c in tool_chunks if isinstance(c.chunk, ToolCallChunk)]
|
||||
token_events = [c for c in tool_chunks if isinstance(c.chunk, TokenChunk)]
|
||||
|
||||
assert len(tool_call_events) == 1, (
|
||||
f"Expected 1 ToolCallChunk, got {len(tool_call_events)}"
|
||||
)
|
||||
tc_chunk = tool_call_events[0].chunk
|
||||
assert isinstance(tc_chunk, ToolCallChunk)
|
||||
assert tc_chunk.tool_calls[0].name == "get_weather"
|
||||
assert json.loads(tc_chunk.tool_calls[0].arguments) == {"city": "SF"}
|
||||
|
||||
assert len(token_events) == 1, "Expected 1 trailing TokenChunk after tool call"
|
||||
assert token_events[0].chunk.finish_reason == "stop"
|
||||
|
||||
# Both tasks should complete
|
||||
done = completed_task_ids(events)
|
||||
assert TaskId("t_normal") in done
|
||||
assert TaskId("t_tool") in done
|
||||
|
||||
|
||||
def test_tool_call_interrupted_by_finish_reason():
|
||||
"""Tool call in progress when finish_reason fires — partial text emitted."""
|
||||
engine = ScriptedBatchEngine()
|
||||
engine.scripts["cmd1"] = [
|
||||
("<tool>", None),
|
||||
('{"name":"f"', "stop"), # finish while inside tool call
|
||||
]
|
||||
|
||||
chat = make_chat_task("t1", "cmd1", max_tokens=100)
|
||||
events = _run_with_tasks(
|
||||
[*SETUP_TASKS, chat],
|
||||
tokenizer_cls=MockToolTokenizer,
|
||||
engine_instance=engine,
|
||||
)
|
||||
|
||||
chunks = chunks_for(events, "cmd1")
|
||||
assert len(chunks) == 1
|
||||
chunk = chunks[0].chunk
|
||||
assert isinstance(chunk, TokenChunk)
|
||||
# The interrupted tool call should be emitted as raw text
|
||||
assert "<tool>" in chunk.text
|
||||
assert '{"name":"f"' in chunk.text
|
||||
assert chunk.finish_reason == "stop"
|
||||
|
||||
assert TaskId("t1") in completed_task_ids(events)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test 2: Request finishing with 'length' reason (timeout mid-generation)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_request_finishes_with_length_reason():
|
||||
"""Request that hits max_tokens limit and finishes with 'length'."""
|
||||
engine = ScriptedBatchEngine()
|
||||
engine.scripts["cmd1"] = [
|
||||
("tok1", None),
|
||||
("tok2", None),
|
||||
("tok3", "length"), # hit the token limit
|
||||
]
|
||||
|
||||
chat = make_chat_task("t1", "cmd1", max_tokens=100)
|
||||
events = _run_with_tasks(
|
||||
[*SETUP_TASKS, chat],
|
||||
engine_instance=engine,
|
||||
)
|
||||
|
||||
chunks = chunks_for(events, "cmd1")
|
||||
assert len(chunks) == 3
|
||||
|
||||
# Last chunk should have finish_reason="length"
|
||||
assert isinstance(chunks[-1].chunk, TokenChunk)
|
||||
assert chunks[-1].chunk.finish_reason == "length"
|
||||
|
||||
# Earlier chunks should have no finish_reason
|
||||
for c in chunks[:-1]:
|
||||
assert isinstance(c.chunk, TokenChunk)
|
||||
assert c.chunk.finish_reason is None
|
||||
|
||||
assert TaskId("t1") in completed_task_ids(events)
|
||||
|
||||
|
||||
def test_mixed_finish_reasons_across_requests():
|
||||
"""Two requests finishing with different reasons: 'stop' and 'length'."""
|
||||
engine = ScriptedBatchEngine()
|
||||
engine.scripts["cmd_stop"] = [("a", None), ("b", "stop")]
|
||||
engine.scripts["cmd_len"] = [("x", None), ("y", "length")]
|
||||
|
||||
chat1 = make_chat_task("t_stop", "cmd_stop", max_tokens=100)
|
||||
chat2 = make_chat_task("t_len", "cmd_len", max_tokens=100)
|
||||
|
||||
events = _run_with_tasks(
|
||||
[*SETUP_TASKS, chat1, chat2],
|
||||
engine_instance=engine,
|
||||
)
|
||||
|
||||
stop_chunks = chunks_for(events, "cmd_stop")
|
||||
len_chunks = chunks_for(events, "cmd_len")
|
||||
|
||||
assert stop_chunks[-1].chunk.finish_reason == "stop"
|
||||
assert len_chunks[-1].chunk.finish_reason == "length"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
assert TaskId("t_stop") in done
|
||||
assert TaskId("t_len") in done
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test 3: Multiple finish reasons in rapid succession (same step)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_all_requests_finish_on_same_step():
|
||||
"""Three requests that all finish on the same step() call.
|
||||
|
||||
This tests that the runner and _process_generation_results correctly
|
||||
handle multiple completions in a single step.
|
||||
"""
|
||||
engine = ScriptedBatchEngine()
|
||||
# All three produce exactly 1 token and finish
|
||||
engine.scripts["cmd_a"] = [("alpha", "stop")]
|
||||
engine.scripts["cmd_b"] = [("beta", "stop")]
|
||||
engine.scripts["cmd_c"] = [("gamma", "stop")]
|
||||
|
||||
tasks = [
|
||||
*SETUP_TASKS,
|
||||
make_chat_task("ta", "cmd_a", max_tokens=100),
|
||||
make_chat_task("tb", "cmd_b", max_tokens=100),
|
||||
make_chat_task("tc", "cmd_c", max_tokens=100),
|
||||
]
|
||||
events = _run_with_tasks([*tasks], engine_instance=engine)
|
||||
|
||||
for cmd_id, expected_text in [
|
||||
("cmd_a", "alpha"),
|
||||
("cmd_b", "beta"),
|
||||
("cmd_c", "gamma"),
|
||||
]:
|
||||
c = chunks_for(events, cmd_id)
|
||||
assert len(c) == 1, f"Expected 1 chunk for {cmd_id}, got {len(c)}"
|
||||
assert isinstance(c[0].chunk, TokenChunk)
|
||||
assert c[0].chunk.text == expected_text
|
||||
assert c[0].chunk.finish_reason == "stop"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
assert TaskId("ta") in done
|
||||
assert TaskId("tb") in done
|
||||
assert TaskId("tc") in done
|
||||
|
||||
# Runner should reach RunnerReady at least after warmup.
|
||||
# With inline task processing, later requests may be inserted into the
|
||||
# batch before the generation loop exits, so the runner can stay
|
||||
# RunnerRunning until Shutdown without an intermediate RunnerReady.
|
||||
ready_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerReady)
|
||||
]
|
||||
assert len(ready_events) >= 1, "Expected RunnerReady at least after warmup"
|
||||
|
||||
|
||||
def test_staggered_completions_in_batch():
|
||||
"""Four requests with different token counts — they complete at different steps.
|
||||
|
||||
Verifies each request gets the right number of chunks and the runner
|
||||
tracks active_requests correctly as requests drain.
|
||||
"""
|
||||
engine = ScriptedBatchEngine()
|
||||
engine.scripts["c1"] = [("a", "stop")] # finishes step 1
|
||||
engine.scripts["c2"] = [("a", None), ("b", "stop")] # finishes step 2
|
||||
engine.scripts["c3"] = [("a", None), ("b", None), ("c", "stop")] # finishes step 3
|
||||
engine.scripts["c4"] = [
|
||||
("a", None),
|
||||
("b", None),
|
||||
("c", None),
|
||||
("d", "stop"),
|
||||
] # finishes step 4
|
||||
|
||||
tasks = [
|
||||
*SETUP_TASKS,
|
||||
make_chat_task("t1", "c1", max_tokens=100),
|
||||
make_chat_task("t2", "c2", max_tokens=100),
|
||||
make_chat_task("t3", "c3", max_tokens=100),
|
||||
make_chat_task("t4", "c4", max_tokens=100),
|
||||
]
|
||||
events = _run_with_tasks([*tasks], engine_instance=engine)
|
||||
|
||||
assert len(chunks_for(events, "c1")) == 1
|
||||
assert len(chunks_for(events, "c2")) == 2
|
||||
assert len(chunks_for(events, "c3")) == 3
|
||||
assert len(chunks_for(events, "c4")) == 4
|
||||
|
||||
done = completed_task_ids(events)
|
||||
for tid in ["t1", "t2", "t3", "t4"]:
|
||||
assert TaskId(tid) in done, f"Task {tid} should be complete"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test 4: Batch of 5+ simultaneous completions
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MockTokenizer))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
def test_five_simultaneous_completions(patch_batch_engine: None):
|
||||
"""Five requests submitted together, all generating tokens and completing."""
|
||||
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=2) for i in range(5)]
|
||||
events = _run_with_tasks([*SETUP_TASKS, *chats])
|
||||
|
||||
for i in range(5):
|
||||
c = chunks_for(events, f"cmd{i}")
|
||||
assert len(c) == 2, f"Expected 2 chunks for cmd{i}, got {len(c)}"
|
||||
assert c[-1].chunk.finish_reason == "stop"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
for i in range(5):
|
||||
assert TaskId(f"t{i}") in done
|
||||
|
||||
|
||||
def test_eight_requests_staggered(patch_batch_engine: None):
|
||||
"""Eight requests with varying token counts, verifying all complete correctly."""
|
||||
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=i + 1) for i in range(8)]
|
||||
events = _run_with_tasks([*SETUP_TASKS, *chats])
|
||||
|
||||
for i in range(8):
|
||||
c = chunks_for(events, f"cmd{i}")
|
||||
expected = i + 1
|
||||
assert len(c) == expected, (
|
||||
f"Expected {expected} chunks for cmd{i}, got {len(c)}"
|
||||
)
|
||||
assert c[-1].chunk.finish_reason == "stop"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
for i in range(8):
|
||||
assert TaskId(f"t{i}") in done
|
||||
|
||||
# Verify runner transitions back to ready after all requests complete
|
||||
# Find the last RunnerReady before shutdown
|
||||
ready_events = [
|
||||
(idx, e)
|
||||
for idx, e in enumerate(events)
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerReady)
|
||||
]
|
||||
shutdown_idx = next(
|
||||
idx
|
||||
for idx, e in enumerate(events)
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("shutdown")
|
||||
and e.task_status == TaskStatus.Running
|
||||
)
|
||||
# There should be a RunnerReady event between generation and shutdown
|
||||
ready_before_shutdown = [idx for idx, _ in ready_events if idx < shutdown_idx]
|
||||
assert len(ready_before_shutdown) >= 1, (
|
||||
"Expected RunnerReady between generation completion and shutdown"
|
||||
)
|
||||
|
||||
|
||||
def test_ten_simultaneous_single_token():
|
||||
"""Ten requests that each produce exactly one token — all finish on step 1."""
|
||||
engine = ScriptedBatchEngine()
|
||||
for i in range(10):
|
||||
engine.scripts[f"cmd{i}"] = [(f"word{i}", "stop")]
|
||||
|
||||
chats = [make_chat_task(f"t{i}", f"cmd{i}", max_tokens=100) for i in range(10)]
|
||||
events = _run_with_tasks([*SETUP_TASKS, *chats], engine_instance=engine)
|
||||
|
||||
for i in range(10):
|
||||
c = chunks_for(events, f"cmd{i}")
|
||||
assert len(c) == 1
|
||||
assert isinstance(c[0].chunk, TokenChunk)
|
||||
assert c[0].chunk.text == f"word{i}"
|
||||
assert c[0].chunk.finish_reason == "stop"
|
||||
|
||||
done = completed_task_ids(events)
|
||||
assert len(done & {TaskId(f"t{i}") for i in range(10)}) == 10
|
||||
@@ -1,12 +1,13 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
import unittest.mock
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -39,7 +40,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import BatchedGenerationResponse
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
@@ -116,7 +116,16 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
|
||||
# Use a fake event_sender to remove test flakiness.
|
||||
@@ -139,7 +148,6 @@ class MockTokenizer:
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
@@ -150,70 +158,6 @@ class MockGroup:
|
||||
return 1
|
||||
|
||||
|
||||
class FakeBatchEngine:
|
||||
"""Fake batch engine that generates a single 'hi' token per request."""
|
||||
|
||||
def __init__(self, *_args: object, **_kwargs: object):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
|
||||
self._pending_inserts: list[tuple[CommandId, TaskId, object]] = []
|
||||
self._uid_counter = 0
|
||||
self.rank = 0
|
||||
|
||||
def queue_request(
|
||||
self, command_id: CommandId, task_id: TaskId, task_params: object
|
||||
) -> str:
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
return ""
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
uids: list[int] = []
|
||||
for cmd_id, task_id, _params in self._pending_inserts:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
self._active_requests[uid] = (cmd_id, task_id)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
for _uid, (cmd_id, task_id) in list(self._active_requests.items()):
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=0, text="hi", finish_reason="stop", usage=None
|
||||
),
|
||||
)
|
||||
)
|
||||
self._active_requests.clear()
|
||||
return results
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self._active_requests)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -223,6 +167,7 @@ def _run(tasks: Iterable[Task]):
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||
event_sender = EventCollector()
|
||||
|
||||
with task_sender:
|
||||
@@ -233,8 +178,16 @@ def _run(tasks: Iterable[Task]):
|
||||
# this is some c++ nonsense
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
|
||||
with unittest.mock.patch(
|
||||
"exo.worker.runner.runner.mx.distributed.all_gather",
|
||||
make_nothin(mx.array([1])),
|
||||
):
|
||||
mlx_runner.main(
|
||||
bound_instance,
|
||||
event_sender, # pyright: ignore[reportArgumentType]
|
||||
task_receiver,
|
||||
cancel_receiver,
|
||||
)
|
||||
|
||||
return event_sender.events
|
||||
|
||||
@@ -279,22 +232,17 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
# CHAT TASK: queued, tokens generated, then completed
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID,
|
||||
runner_status=RunnerRunning(active_requests=1),
|
||||
),
|
||||
# Generation loop produces token and completes the task
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
# SHUTDOWN
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
@@ -303,6 +251,7 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -5,12 +5,13 @@ from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||
from exo.worker.runner.runner import parse_tool_calls
|
||||
from exo.worker.runner.tool_parsers import make_mlx_parser
|
||||
|
||||
|
||||
def _make_responses(
|
||||
texts: list[str],
|
||||
finish_on_last: bool = True,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Create a sequence of GenerationResponses from text strings."""
|
||||
for i, text in enumerate(texts):
|
||||
is_last = i == len(texts) - 1
|
||||
@@ -22,10 +23,13 @@ def _make_responses(
|
||||
)
|
||||
|
||||
|
||||
def _dummy_parser(text: str) -> dict[str, Any]:
|
||||
def _dummier_parser(text: str) -> dict[str, Any]:
|
||||
return {"name": "test_fn", "arguments": {"arg": text}}
|
||||
|
||||
|
||||
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
|
||||
|
||||
|
||||
class TestParseToolCalls:
|
||||
"""Tests for parse_tool_calls generator."""
|
||||
|
||||
@@ -35,8 +39,6 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts, finish_on_last=False),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_dummy_parser,
|
||||
)
|
||||
)
|
||||
@@ -50,8 +52,6 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_dummy_parser,
|
||||
)
|
||||
)
|
||||
@@ -76,9 +76,7 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts, finish_on_last=False),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_failing_parser,
|
||||
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user