mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-24 18:28:30 -05:00
Compare commits
8 Commits
leo/test-b
...
leo/fix-st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa788228bc | ||
|
|
9d3b1334da | ||
|
|
811a4d80bd | ||
|
|
2fe689315b | ||
|
|
644c5573ce | ||
|
|
12c3015f52 | ||
|
|
365dd68d9a | ||
|
|
d3d129581e |
@@ -75,7 +75,7 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
model_id,
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken"],
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -412,7 +412,7 @@
|
||||
<div>{col.label}</div>
|
||||
{#if col.diskAvailable != null}
|
||||
<div
|
||||
class="text-[9px] text-exo-light-gray/60 normal-case tracking-normal mt-0.5"
|
||||
class="text-[9px] text-white/70 normal-case tracking-normal mt-0.5"
|
||||
>
|
||||
{formatBytes(col.diskAvailable)} free
|
||||
</div>
|
||||
@@ -436,7 +436,7 @@
|
||||
</div>
|
||||
{#if row.prettyName}
|
||||
<div
|
||||
class="text-[10px] text-exo-light-gray/60"
|
||||
class="text-[10px] text-white/60"
|
||||
title={row.modelId}
|
||||
>
|
||||
{row.modelId}
|
||||
@@ -450,7 +450,7 @@
|
||||
title="View model details"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4 text-white/30 hover:text-white/60"
|
||||
class="w-4 h-4 text-white/60 hover:text-white/80"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
@@ -483,12 +483,12 @@
|
||||
clip-rule="evenodd"
|
||||
></path>
|
||||
</svg>
|
||||
<span class="text-xs text-exo-light-gray/70"
|
||||
<span class="text-xs text-white/70"
|
||||
>{formatBytes(cell.totalBytes)}</span
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5 cursor-pointer"
|
||||
class="text-white/50 hover:text-red-400 transition-colors mt-0.5 cursor-pointer"
|
||||
onclick={() =>
|
||||
deleteDownload(col.nodeId, row.modelId)}
|
||||
title="Delete from this node"
|
||||
@@ -530,7 +530,7 @@
|
||||
).toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<span class="text-[10px] text-exo-light-gray/60"
|
||||
<span class="text-[10px] text-white/70"
|
||||
>{formatSpeed(cell.speed)}</span
|
||||
>
|
||||
</div>
|
||||
@@ -542,7 +542,7 @@
|
||||
: "Download pending"}
|
||||
>
|
||||
{#if cell.downloaded > 0 && cell.total > 0}
|
||||
<span class="text-exo-light-gray/70 text-xs"
|
||||
<span class="text-white/70 text-xs"
|
||||
>{formatBytes(cell.downloaded)} / {formatBytes(
|
||||
cell.total,
|
||||
)}</span
|
||||
@@ -561,7 +561,7 @@
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Resume download on this node"
|
||||
@@ -581,14 +581,13 @@
|
||||
</svg>
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-exo-light-gray/40 text-[10px]"
|
||||
>paused</span
|
||||
<span class="text-white/50 text-[10px]">paused</span
|
||||
>
|
||||
{/if}
|
||||
{:else if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Start download on this node"
|
||||
@@ -608,8 +607,7 @@
|
||||
</svg>
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-exo-light-gray/50 text-sm">...</span
|
||||
>
|
||||
<span class="text-white/40 text-sm">...</span>
|
||||
{/if}
|
||||
</div>
|
||||
{:else if cell.kind === "failed"}
|
||||
@@ -631,7 +629,7 @@
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Retry download on this node"
|
||||
@@ -663,7 +661,7 @@
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer"
|
||||
class="text-white/50 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer"
|
||||
onclick={() =>
|
||||
startDownload(col.nodeId, row.shardMetadata!)}
|
||||
title="Download to this node"
|
||||
|
||||
@@ -41,7 +41,7 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260220+13998a05"; in
|
||||
version = let v = "0.30.7.dev20260224+e862b122"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
@@ -49,8 +49,8 @@ let
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "13998a054715edcdc93618fb1496c79c7c25ff7c";
|
||||
hash = "sha256-fAqA3hFwNBx7FcoGnhQsIFpAIRbC2EerACm4Fvne0Cc=";
|
||||
rev = "e862b1223a2310d4cc8df1135aed42f5246bc50a";
|
||||
hash = "sha256-GosFIWxIB48Egb1MqJrR3xhsUsQeWdRk5rV93USY6wQ=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -14,10 +14,12 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -456,3 +458,117 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
else:
|
||||
ip_part = coordinator.split(":")[0]
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
|
||||
def _make_task(
|
||||
instance_id: InstanceId,
|
||||
status: TaskStatus = TaskStatus.Running,
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(),
|
||||
task_status=status,
|
||||
instance_id=instance_id,
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId("test-model"),
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_running_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Running)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – cancellation event should come before the deletion event
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
assert events[1].instance_id == instance_id
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_pending_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Pending)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_ignores_completed_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
tasks = {
|
||||
t.task_id: t
|
||||
for t in [
|
||||
_make_task(instance_id, TaskStatus.Complete),
|
||||
_make_task(instance_id, TaskStatus.Failed),
|
||||
_make_task(instance_id, TaskStatus.TimedOut),
|
||||
_make_task(instance_id, TaskStatus.Cancelled),
|
||||
]
|
||||
}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only the InstanceDeleted event, no cancellations
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id_a = InstanceId()
|
||||
instance_id_b = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {
|
||||
instance_id_a: instance,
|
||||
instance_id_b: instance,
|
||||
}
|
||||
# only delete instance A, keep instance B
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
|
||||
|
||||
task_a = _make_task(instance_id_a, TaskStatus.Running)
|
||||
task_b = _make_task(instance_id_b, TaskStatus.Running)
|
||||
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only task_a should be cancelled
|
||||
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
|
||||
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
|
||||
assert len(cancel_events) == 1
|
||||
assert cancel_events[0].task_id == task_a.task_id
|
||||
assert cancel_events[0].task_status == TaskStatus.Cancelled
|
||||
assert len(delete_events) == 1
|
||||
assert delete_events[0].instance_id == instance_id_a
|
||||
|
||||
@@ -852,6 +852,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if hasattr(layer, "linear_attn"):
|
||||
|
||||
@@ -51,11 +51,33 @@ def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _try_parse_json(v: str) -> str | dict[str, Any] | list[Any]:
|
||||
stripped = v.strip()
|
||||
if (stripped.startswith("[") and stripped.endswith("]")) or (
|
||||
stripped.startswith("{") and stripped.endswith("}")
|
||||
):
|
||||
try:
|
||||
parsed: dict[str, Any] | list[Any] = json.loads(stripped) # pyright: ignore[reportAny]
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
return v
|
||||
|
||||
|
||||
def _flatten(p: dict[str, Any]) -> dict[str, str]:
|
||||
return {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
|
||||
for k, v in p.items() # pyright: ignore[reportAny]
|
||||
}
|
||||
result: dict[str, str] = {}
|
||||
for k, v in p.items(): # pyright: ignore[reportAny]
|
||||
if isinstance(v, dict):
|
||||
resolved: dict[str, Any] = {
|
||||
str(ik): _try_parse_json(str(iv)) if isinstance(iv, str) else iv # pyright: ignore[reportUnknownArgumentType]
|
||||
for ik, iv in v.items() # pyright: ignore[reportUnknownVariableType]
|
||||
}
|
||||
result[k] = json.dumps(resolved)
|
||||
elif isinstance(v, list):
|
||||
result[k] = json.dumps(v)
|
||||
else:
|
||||
result[k] = str(v) # pyright: ignore[reportAny]
|
||||
return result
|
||||
|
||||
|
||||
json_tool_parser = ToolParser(
|
||||
|
||||
@@ -8,7 +8,7 @@ from urllib.request import urlopen
|
||||
|
||||
h = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f"USAGE: {sys.argv[0]} host")
|
||||
ts = subprocess.run(
|
||||
["/Applications/Tailscale.app/Contents/MacOS/Tailscale", "status"], check=True, text=True, capture_output=True
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = next(
|
||||
(sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None
|
||||
|
||||
@@ -15,7 +15,7 @@ if not (args := sys.argv[1:]):
|
||||
kind = args[0] if args[0] in ("jaccl", "ring") else "both"
|
||||
hosts = args[1:] if kind != "both" else args
|
||||
ts = subprocess.run(
|
||||
["/Applications/Tailscale.app/Contents/MacOS/Tailscale", "status"], check=True, text=True, capture_output=True
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
|
||||
ips = [ip[h] for h in hosts]
|
||||
|
||||
10
uv.lock
generated
10
uv.lock
generated
@@ -378,7 +378,7 @@ dependencies = [
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+e862b122", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#e862b1223a2310d4cc8df1135aed42f5246bc50a" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1025,7 +1025,7 @@ dependencies = [
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+e862b122", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#e862b1223a2310d4cc8df1135aed42f5246bc50a" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1072,8 +1072,8 @@ cuda13 = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.7.dev20260220+13998a05"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }
|
||||
version = "0.30.7.dev20260224+e862b122"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#e862b1223a2310d4cc8df1135aed42f5246bc50a" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
@@ -1108,7 +1108,7 @@ version = "0.30.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+13998a05", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#13998a054715edcdc93618fb1496c79c7c25ff7c" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260224+e862b122", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#e862b1223a2310d4cc8df1135aed42f5246bc50a" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
|
||||
Reference in New Issue
Block a user