mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-24 18:28:30 -05:00
Compare commits
10 Commits
fix-instan
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ccf4d91d55 | ||
|
|
811a4d80bd | ||
|
|
2fe689315b | ||
|
|
644c5573ce | ||
|
|
12c3015f52 | ||
|
|
365dd68d9a | ||
|
|
d3d129581e | ||
|
|
c90a0cec78 | ||
|
|
e8c1337168 | ||
|
|
7024ddcf3e |
@@ -75,7 +75,7 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
model_id,
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken"],
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -412,7 +412,7 @@
|
||||
<div>{col.label}</div>
|
||||
{#if col.diskAvailable != null}
|
||||
<div
|
||||
class="text-[9px] text-exo-light-gray/60 normal-case tracking-normal mt-0.5"
|
||||
class="text-[9px] text-white/70 normal-case tracking-normal mt-0.5"
|
||||
>
|
||||
{formatBytes(col.diskAvailable)} free
|
||||
</div>
|
||||
@@ -436,7 +436,7 @@
|
||||
</div>
|
||||
{#if row.prettyName}
|
||||
<div
|
||||
class="text-[10px] text-exo-light-gray/60"
|
||||
class="text-[10px] text-white/60"
|
||||
title={row.modelId}
|
||||
>
|
||||
{row.modelId}
|
||||
@@ -450,7 +450,7 @@
|
||||
title="View model details"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4 text-white/30 hover:text-white/60"
|
||||
class="w-4 h-4 text-white/60 hover:text-white/80"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
@@ -469,11 +469,11 @@
|
||||
<td class="px-4 py-3 text-center align-middle">
|
||||
{#if cell.kind === "completed"}
|
||||
<div
|
||||
class="flex flex-col items-center gap-0.5"
|
||||
class="flex flex-col items-center gap-1"
|
||||
title="Completed ({formatBytes(cell.totalBytes)})"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-green-400"
|
||||
class="w-7 h-7 text-green-400"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
@@ -483,18 +483,18 @@
|
||||
clip-rule="evenodd"
|
||||
></path>
|
||||
</svg>
|
||||
<span class="text-[10px] text-exo-light-gray/70"
|
||||
<span class="text-xs text-white/70"
|
||||
>{formatBytes(cell.totalBytes)}</span
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5"
|
||||
class="text-white/50 hover:text-red-400 transition-colors mt-0.5 cursor-pointer"
|
||||
onclick={() =>
|
||||
deleteDownload(col.nodeId, row.modelId)}
|
||||
title="Delete from this node"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
class="w-5 h-5"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
@@ -517,11 +517,11 @@
|
||||
cell.speed,
|
||||
)} - ETA {formatEta(cell.etaMs)}"
|
||||
>
|
||||
<span class="text-exo-yellow text-xs font-medium"
|
||||
<span class="text-exo-yellow text-sm font-medium"
|
||||
>{clampPercent(cell.percentage).toFixed(1)}%</span
|
||||
>
|
||||
<div
|
||||
class="w-14 h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
class="w-16 h-2 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
@@ -530,25 +530,25 @@
|
||||
).toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<span class="text-[9px] text-exo-light-gray/60"
|
||||
<span class="text-[10px] text-white/70"
|
||||
>{formatSpeed(cell.speed)}</span
|
||||
>
|
||||
</div>
|
||||
{:else if cell.kind === "pending"}
|
||||
<div
|
||||
class="flex flex-col items-center gap-0.5"
|
||||
class="flex flex-col items-center gap-1"
|
||||
title={cell.downloaded > 0
|
||||
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded`
|
||||
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded (paused)`
|
||||
: "Download pending"}
|
||||
>
|
||||
{#if cell.downloaded > 0 && cell.total > 0}
|
||||
<span class="text-exo-light-gray/70 text-[10px]"
|
||||
<span class="text-white/70 text-xs"
|
||||
>{formatBytes(cell.downloaded)} / {formatBytes(
|
||||
cell.total,
|
||||
)}</span
|
||||
>
|
||||
<div
|
||||
class="w-full h-1 bg-white/10 rounded-full overflow-hidden"
|
||||
class="w-full h-1.5 bg-white/10 rounded-full overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-exo-light-gray/40 rounded-full"
|
||||
@@ -558,21 +558,65 @@
|
||||
).toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<span class="text-exo-light-gray/40 text-[9px]"
|
||||
>paused</span
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Resume download on this node"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-white/50 text-[10px]">paused</span
|
||||
>
|
||||
{/if}
|
||||
{:else if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Start download on this node"
|
||||
>
|
||||
<svg
|
||||
class="w-6 h-6"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-exo-light-gray/50 text-sm">...</span
|
||||
>
|
||||
<span class="text-white/40 text-sm">...</span>
|
||||
{/if}
|
||||
</div>
|
||||
{:else if cell.kind === "failed"}
|
||||
<div
|
||||
class="flex flex-col items-center gap-0.5"
|
||||
class="flex flex-col items-center gap-1"
|
||||
title="Download failed"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-red-400"
|
||||
class="w-7 h-7 text-red-400"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
@@ -585,13 +629,13 @@
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Retry download on this node"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
class="w-5 h-5"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
@@ -617,13 +661,13 @@
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Download to this node"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
class="w-5 h-5"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
|
||||
@@ -41,7 +41,7 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260220+13998a05"; in
|
||||
version = let v = "0.30.7.dev20260224+5289547a"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
@@ -49,8 +49,8 @@ let
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "13998a054715edcdc93618fb1496c79c7c25ff7c";
|
||||
hash = "sha256-fAqA3hFwNBx7FcoGnhQsIFpAIRbC2EerACm4Fvne0Cc=";
|
||||
rev = "5289547ada1cddda2b9716baf6a077a906d02189";
|
||||
hash = "sha256-Zp9Jln7+Fpn79OfnIdiIVYzQDpih9lHrKtKJadh+c0I=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -823,6 +823,7 @@ async def download_shard(
|
||||
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
||||
final_file_exists = await aios.path.exists(target_dir / file.path)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=shard.model_card.model_id,
|
||||
repo_revision=revision,
|
||||
@@ -832,7 +833,9 @@ async def download_shard(
|
||||
total=Memory.from_bytes(file.size or 0),
|
||||
speed=0,
|
||||
eta=timedelta(0),
|
||||
status="complete" if downloaded_bytes == file.size else "not_started",
|
||||
status="complete"
|
||||
if final_file_exists and downloaded_bytes == file.size
|
||||
else "not_started",
|
||||
start_time=time.time(),
|
||||
)
|
||||
|
||||
|
||||
@@ -252,7 +252,7 @@ def main():
|
||||
target = min(max(soft, 65535), hard)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
mp.set_start_method("spawn", force=True)
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
|
||||
@@ -168,12 +168,7 @@ from exo.shared.types.openai_responses import (
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxJacclInstance,
|
||||
)
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
@@ -518,14 +513,6 @@ class API:
|
||||
shard_assignments = instance.shard_assignments
|
||||
placement_node_ids = list(shard_assignments.node_to_runner.keys())
|
||||
|
||||
# Derive instance_meta from the actual instance type, since
|
||||
# place_instance() may override it (e.g., single-node → MlxRing)
|
||||
actual_instance_meta = (
|
||||
InstanceMeta.MlxJaccl
|
||||
if isinstance(instance, MlxJacclInstance)
|
||||
else InstanceMeta.MlxRing
|
||||
)
|
||||
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if placement_node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
@@ -538,14 +525,14 @@ class API:
|
||||
if (
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
actual_instance_meta,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=actual_instance_meta,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
@@ -555,7 +542,7 @@ class API:
|
||||
(
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
actual_instance_meta,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -14,10 +14,12 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -456,3 +458,117 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
else:
|
||||
ip_part = coordinator.split(":")[0]
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
|
||||
def _make_task(
|
||||
instance_id: InstanceId,
|
||||
status: TaskStatus = TaskStatus.Running,
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(),
|
||||
task_status=status,
|
||||
instance_id=instance_id,
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId("test-model"),
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_running_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Running)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – cancellation event should come before the deletion event
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
assert events[1].instance_id == instance_id
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_pending_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Pending)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_ignores_completed_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
tasks = {
|
||||
t.task_id: t
|
||||
for t in [
|
||||
_make_task(instance_id, TaskStatus.Complete),
|
||||
_make_task(instance_id, TaskStatus.Failed),
|
||||
_make_task(instance_id, TaskStatus.TimedOut),
|
||||
_make_task(instance_id, TaskStatus.Cancelled),
|
||||
]
|
||||
}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only the InstanceDeleted event, no cancellations
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id_a = InstanceId()
|
||||
instance_id_b = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {
|
||||
instance_id_a: instance,
|
||||
instance_id_b: instance,
|
||||
}
|
||||
# only delete instance A, keep instance B
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
|
||||
|
||||
task_a = _make_task(instance_id_a, TaskStatus.Running)
|
||||
task_b = _make_task(instance_id_b, TaskStatus.Running)
|
||||
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only task_a should be cancelled
|
||||
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
|
||||
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
|
||||
assert len(cancel_events) == 1
|
||||
assert cancel_events[0].task_id == task_a.task_id
|
||||
assert cancel_events[0].task_status == TaskStatus.Cancelled
|
||||
assert len(delete_events) == 1
|
||||
assert delete_events[0].instance_id == instance_id_a
|
||||
|
||||
@@ -90,6 +90,7 @@ class ModelCard(CamelCaseModel):
|
||||
base_model: str = ""
|
||||
capabilities: list[str] = []
|
||||
uses_cfg: bool = False
|
||||
trust_remote_code: bool = True
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -137,6 +138,7 @@ class ModelCard(CamelCaseModel):
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
supports_tensor=config_data.supports_tensor,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
trust_remote_code=False,
|
||||
)
|
||||
await mc.save_to_custom_dir()
|
||||
_card_cache[model_id] = mc
|
||||
|
||||
@@ -128,12 +128,25 @@ class PipelineFirstLayer(CustomMlxLayer):
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
import time as _time
|
||||
_t0 = _time.perf_counter()
|
||||
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
||||
if self.is_prefill:
|
||||
# We want to avoid GPU timeout errors by evalling the distributed operation
|
||||
# so that it stays on CPU, which does not have a timeout.
|
||||
mx.eval(x)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
_elapsed = _time.perf_counter() - _t0
|
||||
if _elapsed > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineFirstLayer recv_like+eval took {_elapsed:.4f}s (SLOW)")
|
||||
_t0_layer = _time.perf_counter() if self.r != 0 else None
|
||||
result = self.original_layer(x, *args, **kwargs)
|
||||
if _t0_layer is not None:
|
||||
_elapsed_layer = _time.perf_counter() - _t0_layer
|
||||
if _elapsed_layer > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineFirstLayer original_layer took {_elapsed_layer:.4f}s (SLOW)")
|
||||
return result
|
||||
|
||||
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
@@ -152,13 +165,20 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
self.is_prefill: bool = False
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
import time as _time
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
_t0 = _time.perf_counter()
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
_elapsed = _time.perf_counter() - _t0
|
||||
if _elapsed > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer original_layer took {_elapsed:.4f}s (SLOW)")
|
||||
|
||||
if self.r != self.s - 1:
|
||||
_t0 = _time.perf_counter()
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
@@ -171,11 +191,20 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
mx.eval(output)
|
||||
if cache is not None:
|
||||
mx.eval(_cache.keys) # type: ignore
|
||||
_elapsed = _time.perf_counter() - _t0
|
||||
if _elapsed > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer send+eval took {_elapsed:.4f}s (SLOW)")
|
||||
|
||||
if not self.is_prefill:
|
||||
_t0 = _time.perf_counter()
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
]
|
||||
_elapsed = _time.perf_counter() - _t0
|
||||
if _elapsed > 1.0:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"[PIPELINE] PipelineLastLayer all_gather took {_elapsed:.4f}s (SLOW)")
|
||||
|
||||
return output
|
||||
|
||||
@@ -852,6 +881,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if hasattr(layer, "linear_attn"):
|
||||
|
||||
@@ -94,14 +94,20 @@ def prefill(
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
logger.warning(f"[PREFILL] set_pipeline_prefill(True) took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
t0 = time.perf_counter()
|
||||
mx_barrier(group)
|
||||
logger.info("Starting prefill")
|
||||
logger.warning(f"[PREFILL] mx_barrier (pre-prefill) took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
logger.warning("[PREFILL] Starting prefill via stream_generate")
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -114,15 +120,19 @@ def prefill(
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
logger.warning(f"[PREFILL] stream_generate first yield took {time.perf_counter() - t0:.4f}s")
|
||||
break # Stop after first iteration - cache is now filled
|
||||
except PrefillCancelled:
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
raise
|
||||
|
||||
t0 = time.perf_counter()
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
logger.warning(f"[PREFILL] set_pipeline_prefill(False) took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
|
||||
t0 = time.perf_counter()
|
||||
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
|
||||
for i, c in enumerate(cache):
|
||||
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
@@ -132,11 +142,12 @@ def prefill(
|
||||
else:
|
||||
assert not isinstance(c, (ArraysCache, RotatingKVCache))
|
||||
c.trim(2) # pyright: ignore[reportUnknownMemberType]
|
||||
logger.warning(f"[PREFILL] cache trim took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
logger.debug(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
logger.warning(
|
||||
f"[PREFILL] complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
# Exclude the last snapshot
|
||||
@@ -324,6 +335,8 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
logger.warning(f"[GENERATE] calling prefill with {len(prompt_tokens) - 1} tokens")
|
||||
t_prefill_start = time.perf_counter()
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -333,6 +346,7 @@ def mlx_generate(
|
||||
group,
|
||||
on_prefill_progress,
|
||||
)
|
||||
logger.warning(f"[GENERATE] prefill() returned in {time.perf_counter() - t_prefill_start:.4f}s")
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
# stream_generate starts from the last token
|
||||
@@ -348,9 +362,12 @@ def mlx_generate(
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
|
||||
logger.info("Starting decode")
|
||||
logger.warning("[GENERATE] Starting decode")
|
||||
t0 = time.perf_counter()
|
||||
mx_barrier(group)
|
||||
logger.warning(f"[GENERATE] mx_barrier (pre-decode) took {time.perf_counter() - t0:.4f}s")
|
||||
|
||||
_decode_token_start = time.perf_counter()
|
||||
for completion_tokens, out in enumerate(
|
||||
stream_generate(
|
||||
model=model,
|
||||
@@ -366,6 +383,9 @@ def mlx_generate(
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
_decode_token_elapsed = time.perf_counter() - _decode_token_start
|
||||
if _decode_token_elapsed > 1.0:
|
||||
logger.warning(f"[DECODE] token {completion_tokens} took {_decode_token_elapsed:.4f}s (SLOW)")
|
||||
generated_text_parts.append(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
@@ -488,9 +508,12 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
if is_done:
|
||||
t0 = time.perf_counter()
|
||||
mx_barrier(group)
|
||||
logger.warning(f"[GENERATE] mx_barrier (post-decode) took {time.perf_counter() - t0:.4f}s")
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
_decode_token_start = time.perf_counter()
|
||||
|
||||
@@ -23,9 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE
|
||||
|
||||
try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
@@ -293,7 +291,11 @@ def shard_and_load(
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
return load_tokenizer_for_model_id(
|
||||
shard_metadata.model_card.model_id,
|
||||
model_path,
|
||||
trust_remote_code=shard_metadata.model_card.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
@@ -325,7 +327,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(
|
||||
model_id: ModelId, model_path: Path
|
||||
model_id: ModelId, model_path: Path, *, trust_remote_code: bool = TRUST_REMOTE_CODE
|
||||
) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
@@ -394,7 +396,7 @@ def load_tokenizer_for_model_id(
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
tokenizer_config_extra={"trust_remote_code": trust_remote_code},
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -276,31 +276,38 @@ def main(
|
||||
_task_id: TaskId = task.task_id,
|
||||
_group: mx.distributed.Group | None = group,
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=PrefillProgressChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (_task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, _group):
|
||||
raise PrefillCancelled()
|
||||
time.sleep(0.2)
|
||||
return None
|
||||
# if device_rank == 0:
|
||||
# event_sender.send(
|
||||
# ChunkGenerated(
|
||||
# command_id=command_id,
|
||||
# chunk=PrefillProgressChunk(
|
||||
# model=shard_metadata.model_card.model_id,
|
||||
# processed_tokens=processed,
|
||||
# total_tokens=total,
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
# cancelled_tasks.update(cancel_receiver.collect())
|
||||
# want_to_cancel = (_task_id in cancelled_tasks) or (
|
||||
# TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
# )
|
||||
# if mx_any(want_to_cancel, _group):
|
||||
# raise PrefillCancelled()
|
||||
|
||||
try:
|
||||
import time as _time
|
||||
_runner_req_start = _time.perf_counter()
|
||||
_check_for_debug_prompts(task_params)
|
||||
|
||||
# Build prompt once - used for both generation and thinking detection
|
||||
_t0 = _time.perf_counter()
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
logger.warning(f"[RUNNER] apply_chat_template took {_time.perf_counter() - _t0:.4f}s")
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
logger.warning("[RUNNER] calling mlx_generate")
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, inference_model),
|
||||
tokenizer=tokenizer,
|
||||
@@ -332,6 +339,8 @@ def main(
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = check_for_cancel_every
|
||||
logger.warning("[RUNNER] starting token iteration loop")
|
||||
_runner_token_start = _time.perf_counter()
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
@@ -413,6 +422,7 @@ def main(
|
||||
)
|
||||
raise
|
||||
|
||||
logger.warning(f"[RUNNER] request complete in {_time.perf_counter() - _runner_req_start:.4f}s total")
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
|
||||
|
||||
@@ -106,13 +106,18 @@ class RunnerSupervisor:
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._tg.cancel_tasks()
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
if not self._cancel_watch_runner.cancel_called:
|
||||
self._cancel_watch_runner.cancel()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._ev_recv.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._task_sender.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._event_sender.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
|
||||
279
tmp/reproduce_gpu_lock.py
Normal file
279
tmp/reproduce_gpu_lock.py
Normal file
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# ///
|
||||
"""Reproduce GPU lock issue with mlx-community/Llama-3.2-1B-Instruct-4bit.
|
||||
|
||||
Starts exo or mlx_lm.server, then sends repeated chat completions
|
||||
until a request stalls for >5 seconds (indicating a GPU lock).
|
||||
|
||||
Usage:
|
||||
uv run tmp/reproduce_gpu_lock.py # use exo (default)
|
||||
uv run tmp/reproduce_gpu_lock.py --mlx-lm # use mlx_lm.server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import uuid
|
||||
|
||||
MODEL_ID = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
MODEL_PATH = os.path.expanduser("~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit")
|
||||
STALL_THRESHOLD_S = 5.0
|
||||
|
||||
server_proc = None
|
||||
base_url = ""
|
||||
|
||||
|
||||
def cleanup(*_):
|
||||
if server_proc and server_proc.poll() is None:
|
||||
print("\nStopping server...")
|
||||
server_proc.terminate()
|
||||
try:
|
||||
server_proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
server_proc.kill()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, cleanup)
|
||||
signal.signal(signal.SIGTERM, cleanup)
|
||||
|
||||
|
||||
def api_get(path, timeout=30):
|
||||
req = urllib.request.Request(f"{base_url}{path}")
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read())
|
||||
|
||||
|
||||
def api_post(path, body, timeout=300):
|
||||
data = json.dumps(body).encode()
|
||||
req = urllib.request.Request(f"{base_url}{path}", data=data, headers={"Content-Type": "application/json"})
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return json.loads(resp.read())
|
||||
|
||||
|
||||
def wait_for_api(max_wait=120):
|
||||
print("Waiting for API to be ready...", flush=True)
|
||||
start = time.time()
|
||||
while time.time() - start < max_wait:
|
||||
try:
|
||||
api_get("/v1/models", timeout=5)
|
||||
print("API is ready.", flush=True)
|
||||
return
|
||||
except Exception:
|
||||
time.sleep(2)
|
||||
print("ERROR: API did not become ready in time.", flush=True)
|
||||
cleanup()
|
||||
|
||||
|
||||
def create_instance(max_wait=120):
|
||||
print(f"Waiting for valid placements for {MODEL_ID}...", flush=True)
|
||||
start = time.time()
|
||||
valid = []
|
||||
while time.time() - start < max_wait:
|
||||
try:
|
||||
previews = api_get(f"/instance/previews?model_id={MODEL_ID}")
|
||||
valid = [p for p in previews.get("previews", []) if p.get("error") is None and p.get("instance") is not None]
|
||||
if valid:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(3)
|
||||
if not valid:
|
||||
print("ERROR: No valid placements found after waiting.", flush=True)
|
||||
cleanup()
|
||||
|
||||
instance = valid[0]["instance"]
|
||||
print(f"Creating instance (sharding={valid[0].get('sharding')}, meta={valid[0].get('instance_meta')})...", flush=True)
|
||||
resp = api_post("/instance", {"instance": instance})
|
||||
print(f"Instance creation requested: {resp.get('message')} (command_id={resp.get('command_id')})", flush=True)
|
||||
return instance.get("id") or instance.get("instance_id")
|
||||
|
||||
|
||||
def wait_for_instance(max_wait=120):
|
||||
print("Waiting for instance to be ready...", flush=True)
|
||||
start = time.time()
|
||||
while time.time() - start < max_wait:
|
||||
try:
|
||||
state = api_get("/state", timeout=10)
|
||||
instances = state.get("instances") or state.get("model_instances") or {}
|
||||
if instances:
|
||||
print(f"Instance ready. ({len(instances)} instance(s) in state)", flush=True)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(3)
|
||||
print("WARNING: Timed out waiting for instance in state. Proceeding anyway...", flush=True)
|
||||
|
||||
|
||||
TOPICS = [
|
||||
"the weather", "cats", "space", "pizza", "music", "ocean", "mountains",
|
||||
"robots", "books", "coffee", "trains", "clouds", "birds", "fire",
|
||||
"ice cream", "trees", "rivers", "stars", "thunder", "gardens",
|
||||
]
|
||||
|
||||
def send_chat(request_num):
|
||||
topic = random.choice(TOPICS)
|
||||
nonce = random.randint(1000, 9999)
|
||||
body = {
|
||||
"model": MODEL_ID,
|
||||
"messages": [{"role": "user", "content": f"Say something about {topic} in one sentence. ({nonce})"}],
|
||||
"stream": False,
|
||||
"max_tokens": 64,
|
||||
}
|
||||
start = time.time()
|
||||
resp = api_post("/v1/chat/completions", body, timeout=600)
|
||||
elapsed = time.time() - start
|
||||
return elapsed, resp
|
||||
|
||||
|
||||
def start_exo():
|
||||
global server_proc
|
||||
machine_id = hashlib.sha256(f"{platform.node()}-{uuid.getnode()}".encode()).hexdigest()[:12]
|
||||
namespace = f"gpu-lock-repro-{machine_id}"
|
||||
|
||||
log_file = open("/tmp/exo_gpu_lock_repro.log", "w", buffering=1)
|
||||
print(f"\nStarting exo (namespace={namespace})...", flush=True)
|
||||
print(f"Log: /tmp/exo_gpu_lock_repro.log", flush=True)
|
||||
print(f" tail -f /tmp/exo_gpu_lock_repro.log (in another terminal to watch)", flush=True)
|
||||
env = {**os.environ, "EXO_LIBP2P_NAMESPACE": namespace, "PYTHONUNBUFFERED": "1"}
|
||||
server_proc = subprocess.Popen(
|
||||
["uv", "run", "exo"],
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
)
|
||||
print(f"exo started (pid={server_proc.pid})", flush=True)
|
||||
|
||||
wait_for_api()
|
||||
create_instance()
|
||||
wait_for_instance()
|
||||
|
||||
|
||||
def start_mlx_lm():
|
||||
global server_proc
|
||||
log_file = open("/tmp/mlx_lm_gpu_lock_repro.log", "w", buffering=1)
|
||||
print(f"\nStarting mlx_lm.server on port 8080...", flush=True)
|
||||
print(f" Model path: {MODEL_PATH}", flush=True)
|
||||
print(f"Log: /tmp/mlx_lm_gpu_lock_repro.log", flush=True)
|
||||
print(f" tail -f /tmp/mlx_lm_gpu_lock_repro.log (in another terminal to watch)", flush=True)
|
||||
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||
server_proc = subprocess.Popen(
|
||||
["uv", "run", "mlx_lm.server", "--model", MODEL_PATH, "--port", "8080"],
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
cwd=os.path.expanduser("~/mlx-lm"),
|
||||
)
|
||||
print(f"mlx_lm.server started (pid={server_proc.pid})", flush=True)
|
||||
wait_for_api()
|
||||
|
||||
|
||||
def chat_loop():
|
||||
print("\n" + "-" * 60, flush=True)
|
||||
print("Starting chat completion loop. Watching for stalls...", flush=True)
|
||||
print("-" * 60 + "\n", flush=True)
|
||||
|
||||
timings = []
|
||||
request_num = 0
|
||||
|
||||
while True:
|
||||
request_num += 1
|
||||
print(f" [#{request_num}] sending...", end="", flush=True)
|
||||
req_start = time.time()
|
||||
|
||||
done_event = threading.Event()
|
||||
def print_waiting():
|
||||
while not done_event.is_set():
|
||||
if done_event.wait(5):
|
||||
break
|
||||
elapsed_so_far = time.time() - req_start
|
||||
print(f" ({elapsed_so_far:.0f}s)", end="", flush=True)
|
||||
watcher = threading.Thread(target=print_waiting, daemon=True)
|
||||
watcher.start()
|
||||
|
||||
try:
|
||||
elapsed, resp = send_chat(request_num)
|
||||
except Exception as e:
|
||||
done_event.set()
|
||||
print(f" ERROR after {time.time() - req_start:.1f}s: {e}", flush=True)
|
||||
time.sleep(2)
|
||||
continue
|
||||
finally:
|
||||
done_event.set()
|
||||
|
||||
timings.append(elapsed)
|
||||
content = ""
|
||||
try:
|
||||
content = resp["choices"][0]["message"]["content"][:80]
|
||||
except (KeyError, IndexError):
|
||||
content = "<no content>"
|
||||
|
||||
print(f" {elapsed:.2f}s | {content}", flush=True)
|
||||
|
||||
if elapsed > STALL_THRESHOLD_S:
|
||||
print("\n", flush=True)
|
||||
print("!" * 60, flush=True)
|
||||
print("!" * 60, flush=True)
|
||||
print("!!!", flush=True)
|
||||
print(f"!!! GPU LOCK DETECTED on request #{request_num}", flush=True)
|
||||
print(f"!!! Elapsed: {elapsed:.2f}s (threshold: {STALL_THRESHOLD_S}s)", flush=True)
|
||||
print("!!!", flush=True)
|
||||
print("!" * 60, flush=True)
|
||||
print("!" * 60, flush=True)
|
||||
print(f"\nTotal requests sent: {request_num}", flush=True)
|
||||
print(f"Average time (all): {sum(timings) / len(timings):.2f}s", flush=True)
|
||||
normal = [t for t in timings if t <= STALL_THRESHOLD_S]
|
||||
if normal:
|
||||
print(f"Average time (normal): {sum(normal) / len(normal):.2f}s", flush=True)
|
||||
print(f"Max time: {max(timings):.2f}s", flush=True)
|
||||
print(f"Min time: {min(timings):.2f}s", flush=True)
|
||||
print("\nAll timings:", flush=True)
|
||||
for i, t in enumerate(timings, 1):
|
||||
marker = " <<<< STALL" if t > STALL_THRESHOLD_S else ""
|
||||
print(f" #{i}: {t:.2f}s{marker}", flush=True)
|
||||
print(f"\nServer still running (pid={server_proc.pid}). Continuing... Ctrl+C to stop.", flush=True)
|
||||
print("-" * 60 + "\n", flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
global base_url
|
||||
|
||||
parser = argparse.ArgumentParser(description="Reproduce GPU lock issue")
|
||||
parser.add_argument("--mlx-lm", action="store_true", help="Use mlx_lm.server instead of exo")
|
||||
parser.add_argument("--port", type=int, default=None, help="Override server port")
|
||||
args = parser.parse_args()
|
||||
|
||||
mode = "mlx_lm" if args.mlx_lm else "exo"
|
||||
port = args.port or (8080 if args.mlx_lm else 52415)
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
print("=" * 60, flush=True)
|
||||
print(" GPU Lock Reproduction Script", flush=True)
|
||||
print(f" Mode: {mode}", flush=True)
|
||||
print(f" Model: {MODEL_ID}", flush=True)
|
||||
print(f" API: {base_url}", flush=True)
|
||||
print(f" Stall threshold: {STALL_THRESHOLD_S}s", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
if args.mlx_lm:
|
||||
start_mlx_lm()
|
||||
else:
|
||||
start_exo()
|
||||
|
||||
chat_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
125
tmp/test_trust_remote_code_attack.sh
Executable file
125
tmp/test_trust_remote_code_attack.sh
Executable file
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env bash
|
||||
# Test that models added via API get trust_remote_code=false
|
||||
# Run this against a running exo instance.
|
||||
# Usage: ./test_trust_remote_code_attack.sh [host:port]
|
||||
|
||||
set -uo pipefail
|
||||
|
||||
HOST="${1:-localhost:52415}"
|
||||
MODEL_ID="KevTheHermit/security-testing"
|
||||
CUSTOM_CARDS_DIR="$HOME/.exo/custom_model_cards"
|
||||
CARD_FILE="$CUSTOM_CARDS_DIR/KevTheHermit--security-testing.toml"
|
||||
|
||||
echo "=== Test: trust_remote_code attack via API ==="
|
||||
echo "Target: $HOST"
|
||||
echo ""
|
||||
|
||||
# Clean up RCE proof from previous runs
|
||||
rm -f /tmp/exo-rce-proof.txt
|
||||
|
||||
# Step 0: Clean up any stale card from previous runs
|
||||
if [ -f "$CARD_FILE" ]; then
|
||||
echo "[0] Removing stale card from previous run ..."
|
||||
curl -s -X DELETE \
|
||||
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
|
||||
rm -f "$CARD_FILE"
|
||||
echo " Done"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Step 1: Add the malicious model via API
|
||||
echo "[1] Adding model via POST /models/add ..."
|
||||
ADD_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/models/add" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model_id\":\"$MODEL_ID\"}")
|
||||
HTTP_CODE=$(echo "$ADD_RESPONSE" | tail -1)
|
||||
BODY=$(echo "$ADD_RESPONSE" | sed '$d')
|
||||
echo " HTTP $HTTP_CODE"
|
||||
|
||||
if [ "$HTTP_CODE" -ge 400 ]; then
|
||||
echo " Model add failed (HTTP $HTTP_CODE) — that's fine if model doesn't exist on HF."
|
||||
echo " Response: $BODY"
|
||||
echo ""
|
||||
echo "RESULT: Model was rejected at add time. Attack blocked."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Step 2: Verify the saved TOML has trust_remote_code = false
|
||||
echo ""
|
||||
echo "[2] Checking saved model card TOML ..."
|
||||
if [ ! -f "$CARD_FILE" ]; then
|
||||
echo " FAIL: Card file not found at $CARD_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if grep -q 'trust_remote_code = false' "$CARD_FILE"; then
|
||||
echo " SAFE: trust_remote_code = false (fix is active)"
|
||||
else
|
||||
echo " VULNERABLE: trust_remote_code is not false — remote code WILL be trusted"
|
||||
fi
|
||||
echo " Contents:"
|
||||
cat "$CARD_FILE"
|
||||
|
||||
# Step 3: Place the instance
|
||||
echo ""
|
||||
echo "[3] Attempting POST /place_instance ..."
|
||||
PLACE_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "http://$HOST/place_instance" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model_id\":\"$MODEL_ID\"}")
|
||||
PLACE_CODE=$(echo "$PLACE_RESPONSE" | tail -1)
|
||||
PLACE_BODY=$(echo "$PLACE_RESPONSE" | sed '$d')
|
||||
echo " HTTP $PLACE_CODE"
|
||||
echo " Response: $PLACE_BODY"
|
||||
|
||||
# Step 3b: Send a chat completion to actually trigger tokenizer loading
|
||||
echo ""
|
||||
echo "[3b] Sending chat completion to trigger tokenizer load ..."
|
||||
CHAT_RESPONSE=$(curl -s -w "\n%{http_code}" --max-time 30 -X POST "http://$HOST/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"model\":\"$MODEL_ID\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}],\"max_tokens\":1}")
|
||||
CHAT_CODE=$(echo "$CHAT_RESPONSE" | tail -1)
|
||||
CHAT_BODY=$(echo "$CHAT_RESPONSE" | sed '$d')
|
||||
echo " HTTP $CHAT_CODE"
|
||||
echo " Response: $CHAT_BODY"
|
||||
echo ""
|
||||
echo "[3c] Checking for RCE proof ..."
|
||||
sleep 5
|
||||
if [ -f /tmp/exo-rce-proof.txt ]; then
|
||||
echo " VULNERABLE: Remote code executed!"
|
||||
echo " Contents:"
|
||||
cat /tmp/exo-rce-proof.txt
|
||||
else
|
||||
echo " SAFE: /tmp/exo-rce-proof.txt does not exist — remote code was NOT executed"
|
||||
fi
|
||||
|
||||
# Step 4: Clean up — delete instance and custom model
|
||||
echo ""
|
||||
echo "[4] Cleaning up ..."
|
||||
|
||||
# Find and delete any instance for this model
|
||||
INSTANCE_ID=$(curl -s "http://$HOST/state" | python3 -c "
|
||||
import sys, json
|
||||
state = json.load(sys.stdin)
|
||||
for iid, wrapper in state.get('instances', {}).items():
|
||||
for tag, inst in wrapper.items():
|
||||
sa = inst.get('shardAssignments', {})
|
||||
if sa.get('modelId', '') == '$MODEL_ID':
|
||||
print(iid)
|
||||
sys.exit(0)
|
||||
" 2>/dev/null || true)
|
||||
|
||||
if [ -n "$INSTANCE_ID" ]; then
|
||||
echo " Deleting instance $INSTANCE_ID ..."
|
||||
curl -s -X DELETE "http://$HOST/instance/$INSTANCE_ID" >/dev/null
|
||||
echo " Done"
|
||||
else
|
||||
echo " No instance found to delete"
|
||||
fi
|
||||
|
||||
echo " Deleting custom model card ..."
|
||||
curl -s -X DELETE \
|
||||
"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote("'"$MODEL_ID"'", safe=""))')" >/dev/null
|
||||
echo " Done"
|
||||
|
||||
echo ""
|
||||
echo "=== DONE ==="
|
||||
10
uv.lock
generated
10
uv.lock
generated
@@ -378,7 +378,7 @@ dependencies = [
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1025,7 +1025,7 @@ dependencies = [
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1072,8 +1072,8 @@ cuda13 = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.7.dev20260220+13998a05"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }
|
||||
version = "0.30.7.dev20260224+5289547a"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
@@ -1108,7 +1108,7 @@ version = "0.30.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+5289547a", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#5289547ada1cddda2b9716baf6a077a906d02189" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
|
||||
Reference in New Issue
Block a user