Compare commits

...

10 Commits

Author SHA1 Message Date
Alex Cheema
ccf4d91d55 add script to reproduce GPU lock issue
Repeatedly sends chat completions to Llama-3.2-1B-Instruct-4bit on a
single node and detects when a request stalls for >5s.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 09:44:07 -08:00
Ryuichi Leo Takashige
811a4d80bd update mlx to resolve event leak 2026-02-24 13:03:55 +00:00
rltakashige
2fe689315b download .model files in exo bench (#1607)
## Motivation

failed again for kimi on a machine that had never downloaded it.

## Test Plan

### Manual Testing
it worked this time
2026-02-24 11:13:04 +00:00
Alex Cheema
644c5573ce fix: improve text contrast on downloads page (#1601)
## Summary
- Bumps opacity on dark-grey text/icons on the downloads page that were
nearly invisible against the dark background
- Informational text (GB sizes, speeds, disk free, model IDs) → full
`text-exo-light-gray`
- Interactive icons (delete, resume, retry) → `/70` at rest
- Hover-only elements (download button) → `/60`

## Test plan
- [ ] Open http://localhost:52415/downloads with models downloaded
- [ ] Verify GB downloaded amounts are clearly visible
- [ ] Verify delete trash icons are visible (not just on hover)
- [ ] Verify download speed text is readable
- [ ] Verify "paused" labels and resume buttons are visible
- [ ] Verify "X GB free" disk labels in column headers are readable

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 10:42:49 +00:00
rltakashige
12c3015f52 fix qwen moe tensor sharding (#1604)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-23 21:23:27 +00:00
rltakashige
365dd68d9a Final fixes for release (#1603)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-23 21:10:15 +00:00
Alex Cheema
d3d129581e test: verify instance deletion cancels ongoing tasks (#1508)
## Summary
- The cancellation logic for issue #1215 already exists in
`get_transition_events()` (`src/exo/master/placement.py:208-227`) — when
an instance is deleted, `TaskStatusUpdated(Cancelled)` events are
emitted for all Pending/Running tasks on that instance
- Combined with PR #1276's token-boundary cancellation in runners, the
full pipeline works end-to-end
- However, the existing test
`test_get_transition_events_delete_instance` passed `{}` for tasks, so
this path was never exercised
- This PR adds 4 tests covering the cancellation behavior:
  - Running tasks are cancelled on instance deletion
  - Pending tasks are cancelled on instance deletion
  - Completed/Failed/TimedOut/Cancelled tasks are left alone
  - Only tasks matching the deleted instance are cancelled

Closes #1215

## Test plan
- [x] `uv run pytest src/exo/master/tests/test_placement.py -v` — all 15
tests pass
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:12:23 +00:00
Alex Cheema
c90a0cec78 fix: suppress closure errors in runnersupervisor and force spawn start method (#1547)
some errors could be thrown during shutdown - we can dismiss these safely

co-authored by me :)
2026-02-23 18:30:41 +00:00
Alex Cheema
e8c1337168 fix: add download/resume buttons to pending downloads (#1581)
## Summary
- Adds a **resume button** (download icon) to paused pending downloads
(those with partial progress)
- Adds a **download button** to not-started pending downloads
- Both buttons call the existing `startDownload()` function which
handles both new downloads and resuming partial ones
- Previously, paused downloads only showed a "paused" label with no
action, and not-started downloads showed "..." with no way to trigger
them

## Test plan
- [ ] Build dashboard (`cd dashboard && npm run build`)
- [ ] Start exo, navigate to Downloads tab
- [ ] Verify paused downloads show a clickable resume (download arrow)
icon below the progress bar
- [ ] Verify not-started pending downloads show a clickable download
icon
- [ ] Click both button types and confirm downloads start/resume

> Note: Screenshot could not be captured because the dashboard requires
the exo backend API to render, and exo has a pre-existing
`Keypair.generate()` startup bug on main.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 18:20:35 +00:00
Alex Cheema
7024ddcf3e fix: detect completed downloads by checking final file exists (#1582)
## Summary

Split from #1547 per review feedback.

When scanning existing download status, `get_downloaded_size()` returns
bytes from either the final file or its `.partial` counterpart — so a
`.partial` file with all bytes downloaded (e.g. process killed before
hash verification and rename) could be falsely reported as complete.

The previous approach (commit 3b54e7df) added a byte-comparison fallback
in the coordinator (`downloaded >= total > 0`), but this suffered from
the same `.partial` conflation issue.

**Fix:** Check whether the final (non-`.partial`) file actually exists
on disk before marking status as `"complete"` in `download_utils.py`.
This is the only reliable signal that hash verification passed and the
rename from `.partial` succeeded. The coordinator-level byte comparison
is removed since the source now reports correctly.

### Changes
- `download_utils.py`: Add `final_file_exists` check — only report
`"complete"` when the renamed, hash-verified file exists with matching
size
- `coordinator.py`: Revert to simple `progress.status == "complete"`
check, removing the byte-comparison fallback

**Note:** The corresponding byte-comparison workaround in #1547 should
also be removed.

## Test plan
- [x] basedpyright — 0 errors
- [x] ruff check — passes
- [x] pytest — 218 passed, 1 skipped (1 pre-existing Rust bindings
failure)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 18:12:07 +00:00
15 changed files with 708 additions and 68 deletions

View File

@@ -75,7 +75,7 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
model_path = Path(
snapshot_download(
model_id,
allow_patterns=["*.json", "*.py", "*.tiktoken"],
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model"],
)
)

View File

@@ -412,7 +412,7 @@
<div>{col.label}</div>
{#if col.diskAvailable != null}
<div
class="text-[9px] text-exo-light-gray/60 normal-case tracking-normal mt-0.5"
class="text-[9px] text-white/70 normal-case tracking-normal mt-0.5"
>
{formatBytes(col.diskAvailable)} free
</div>
@@ -436,7 +436,7 @@
</div>
{#if row.prettyName}
<div
class="text-[10px] text-exo-light-gray/60"
class="text-[10px] text-white/60"
title={row.modelId}
>
{row.modelId}
@@ -450,7 +450,7 @@
title="View model details"
>
<svg
class="w-4 h-4 text-white/30 hover:text-white/60"
class="w-4 h-4 text-white/60 hover:text-white/80"
viewBox="0 0 24 24"
fill="currentColor"
>
@@ -469,11 +469,11 @@
<td class="px-4 py-3 text-center align-middle">
{#if cell.kind === "completed"}
<div
class="flex flex-col items-center gap-0.5"
class="flex flex-col items-center gap-1"
title="Completed ({formatBytes(cell.totalBytes)})"
>
<svg
class="w-5 h-5 text-green-400"
class="w-7 h-7 text-green-400"
viewBox="0 0 20 20"
fill="currentColor"
>
@@ -483,18 +483,18 @@
clip-rule="evenodd"
></path>
</svg>
<span class="text-[10px] text-exo-light-gray/70"
<span class="text-xs text-white/70"
>{formatBytes(cell.totalBytes)}</span
>
<button
type="button"
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5"
class="text-white/50 hover:text-red-400 transition-colors mt-0.5 cursor-pointer"
onclick={() =>
deleteDownload(col.nodeId, row.modelId)}
title="Delete from this node"
>
<svg
class="w-3.5 h-3.5"
class="w-5 h-5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
@@ -517,11 +517,11 @@
cell.speed,
)} - ETA {formatEta(cell.etaMs)}"
>
<span class="text-exo-yellow text-xs font-medium"
<span class="text-exo-yellow text-sm font-medium"
>{clampPercent(cell.percentage).toFixed(1)}%</span
>
<div
class="w-14 h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
class="w-16 h-2 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="h-full bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
@@ -530,25 +530,25 @@
).toFixed(1)}%"
></div>
</div>
<span class="text-[9px] text-exo-light-gray/60"
<span class="text-[10px] text-white/70"
>{formatSpeed(cell.speed)}</span
>
</div>
{:else if cell.kind === "pending"}
<div
class="flex flex-col items-center gap-0.5"
class="flex flex-col items-center gap-1"
title={cell.downloaded > 0
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded`
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded (paused)`
: "Download pending"}
>
{#if cell.downloaded > 0 && cell.total > 0}
<span class="text-exo-light-gray/70 text-[10px]"
<span class="text-white/70 text-xs"
>{formatBytes(cell.downloaded)} / {formatBytes(
cell.total,
)}</span
>
<div
class="w-full h-1 bg-white/10 rounded-full overflow-hidden"
class="w-full h-1.5 bg-white/10 rounded-full overflow-hidden"
>
<div
class="h-full bg-exo-light-gray/40 rounded-full"
@@ -558,21 +558,65 @@
).toFixed(1)}%"
></div>
</div>
<span class="text-exo-light-gray/40 text-[9px]"
>paused</span
{#if row.shardMetadata}
<button
type="button"
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Resume download on this node"
>
<svg
class="w-5 h-5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{:else}
<span class="text-white/50 text-[10px]">paused</span
>
{/if}
{:else if row.shardMetadata}
<button
type="button"
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Start download on this node"
>
<svg
class="w-6 h-6"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{:else}
<span class="text-exo-light-gray/50 text-sm">...</span
>
<span class="text-white/40 text-sm">...</span>
{/if}
</div>
{:else if cell.kind === "failed"}
<div
class="flex flex-col items-center gap-0.5"
class="flex flex-col items-center gap-1"
title="Download failed"
>
<svg
class="w-5 h-5 text-red-400"
class="w-7 h-7 text-red-400"
viewBox="0 0 20 20"
fill="currentColor"
>
@@ -585,13 +629,13 @@
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors"
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Retry download on this node"
>
<svg
class="w-3.5 h-3.5"
class="w-5 h-5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
@@ -617,13 +661,13 @@
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100"
class="text-white/50 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Download to this node"
>
<svg
class="w-3.5 h-3.5"
class="w-5 h-5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"

View File

@@ -41,7 +41,7 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260220+13998a05"; in
version = let v = "0.30.7.dev20260224+5289547a"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
@@ -49,8 +49,8 @@ let
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "13998a054715edcdc93618fb1496c79c7c25ff7c";
hash = "sha256-fAqA3hFwNBx7FcoGnhQsIFpAIRbC2EerACm4Fvne0Cc=";
rev = "5289547ada1cddda2b9716baf6a077a906d02189";
hash = "sha256-Zp9Jln7+Fpn79OfnIdiIVYzQDpih9lHrKtKJadh+c0I=";
};
patches = [

View File

@@ -823,6 +823,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
final_file_exists = await aios.path.exists(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_revision=revision,
@@ -832,7 +833,9 @@ async def download_shard(
total=Memory.from_bytes(file.size or 0),
speed=0,
eta=timedelta(0),
status="complete" if downloaded_bytes == file.size else "not_started",
status="complete"
if final_file_exists and downloaded_bytes == file.size
else "not_started",
start_time=time.time(),
)

View File

@@ -252,7 +252,7 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn")
mp.set_start_method("spawn", force=True)
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")

View File

@@ -14,10 +14,12 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
Instance,
@@ -456,3 +458,117 @@ def test_tensor_rdma_backend_connectivity_matrix(
else:
ip_part = coordinator.split(":")[0]
assert len(ip_part.split(".")) == 4
def _make_task(
instance_id: InstanceId,
status: TaskStatus = TaskStatus.Running,
) -> TextGeneration:
return TextGeneration(
task_id=TaskId(),
task_status=status,
instance_id=instance_id,
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("test-model"),
input=[InputMessage(role="user", content="hello")],
),
)
def test_get_transition_events_delete_instance_cancels_running_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
task = _make_task(instance_id, TaskStatus.Running)
tasks = {task.task_id: task}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert cancellation event should come before the deletion event
assert len(events) == 2
assert isinstance(events[0], TaskStatusUpdated)
assert events[0].task_id == task.task_id
assert events[0].task_status == TaskStatus.Cancelled
assert isinstance(events[1], InstanceDeleted)
assert events[1].instance_id == instance_id
def test_get_transition_events_delete_instance_cancels_pending_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
task = _make_task(instance_id, TaskStatus.Pending)
tasks = {task.task_id: task}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert
assert len(events) == 2
assert isinstance(events[0], TaskStatusUpdated)
assert events[0].task_id == task.task_id
assert events[0].task_status == TaskStatus.Cancelled
assert isinstance(events[1], InstanceDeleted)
def test_get_transition_events_delete_instance_ignores_completed_tasks(
instance: Instance,
):
# arrange
instance_id = InstanceId()
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
target_instances: dict[InstanceId, Instance] = {}
tasks = {
t.task_id: t
for t in [
_make_task(instance_id, TaskStatus.Complete),
_make_task(instance_id, TaskStatus.Failed),
_make_task(instance_id, TaskStatus.TimedOut),
_make_task(instance_id, TaskStatus.Cancelled),
]
}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert only the InstanceDeleted event, no cancellations
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
instance: Instance,
):
# arrange
instance_id_a = InstanceId()
instance_id_b = InstanceId()
current_instances: dict[InstanceId, Instance] = {
instance_id_a: instance,
instance_id_b: instance,
}
# only delete instance A, keep instance B
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
task_a = _make_task(instance_id_a, TaskStatus.Running)
task_b = _make_task(instance_id_b, TaskStatus.Running)
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
# act
events = get_transition_events(current_instances, target_instances, tasks)
# assert only task_a should be cancelled
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
assert len(cancel_events) == 1
assert cancel_events[0].task_id == task_a.task_id
assert cancel_events[0].task_status == TaskStatus.Cancelled
assert len(delete_events) == 1
assert delete_events[0].instance_id == instance_id_a

View File

@@ -90,6 +90,7 @@ class ModelCard(CamelCaseModel):
base_model: str = ""
capabilities: list[str] = []
uses_cfg: bool = False
trust_remote_code: bool = True
@field_validator("tasks", mode="before")
@classmethod
@@ -137,6 +138,7 @@ class ModelCard(CamelCaseModel):
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
tasks=[ModelTask.TextGeneration],
trust_remote_code=False,
)
await mc.save_to_custom_dir()
_card_cache[model_id] = mc

View File

@@ -128,12 +128,25 @@ class PipelineFirstLayer(CustomMlxLayer):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
import time as _time
_t0 = _time.perf_counter()
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
if self.is_prefill:
# We want to avoid GPU timeout errors by evalling the distributed operation
# so that it stays on CPU, which does not have a timeout.
mx.eval(x)
return self.original_layer(x, *args, **kwargs)
_elapsed = _time.perf_counter() - _t0
if _elapsed > 1.0:
import logging
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineFirstLayer recv_like+eval took {_elapsed:.4f}s (SLOW)")
_t0_layer = _time.perf_counter() if self.r != 0 else None
result = self.original_layer(x, *args, **kwargs)
if _t0_layer is not None:
_elapsed_layer = _time.perf_counter() - _t0_layer
if _elapsed_layer > 1.0:
import logging
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineFirstLayer original_layer took {_elapsed_layer:.4f}s (SLOW)")
return result
class PipelineLastLayer(CustomMlxLayer):
@@ -152,13 +165,20 @@ class PipelineLastLayer(CustomMlxLayer):
self.is_prefill: bool = False
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
import time as _time
cache = self.original_layer_signature.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
_t0 = _time.perf_counter()
output: mx.array = self.original_layer(x, *args, **kwargs)
_elapsed = _time.perf_counter() - _t0
if _elapsed > 1.0:
import logging
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer original_layer took {_elapsed:.4f}s (SLOW)")
if self.r != self.s - 1:
_t0 = _time.perf_counter()
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
@@ -171,11 +191,20 @@ class PipelineLastLayer(CustomMlxLayer):
mx.eval(output)
if cache is not None:
mx.eval(_cache.keys) # type: ignore
_elapsed = _time.perf_counter() - _t0
if _elapsed > 1.0:
import logging
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer send+eval took {_elapsed:.4f}s (SLOW)")
if not self.is_prefill:
_t0 = _time.perf_counter()
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
]
_elapsed = _time.perf_counter() - _t0
if _elapsed > 1.0:
import logging
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer all_gather took {_elapsed:.4f}s (SLOW)")
return output
@@ -852,6 +881,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
layer.self_attn.n_heads //= self.N
layer.self_attn.n_kv_heads //= self.N
else:
assert isinstance(layer, Qwen3NextDecoderLayer)
if hasattr(layer, "linear_attn"):

View File

@@ -94,14 +94,20 @@ def prefill(
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
t0 = time.perf_counter()
set_pipeline_prefill(model, is_prefill=True)
logger.warning(f"[PREFILL] set_pipeline_prefill(True) took {time.perf_counter() - t0:.4f}s")
t0 = time.perf_counter()
mx_barrier(group)
logger.info("Starting prefill")
logger.warning(f"[PREFILL] mx_barrier (pre-prefill) took {time.perf_counter() - t0:.4f}s")
logger.warning("[PREFILL] Starting prefill via stream_generate")
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
try:
t0 = time.perf_counter()
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -114,15 +120,19 @@ def prefill(
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
logger.warning(f"[PREFILL] stream_generate first yield took {time.perf_counter() - t0:.4f}s")
break # Stop after first iteration - cache is now filled
except PrefillCancelled:
set_pipeline_prefill(model, is_prefill=False)
raise
t0 = time.perf_counter()
set_pipeline_prefill(model, is_prefill=False)
logger.warning(f"[PREFILL] set_pipeline_prefill(False) took {time.perf_counter() - t0:.4f}s")
# stream_generate added 1 extra generated token to the cache, so we should trim it.
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
t0 = time.perf_counter()
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
for i, c in enumerate(cache):
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
@@ -132,11 +142,12 @@ def prefill(
else:
assert not isinstance(c, (ArraysCache, RotatingKVCache))
c.trim(2) # pyright: ignore[reportUnknownMemberType]
logger.warning(f"[PREFILL] cache trim took {time.perf_counter() - t0:.4f}s")
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
logger.debug(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
logger.warning(
f"[PREFILL] complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
# Exclude the last snapshot
@@ -324,6 +335,8 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
# Prefill cache with all tokens except the last one
logger.warning(f"[GENERATE] calling prefill with {len(prompt_tokens) - 1} tokens")
t_prefill_start = time.perf_counter()
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model,
tokenizer,
@@ -333,6 +346,7 @@ def mlx_generate(
group,
on_prefill_progress,
)
logger.warning(f"[GENERATE] prefill() returned in {time.perf_counter() - t_prefill_start:.4f}s")
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
# stream_generate starts from the last token
@@ -348,9 +362,12 @@ def mlx_generate(
think_start = tokenizer.think_start
think_end = tokenizer.think_end
logger.info("Starting decode")
logger.warning("[GENERATE] Starting decode")
t0 = time.perf_counter()
mx_barrier(group)
logger.warning(f"[GENERATE] mx_barrier (pre-decode) took {time.perf_counter() - t0:.4f}s")
_decode_token_start = time.perf_counter()
for completion_tokens, out in enumerate(
stream_generate(
model=model,
@@ -366,6 +383,9 @@ def mlx_generate(
),
start=1,
):
_decode_token_elapsed = time.perf_counter() - _decode_token_start
if _decode_token_elapsed > 1.0:
logger.warning(f"[DECODE] token {completion_tokens} took {_decode_token_elapsed:.4f}s (SLOW)")
generated_text_parts.append(out.text)
accumulated_text += out.text
@@ -488,9 +508,12 @@ def mlx_generate(
)
if is_done:
t0 = time.perf_counter()
mx_barrier(group)
logger.warning(f"[GENERATE] mx_barrier (post-decode) took {time.perf_counter() - t0:.4f}s")
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
_decode_token_start = time.perf_counter()

View File

@@ -23,9 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
TRUST_REMOTE_CODE,
)
from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE
try:
from mlx_lm.tokenizer_utils import load_tokenizer
@@ -293,7 +291,11 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=shard_metadata.model_card.trust_remote_code,
)
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
@@ -325,7 +327,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
model_id: ModelId, model_path: Path, *, trust_remote_code: bool = TRUST_REMOTE_CODE
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
@@ -394,7 +396,7 @@ def load_tokenizer_for_model_id(
tokenizer = load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
tokenizer_config_extra={"trust_remote_code": trust_remote_code},
eos_token_ids=eos_token_ids,
)

View File

@@ -276,31 +276,38 @@ def main(
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=PrefillProgressChunk(
model=shard_metadata.model_card.model_id,
processed_tokens=processed,
total_tokens=total,
),
)
)
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (_task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, _group):
raise PrefillCancelled()
time.sleep(0.2)
return None
# if device_rank == 0:
# event_sender.send(
# ChunkGenerated(
# command_id=command_id,
# chunk=PrefillProgressChunk(
# model=shard_metadata.model_card.model_id,
# processed_tokens=processed,
# total_tokens=total,
# ),
# )
# )
# cancelled_tasks.update(cancel_receiver.collect())
# want_to_cancel = (_task_id in cancelled_tasks) or (
# TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
# )
# if mx_any(want_to_cancel, _group):
# raise PrefillCancelled()
try:
import time as _time
_runner_req_start = _time.perf_counter()
_check_for_debug_prompts(task_params)
# Build prompt once - used for both generation and thinking detection
_t0 = _time.perf_counter()
prompt = apply_chat_template(tokenizer, task_params)
logger.warning(f"[RUNNER] apply_chat_template took {_time.perf_counter() - _t0:.4f}s")
# Generate responses using the actual MLX generation
logger.warning("[RUNNER] calling mlx_generate")
mlx_generator = mlx_generate(
model=cast(Model, inference_model),
tokenizer=tokenizer,
@@ -332,6 +339,8 @@ def main(
completion_tokens = 0
tokens_since_last_cancel_check = check_for_cancel_every
logger.warning("[RUNNER] starting token iteration loop")
_runner_token_start = _time.perf_counter()
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
@@ -413,6 +422,7 @@ def main(
)
raise
logger.warning(f"[RUNNER] request complete in {_time.perf_counter() - _runner_req_start:.4f}s total")
current_status = RunnerReady()
logger.info("runner ready")

View File

@@ -106,13 +106,18 @@ class RunnerSupervisor:
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._tg.cancel_tasks()
self._ev_recv.close()
self._task_sender.close()
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
with contextlib.suppress(ClosedResourceError):
self._ev_recv.close()
with contextlib.suppress(ClosedResourceError):
self._task_sender.close()
with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
self.runner_process.join(5)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")

279
tmp/reproduce_gpu_lock.py Normal file
View File

@@ -0,0 +1,279 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# ///
"""Reproduce GPU lock issue with mlx-community/Llama-3.2-1B-Instruct-4bit.
Starts exo or mlx_lm.server, then sends repeated chat completions
until a request stalls for >5 seconds (indicating a GPU lock).
Usage:
uv run tmp/reproduce_gpu_lock.py # use exo (default)
uv run tmp/reproduce_gpu_lock.py --mlx-lm # use mlx_lm.server
"""
import argparse
import hashlib
import json
import os
import platform
import random
import signal
import subprocess
import sys
import threading
import time
import urllib.error
import urllib.request
import uuid
MODEL_ID = "mlx-community/Llama-3.2-1B-Instruct-4bit"
MODEL_PATH = os.path.expanduser("~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit")
STALL_THRESHOLD_S = 5.0
server_proc = None
base_url = ""
def cleanup(*_):
if server_proc and server_proc.poll() is None:
print("\nStopping server...")
server_proc.terminate()
try:
server_proc.wait(timeout=10)
except subprocess.TimeoutExpired:
server_proc.kill()
sys.exit(0)
signal.signal(signal.SIGINT, cleanup)
signal.signal(signal.SIGTERM, cleanup)
def api_get(path, timeout=30):
req = urllib.request.Request(f"{base_url}{path}")
with urllib.request.urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read())
def api_post(path, body, timeout=300):
data = json.dumps(body).encode()
req = urllib.request.Request(f"{base_url}{path}", data=data, headers={"Content-Type": "application/json"})
with urllib.request.urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read())
def wait_for_api(max_wait=120):
print("Waiting for API to be ready...", flush=True)
start = time.time()
while time.time() - start < max_wait:
try:
api_get("/v1/models", timeout=5)
print("API is ready.", flush=True)
return
except Exception:
time.sleep(2)
print("ERROR: API did not become ready in time.", flush=True)
cleanup()
def create_instance(max_wait=120):
print(f"Waiting for valid placements for {MODEL_ID}...", flush=True)
start = time.time()
valid = []
while time.time() - start < max_wait:
try:
previews = api_get(f"/instance/previews?model_id={MODEL_ID}")
valid = [p for p in previews.get("previews", []) if p.get("error") is None and p.get("instance") is not None]
if valid:
break
except Exception:
pass
time.sleep(3)
if not valid:
print("ERROR: No valid placements found after waiting.", flush=True)
cleanup()
instance = valid[0]["instance"]
print(f"Creating instance (sharding={valid[0].get('sharding')}, meta={valid[0].get('instance_meta')})...", flush=True)
resp = api_post("/instance", {"instance": instance})
print(f"Instance creation requested: {resp.get('message')} (command_id={resp.get('command_id')})", flush=True)
return instance.get("id") or instance.get("instance_id")
def wait_for_instance(max_wait=120):
print("Waiting for instance to be ready...", flush=True)
start = time.time()
while time.time() - start < max_wait:
try:
state = api_get("/state", timeout=10)
instances = state.get("instances") or state.get("model_instances") or {}
if instances:
print(f"Instance ready. ({len(instances)} instance(s) in state)", flush=True)
return
except Exception:
pass
time.sleep(3)
print("WARNING: Timed out waiting for instance in state. Proceeding anyway...", flush=True)
TOPICS = [
"the weather", "cats", "space", "pizza", "music", "ocean", "mountains",
"robots", "books", "coffee", "trains", "clouds", "birds", "fire",
"ice cream", "trees", "rivers", "stars", "thunder", "gardens",
]
def send_chat(request_num):
topic = random.choice(TOPICS)
nonce = random.randint(1000, 9999)
body = {
"model": MODEL_ID,
"messages": [{"role": "user", "content": f"Say something about {topic} in one sentence. ({nonce})"}],
"stream": False,
"max_tokens": 64,
}
start = time.time()
resp = api_post("/v1/chat/completions", body, timeout=600)
elapsed = time.time() - start
return elapsed, resp
def start_exo():
global server_proc
machine_id = hashlib.sha256(f"{platform.node()}-{uuid.getnode()}".encode()).hexdigest()[:12]
namespace = f"gpu-lock-repro-{machine_id}"
log_file = open("/tmp/exo_gpu_lock_repro.log", "w", buffering=1)
print(f"\nStarting exo (namespace={namespace})...", flush=True)
print(f"Log: /tmp/exo_gpu_lock_repro.log", flush=True)
print(f" tail -f /tmp/exo_gpu_lock_repro.log (in another terminal to watch)", flush=True)
env = {**os.environ, "EXO_LIBP2P_NAMESPACE": namespace, "PYTHONUNBUFFERED": "1"}
server_proc = subprocess.Popen(
["uv", "run", "exo"],
stdout=log_file,
stderr=subprocess.STDOUT,
env=env,
)
print(f"exo started (pid={server_proc.pid})", flush=True)
wait_for_api()
create_instance()
wait_for_instance()
def start_mlx_lm():
global server_proc
log_file = open("/tmp/mlx_lm_gpu_lock_repro.log", "w", buffering=1)
print(f"\nStarting mlx_lm.server on port 8080...", flush=True)
print(f" Model path: {MODEL_PATH}", flush=True)
print(f"Log: /tmp/mlx_lm_gpu_lock_repro.log", flush=True)
print(f" tail -f /tmp/mlx_lm_gpu_lock_repro.log (in another terminal to watch)", flush=True)
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
server_proc = subprocess.Popen(
["uv", "run", "mlx_lm.server", "--model", MODEL_PATH, "--port", "8080"],
stdout=log_file,
stderr=subprocess.STDOUT,
env=env,
cwd=os.path.expanduser("~/mlx-lm"),
)
print(f"mlx_lm.server started (pid={server_proc.pid})", flush=True)
wait_for_api()
def chat_loop():
print("\n" + "-" * 60, flush=True)
print("Starting chat completion loop. Watching for stalls...", flush=True)
print("-" * 60 + "\n", flush=True)
timings = []
request_num = 0
while True:
request_num += 1
print(f" [#{request_num}] sending...", end="", flush=True)
req_start = time.time()
done_event = threading.Event()
def print_waiting():
while not done_event.is_set():
if done_event.wait(5):
break
elapsed_so_far = time.time() - req_start
print(f" ({elapsed_so_far:.0f}s)", end="", flush=True)
watcher = threading.Thread(target=print_waiting, daemon=True)
watcher.start()
try:
elapsed, resp = send_chat(request_num)
except Exception as e:
done_event.set()
print(f" ERROR after {time.time() - req_start:.1f}s: {e}", flush=True)
time.sleep(2)
continue
finally:
done_event.set()
timings.append(elapsed)
content = ""
try:
content = resp["choices"][0]["message"]["content"][:80]
except (KeyError, IndexError):
content = "<no content>"
print(f" {elapsed:.2f}s | {content}", flush=True)
if elapsed > STALL_THRESHOLD_S:
print("\n", flush=True)
print("!" * 60, flush=True)
print("!" * 60, flush=True)
print("!!!", flush=True)
print(f"!!! GPU LOCK DETECTED on request #{request_num}", flush=True)
print(f"!!! Elapsed: {elapsed:.2f}s (threshold: {STALL_THRESHOLD_S}s)", flush=True)
print("!!!", flush=True)
print("!" * 60, flush=True)
print("!" * 60, flush=True)
print(f"\nTotal requests sent: {request_num}", flush=True)
print(f"Average time (all): {sum(timings) / len(timings):.2f}s", flush=True)
normal = [t for t in timings if t <= STALL_THRESHOLD_S]
if normal:
print(f"Average time (normal): {sum(normal) / len(normal):.2f}s", flush=True)
print(f"Max time: {max(timings):.2f}s", flush=True)
print(f"Min time: {min(timings):.2f}s", flush=True)
print("\nAll timings:", flush=True)
for i, t in enumerate(timings, 1):
marker = " <<<< STALL" if t > STALL_THRESHOLD_S else ""
print(f" #{i}: {t:.2f}s{marker}", flush=True)
print(f"\nServer still running (pid={server_proc.pid}). Continuing... Ctrl+C to stop.", flush=True)
print("-" * 60 + "\n", flush=True)
def main():
global base_url
parser = argparse.ArgumentParser(description="Reproduce GPU lock issue")
parser.add_argument("--mlx-lm", action="store_true", help="Use mlx_lm.server instead of exo")
parser.add_argument("--port", type=int, default=None, help="Override server port")
args = parser.parse_args()
mode = "mlx_lm" if args.mlx_lm else "exo"
port = args.port or (8080 if args.mlx_lm else 52415)
base_url = f"http://localhost:{port}"
print("=" * 60, flush=True)
print(" GPU Lock Reproduction Script", flush=True)
print(f" Mode: {mode}", flush=True)
print(f" Model: {MODEL_ID}", flush=True)
print(f" API: {base_url}", flush=True)
print(f" Stall threshold: {STALL_THRESHOLD_S}s", flush=True)
print("=" * 60, flush=True)
if args.mlx_lm:
start_mlx_lm()
else:
start_exo()
chat_loop()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,125 @@
#!/usr/bin/env bash
# Test that models added via API get trust_remote_code=false
# Run this against a running exo instance.
# Usage: ./test_trust_remote_code_attack.sh [host:port]
set -uo pipefail
HOST="${1:-localhost:52415}"
MODEL_ID="KevTheHermit/security-testing"
CUSTOM_CARDS_DIR="$HOME/.exo/custom_model_cards"
CARD_FILE="$CUSTOM_CARDS_DIR/KevTheHermit--security-testing.toml"
echo "=== Test: trust_remote_code attack via API ==="
echo "Target: $HOST"
echo ""
# Clean up RCE proof from previous runs
rm -f /tmp/exo-rce-proof.txt
# Step 0: Clean up any stale card from previous runs
if [ -f "$CARD_FILE" ]; then
echo "[0] Removing stale card from previous run ..."
curl -s -X DELETE \
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
rm -f "$CARD_FILE"
echo " Done"
echo ""
fi
# Step 1: Add the malicious model via API
echo "[1] Adding model via POST /models/add ..."
ADD_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/models/add" \
-H "Content-Type: application/json" \
-d "{\"model_id\":\"$MODEL_ID\"}")
HTTP_CODE=$(echo "$ADD_RESPONSE" | tail -1)
BODY=$(echo "$ADD_RESPONSE" | sed '$d')
echo " HTTP $HTTP_CODE"
if [ "$HTTP_CODE" -ge 400 ]; then
echo " Model add failed (HTTP $HTTP_CODE) — that's fine if model doesn't exist on HF."
echo " Response: $BODY"
echo ""
echo "RESULT: Model was rejected at add time. Attack blocked."
exit 0
fi
# Step 2: Verify the saved TOML has trust_remote_code = false
echo ""
echo "[2] Checking saved model card TOML ..."
if [ ! -f "$CARD_FILE" ]; then
echo " FAIL: Card file not found at $CARD_FILE"
exit 1
fi
if grep -q 'trust_remote_code = false' "$CARD_FILE"; then
echo " SAFE: trust_remote_code = false (fix is active)"
else
echo " VULNERABLE: trust_remote_code is not false — remote code WILL be trusted"
fi
echo " Contents:"
cat "$CARD_FILE"
# Step 3: Place the instance
echo ""
echo "[3] Attempting POST /place_instance ..."
PLACE_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/place_instance" \
-H "Content-Type: application/json" \
-d "{\"model_id\":\"$MODEL_ID\"}")
PLACE_CODE=$(echo "$PLACE_RESPONSE" | tail -1)
PLACE_BODY=$(echo "$PLACE_RESPONSE" | sed '$d')
echo " HTTP $PLACE_CODE"
echo " Response: $PLACE_BODY"
# Step 3b: Send a chat completion to actually trigger tokenizer loading
echo ""
echo "[3b] Sending chat completion to trigger tokenizer load ..."
CHAT_RESPONSE=$(curl -s -w "\n%{http_code}" --max-time 30 -X POST "http://$HOST/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{\"model\":\"$MODEL_ID\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}],\"max_tokens\":1}")
CHAT_CODE=$(echo "$CHAT_RESPONSE" | tail -1)
CHAT_BODY=$(echo "$CHAT_RESPONSE" | sed '$d')
echo " HTTP $CHAT_CODE"
echo " Response: $CHAT_BODY"
echo ""
echo "[3c] Checking for RCE proof ..."
sleep 5
if [ -f /tmp/exo-rce-proof.txt ]; then
echo " VULNERABLE: Remote code executed!"
echo " Contents:"
cat /tmp/exo-rce-proof.txt
else
echo " SAFE: /tmp/exo-rce-proof.txt does not exist — remote code was NOT executed"
fi
# Step 4: Clean up — delete instance and custom model
echo ""
echo "[4] Cleaning up ..."
# Find and delete any instance for this model
INSTANCE_ID=$(curl -s "http://$HOST/state" | python3 -c "
import sys, json
state = json.load(sys.stdin)
for iid, wrapper in state.get('instances', {}).items():
for tag, inst in wrapper.items():
sa = inst.get('shardAssignments', {})
if sa.get('modelId', '') == '$MODEL_ID':
print(iid)
sys.exit(0)
" 2>/dev/null || true)
if [ -n "$INSTANCE_ID" ]; then
echo " Deleting instance $INSTANCE_ID ..."
curl -s -X DELETE "http://$HOST/instance/$INSTANCE_ID" >/dev/null
echo " Done"
else
echo " No instance found to delete"
fi
echo " Deleting custom model card ..."
curl -s -X DELETE \
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
echo " Done"
echo ""
echo "=== DONE ==="

10
uv.lock generated
View File

@@ -378,7 +378,7 @@ dependencies = [
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1025,7 +1025,7 @@ dependencies = [
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1072,8 +1072,8 @@ cuda13 = [
[[package]]
name = "mlx"
version = "0.30.7.dev20260220+13998a05"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }
version = "0.30.7.dev20260224+5289547a"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }
resolution-markers = [
"sys_platform == 'darwin'",
]
@@ -1108,7 +1108,7 @@ version = "0.30.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },