Compare commits

..

5 Commits

Author SHA1 Message Date
rltakashige
5c8a237940 Handle model timeouts (#1177)
- Add eval with a timeout.
- Add fast synch flag

## Motivation

Because of the experimental FAST SYNCH flag, some models may not work.
This PR catches when this occurs and allows users to specify a run
without fast synch

## Changes

- Adds a flag to enable or disable fast synch (--fast-synch and
--no-fast-synch)
- Adds a heuristic timeout
- Reduces exo_bench default timeout to 10 minutes.

## Why It Works

Heuristic timeout assumes normal loading times on Mac devices (60 +
model size in gb / 5: e.g. DeepSeek takes up to 120 seconds to load on
tensor parallel, and timeout is set to 60 + 120 = 180s.

We could raise this value if necessary.

## Test Plan

### Manual Testing
Catches that GPT OSS fails to load in Tensor RDMA
Can launch with --no-fast-synch flag to launch GPT OSS.

**GPT OSS 20B**
TP with fast synch
<img width="3064" height="456" alt="image"
src="https://github.com/user-attachments/assets/f6e25cd8-8621-4e99-99fe-292ee05c4035"
/>

TP without fast synch
<img width="3098" height="496" alt="image"
src="https://github.com/user-attachments/assets/d36453d9-6686-4cfe-aa7c-a7d458369d4d"
/>
[Note: the performance is really not great as fast synch is off]

(As a sanity check)
PP with fast synch
<img width="3124" height="496" alt="image"
src="https://github.com/user-attachments/assets/e97d4547-c6fa-483d-badb-4b371b900b4c"
/>

PP without fast synch
<img width="3078" height="508" alt="image"
src="https://github.com/user-attachments/assets/b2e20dfd-4b0e-4295-8a92-417dfe745c28"
/>

PP without RDMA
<img width="3070" height="498" alt="image"
src="https://github.com/user-attachments/assets/a8509d68-0aef-4cda-bca5-a67d39a0801e"
/>

TP without RDMA
<img width="3068" height="496" alt="image"
src="https://github.com/user-attachments/assets/b5691429-89f4-4369-bcf2-8fde2ad7154a"
/>
2026-01-16 20:25:12 +00:00
rltakashige
745343c705 Return error responses for Chat Completions (#1173)
- Error chunks
- Use error handling in exo_bench.py

## Motivation

Return when an error occurs so that generation stops. Adding timeouts is
a separate TODO for model loading and chat completions.

## Changes

- Return HTTP exceptions as JSON responses in an OpenAI compatible
format.
- Context manager for generation to catch and return error messages.
- Use error handling in exo_bench.py.

## Test Plan

### Manual Testing
Manually tested that exo_bench returns on failures within and outside
generation

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 19:24:37 +00:00
Alex Cheema
5e28664c41 Fix draft release detection (attempt 3) (#1176)
## Motivation

Previous fix still failed in CI. Suspecting permissions issue with
GITHUB_TOKEN not being able to see draft releases via API.

## Changes

1. Add explicit `permissions: contents: write` to the job
2. Use `gh release list` first to check if draft exists (this uses a
different code path that might work better)
3. Add debug echo statements

## Test Plan

Delete v1.0.63 tag and re-push after merging.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:26:06 +00:00
Alex Cheema
ae0a804ccb Fix draft release detection query (#1175)
## Motivation

Fixes the draft release detection that failed on the v1.0.63 release
attempt.

## Changes

The jq query was piped to `head -1` which truncated multi-line JSON
output to just `{`, causing the empty check to fail.

Changed to use `first // empty` in jq instead.

## Test Plan

Tested locally:
```bash
GITHUB_REF_NAME="v1.0.63"
gh api repos/exo-explore/exo/releases --jq "[.[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")] | first // empty"
# Returns the full draft release JSON (2711 chars)
```

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:05:24 +00:00
Alex Cheema
07cf2c1aa1 Add GitHub releases with Sparkle release notes integration (#1172)
## Motivation

Closes #1140

Currently releases are uploaded to S3 for Sparkle updates but there's no
GitHub Release created, and Sparkle update dialogs don't show release
notes. Users have no visibility into what changed.

## Changes

- Added release workflow documentation comment at top of `build-app.yml`
- Added "Fetch release notes for Sparkle" step that converts markdown
from draft GitHub release to HTML
- Added "Inject release notes into appcast" step that embeds HTML in
appcast.xml with CDATA
- Added "Publish GitHub Release" step that attaches DMG and publishes
the draft

## Why It Works

- Sparkle's `<description>` tag supports HTML wrapped in CDATA for
rendering in update dialogs
- GitHub's markdown API (`/markdown`) converts the release notes to HTML
with proper formatting
- Draft releases allow writing polished notes before the build, then the
workflow publishes them automatically
- The workflow fails if no draft release exists, ensuring release notes
are always provided

## Test Plan

### Manual Testing
1. Create a draft GitHub release for a new tag with markdown release
notes
2. Push the tag to trigger the workflow
3. Verify the GitHub release is published with DMG attached
4. Download appcast.xml from S3 and verify
`<description><![CDATA[...]]></description>` contains HTML
5. Test Sparkle update dialog on macOS to confirm release notes appear

### Automated Testing
No automated tests added - this is CI workflow configuration.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 16:47:33 +00:00
18 changed files with 561 additions and 432 deletions

View File

@@ -1,5 +1,16 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
@@ -11,8 +22,10 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.8.1
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
@@ -87,6 +100,52 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -304,6 +363,28 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -336,3 +417,26 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.8.1;
minimumVersion = 2.9.0-beta.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
}
}
],

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
@@ -26,7 +27,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -104,22 +105,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
@@ -299,6 +324,12 @@ def main() -> int:
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
@@ -320,7 +351,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -399,7 +430,7 @@ def main() -> int:
):
continue
if 0 < n <= args.max_nodes:
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
if not selected:
@@ -441,7 +472,13 @@ def main() -> int:
)
client.request_json("POST", "/instance", body={"instance": instance})
wait_for_instance_ready(client, instance_id)
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)

View File

@@ -205,6 +205,14 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -218,6 +226,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -259,6 +268,20 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,13 +1,14 @@
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
import anyio
from anyio import create_task_group
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
@@ -29,6 +30,8 @@ from exo.shared.types.api import (
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
GenerationStats,
ModelList,
@@ -49,7 +52,12 @@ from exo.shared.types.commands import (
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
@@ -115,6 +123,7 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -145,6 +154,20 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -406,6 +429,18 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -426,6 +461,12 @@ class API:
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -463,6 +504,12 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -607,14 +654,14 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -49,22 +49,20 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def _assign_layers_by_ram(
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
) -> ShardAssignments:
"""Assign layers proportionally based on available RAM."""
):
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
@@ -79,6 +77,7 @@ def _assign_layers_by_ram(
node_layers = max(1, node_layers)
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
device_rank=i,
@@ -87,143 +86,18 @@ def _assign_layers_by_ram(
end_layer=layers_assigned + node_layers,
n_layers=total_layers,
)
runner_to_shard[runner_id] = shard
node_to_runner[node.node_id] = runner_id
layers_assigned += node_layers
return ShardAssignments(
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
def _reserve_base_layers(world_size: int, total_layers: int) -> dict[int, int]:
"""Reserve 1 layer per node to ensure connectivity."""
assignments = {i: 0 for i in range(world_size)}
remaining_layers = total_layers
for i in range(world_size):
assignments[i] = 1
remaining_layers -= 1
if remaining_layers < 0:
logger.warning(
"Fewer layers than nodes! Reducing to 1 layer per node where possible."
)
assignments = {i: 1 if i < total_layers else 0 for i in range(world_size)}
remaining_layers = 0
return assignments
def _distribute_layers_by_bandwidth(
selected_cycle: list[NodeWithProfile],
assignments: dict[int, int],
remaining_layers: int,
model_meta: ModelMetadata,
) -> None:
"""Distribute remaining layers based on bandwidth and RAM capacity."""
indexed_nodes = list(enumerate(selected_cycle))
sorted_nodes = sorted(
indexed_nodes,
key=lambda x: x[1].node_profile.memory_bandwidth or 0,
reverse=True,
)
for original_idx, node in sorted_nodes:
if remaining_layers <= 0:
break
layer_size_bytes = model_meta.storage_size.in_bytes / model_meta.n_layers
max_layers_by_ram = int(
node.node_profile.memory.ram_available.in_bytes // layer_size_bytes
)
can_take = max(0, max_layers_by_ram - assignments[original_idx])
take = min(can_take, remaining_layers)
assignments[original_idx] += take
remaining_layers -= take
if remaining_layers > 0:
logger.warning(
"All nodes maxed out on RAM estimation, dumping remaining layers on fastest nodes."
)
for original_idx, _ in sorted_nodes:
assignments[original_idx] += 1
remaining_layers -= 1
if remaining_layers == 0:
break
def _create_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
assignments: dict[int, int],
) -> ShardAssignments:
"""Create shard assignments from layer assignments."""
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
current_start = 0
for i, node in enumerate(selected_cycle):
count = assignments[i]
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=current_start,
end_layer=current_start + count,
n_layers=model_meta.n_layers,
)
runner_to_shard[runner_id] = shard
node_to_runner[node.node_id] = runner_id
current_start += count
return ShardAssignments(
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
def _assign_layers_by_bandwidth(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
) -> ShardAssignments:
"""Assign layers based on memory bandwidth."""
logger.info("Using bandwidth-aware shard assignment")
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
assignments = _reserve_base_layers(world_size, total_layers)
remaining_layers = total_layers - sum(assignments.values())
if remaining_layers > 0:
_distribute_layers_by_bandwidth(
selected_cycle, assignments, remaining_layers, model_meta
)
return _create_shard_assignments(model_meta, selected_cycle, assignments)
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
has_bandwidth = all(
node.node_profile.memory_bandwidth is not None for node in selected_cycle
)
if not has_bandwidth:
logger.info(
"Bandwidth data missing for some nodes, falling back to RAM-proportional assignment"
)
return _assign_layers_by_ram(model_meta, selected_cycle)
return _assign_layers_by_bandwidth(model_meta, selected_cycle)
return shard_assignments
def get_shard_assignments_for_tensor_parallel(

View File

@@ -0,0 +1,107 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -397,106 +397,3 @@ def test_get_mlx_jaccl_coordinators(
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
def test_get_shard_assignments_bandwidth_aware(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# Create nodes with identical RAM (plenty of it)
# Using 1GB to ensure no RAM constraints (model is small)
node_a = create_node(1024 * 1024 * 1024, node_a_id)
node_b = create_node(1024 * 1024 * 1024, node_b_id)
node_c = create_node(1024 * 1024 * 1024, node_c_id)
# Set Bandwidths: A=400 (Fastest), B=200, C=100 (Slowest)
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
node_a.node_profile.memory_bandwidth = 400_000_000_000
node_b.node_profile.memory_bandwidth = 200_000_000_000
node_c.node_profile.memory_bandwidth = 100_000_000_000
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
# Needs full cycle edges for get_cycles/get_shard_assignments if strict?
# Actually get_cycles just looks for cycles.
# But let's follow the pattern of other tests if they add bidirectional.
# checking test_filter_cycles_by_memory, it adds both directions.
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=30, # 30 layers
storage_size=Memory.from_kb(
300
), # 10KB per layer. Nodes have 100MB RAM (100*1024 in create_node usually means KB? other tests use 1000*1024).
# create_node arg is likely KB or Bytes.
# test_filter_cycles_by_memory: create_node(1000 * 1024, ...) -> Memory.from_bytes(1) passes.
# Let's assume create_node takes Bytes or KB consistently.
# If I give 100*1024*1024 bytes = 100MB.
# Model storage = 300KB.
# So capacity is definitely not an issue.
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
# Depending on how get_cycles works and order of addition, we might get multiple cycles.
# filtering by memory usually done in master.
# Here we just pick one.
selected_cycle = cycles[0]
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline
)
# assert
runner_id_a = shard_assignments.node_to_runner[node_a_id]
runner_id_b = shard_assignments.node_to_runner[node_b_id]
runner_id_c = shard_assignments.node_to_runner[node_c_id]
# Get layer counts
layers_a = (
shard_assignments.runner_to_shard[runner_id_a].end_layer
- shard_assignments.runner_to_shard[runner_id_a].start_layer
)
layers_b = (
shard_assignments.runner_to_shard[runner_id_b].end_layer
- shard_assignments.runner_to_shard[runner_id_b].start_layer
)
layers_c = (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
)
# Check total
assert layers_a + layers_b + layers_c == 30
# Check that the fastest node (A with 400GB/s) gets saturated first.
# With strict greedy assignment and plenty of RAM:
# 1. Reserve: A=1, B=1, C=1. Remaining=27.
# 2. Sort: [A, B, C]
# 3. A takes min(remaining=27, capacity=huge) = 27.
# 4. A=28, B=1, C=1.
assert layers_a == 28
assert layers_b == 1
assert layers_c == 1

View File

@@ -11,10 +11,21 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call"
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"

View File

@@ -22,6 +22,7 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):

View File

@@ -57,7 +57,6 @@ class NodePerformanceProfile(CamelCaseModel):
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
memory_bandwidth: int | None = None
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile

View File

@@ -2,7 +2,9 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -82,6 +84,45 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -188,7 +229,9 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
@@ -202,7 +245,9 @@ def load_mlx_items(
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
@@ -216,6 +261,7 @@ def load_mlx_items(
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
@@ -252,7 +298,15 @@ def shard_and_load(
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)

View File

@@ -17,15 +17,23 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
)
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main

View File

@@ -1,6 +1,8 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -13,6 +15,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -20,6 +23,7 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -48,6 +52,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -57,6 +62,33 @@ from exo.worker.engines.mlx.utils_mlx import (
from exo.worker.runner.bootstrap import logger
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -118,7 +150,20 @@ def main(
)
)
model, tokenizer = load_mlx_items(bound_instance, group)
def on_model_load_timeout() -> None:
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id,
runner_status=RunnerFailed(
error_message="Model loading timed out"
),
)
)
time.sleep(0.5)
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -135,7 +180,7 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
model=cast(Model, model),
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -148,8 +193,6 @@ def main(
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -158,41 +201,47 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
):
assert model
assert tokenizer
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")

View File

@@ -0,0 +1,50 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()

View File

@@ -4,7 +4,6 @@ import platform
from typing import Any, Callable, Coroutine
import anyio
from anyio import to_thread
from loguru import logger
from exo.shared.types.memory import Memory
@@ -25,61 +24,8 @@ from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
profile_memory_bandwidth,
)
# Module-level cache for memory bandwidth (doesn't change at runtime)
_cached_bandwidth: int | None = None
_bandwidth_profiled: bool = False
_bandwidth_profiling_task: asyncio.Task[int | None] | None = None
async def profile_bandwidth_once() -> int | None:
"""Profile bandwidth once in a background thread and cache the result.
This function is non-blocking - it runs the profiling in a thread pool.
Subsequent calls return the cached result immediately.
"""
global _cached_bandwidth, _bandwidth_profiled, _bandwidth_profiling_task
# Already profiled, return cached value
if _bandwidth_profiled:
return _cached_bandwidth
# Profiling already in progress, wait for it
if _bandwidth_profiling_task is not None:
return await _bandwidth_profiling_task
# Start profiling in background thread
async def _do_profile() -> int | None:
global _cached_bandwidth, _bandwidth_profiled
try:
logger.info("Starting memory bandwidth profiling in background thread...")
bandwidth = await to_thread.run_sync(profile_memory_bandwidth, cancellable=True)
_cached_bandwidth = bandwidth
_bandwidth_profiled = True
if bandwidth:
logger.info(f"Memory bandwidth profiled: {bandwidth / 1e9:.1f} GB/s")
else:
logger.warning("Memory bandwidth profiling returned None")
return bandwidth
except Exception as e:
logger.opt(exception=e).error("Memory bandwidth profiling failed")
_bandwidth_profiled = True # Mark as done to avoid retrying
return None
_bandwidth_profiling_task = asyncio.create_task(_do_profile())
return await _bandwidth_profiling_task
def get_memory_bandwidth_cached() -> int | None:
"""Return cached bandwidth or None if not yet profiled.
This is a non-blocking synchronous function that returns immediately.
Call profile_bandwidth_once() first to trigger profiling.
"""
return _cached_bandwidth if _bandwidth_profiled else None
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
@@ -125,8 +71,6 @@ async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
bandwidth_profile_started = False
while True:
try:
metrics = await get_metrics_async()
@@ -141,15 +85,6 @@ async def start_polling_node_metrics(
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
# Start bandwidth profiling in background on first poll (non-blocking)
if not bandwidth_profile_started:
bandwidth_profile_started = True
# Fire and forget - don't await, let it run in background
asyncio.create_task(profile_bandwidth_once())
# Use cached bandwidth (None until profiling completes)
memory_bandwidth = get_memory_bandwidth_cached()
await callback(
NodePerformanceProfile(
model_id=model_id,
@@ -157,7 +92,6 @@ async def start_polling_node_metrics(
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
memory_bandwidth=memory_bandwidth,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,

View File

@@ -1,6 +1,5 @@
import socket
import sys
import time
from subprocess import CalledProcessError
import psutil
@@ -82,68 +81,3 @@ async def get_model_and_chip() -> tuple[str, str]:
chip = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
return (model, chip)
def profile_memory_bandwidth() -> int | None:
"""
Profile device memory bandwidth using MLX GPU operations.
Uses a large array copy on the GPU to measure unified memory bandwidth.
Returns measured bandwidth in bytes/second, or None if MLX is unavailable.
"""
try:
import mlx.core as mx
if not mx.metal.is_available():
return None
# Use 2GB buffer to better saturate memory bandwidth
# Use 2D shape to avoid potential issues with very large 1D arrays
size_bytes = 2 * 1024 * 1024 * 1024
side = int((size_bytes // 4) ** 0.5) # Square 2D array of float32
shape = (side, side)
actual_bytes = side * side * 4
bytes_transferred = actual_bytes * 2 # read + write
# Warm-up: run the full benchmark operation multiple times to stabilize GPU
for _ in range(3):
src = mx.random.uniform(shape=shape, dtype=mx.float32)
mx.eval(src)
dst = src + 0.0
mx.eval(dst)
mx.synchronize()
del src, dst
# Benchmark: measure time to copy array
best_bandwidth = 0.0
num_runs = 4
for _ in range(num_runs):
src = mx.random.uniform(shape=shape, dtype=mx.float32)
mx.eval(src)
mx.synchronize()
# Time the copy operation (src + 0.0 forces read of src, write of dst)
start = time.perf_counter()
dst = src + 0.0
mx.eval(dst)
mx.synchronize()
end = time.perf_counter()
bandwidth = bytes_transferred / (end - start)
best_bandwidth = max(best_bandwidth, bandwidth)
del src, dst
return int(best_bandwidth)
except Exception:
return None
def get_memory_bandwidth(_chip_id: str) -> int | None:
"""
Returns measured memory bandwidth in bytes/second.
Uses MLX GPU operations for accurate unified memory bandwidth measurement.
"""
return profile_memory_bandwidth()