Compare commits

..

11 Commits

Author SHA1 Message Date
Alex Cheema
b9c64f94d0 Merge remote-tracking branch 'origin/main' into releases/v1.0.67 2026-01-27 22:29:07 -08:00
Alex Cheema
4f24e33d30 Merge branch 'main' into releases/v1.0.65 2026-01-26 14:01:15 -08:00
Alex Cheema
a9ee2204ef Merge remote-tracking branch 'origin/main' into releases/v1.0.65 2026-01-26 11:59:41 -08:00
Alex Cheema
054b296a51 Merge remote-tracking branch 'origin/main' into releases/v1.0.65 2026-01-26 09:59:18 -08:00
Alex Cheema
281aaeb013 better message for configuring local network macOS app 2026-01-24 06:32:15 -08:00
Alex Cheema
10fdc439a5 Merge branch 'main' into releases/v1.0.65 2026-01-24 05:36:15 -08:00
Alex Cheema
78a8c06d57 Add 30s delay to wait for macOS network setup
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 15:29:22 -08:00
Alex Cheema
4c0c6dcae9 Merge branch 'main' into releases/v1.0.65 2026-01-23 20:48:51 +00:00
Alex Cheema
d885600a4c Merge branch 'main' into releases/v1.0.64 2026-01-23 20:40:26 +00:00
Alex Cheema
55b67e2be2 Merge branch 'main' into releases/v1.0.64 2026-01-23 01:40:20 +00:00
Ryuichi Leo Takashige
30cfad9b68 Use custom fork 2026-01-22 22:19:35 +00:00
9 changed files with 1303 additions and 1346 deletions

View File

@@ -5,7 +5,7 @@
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>

View File

@@ -18,6 +18,9 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 30
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -80,7 +83,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")

View File

@@ -17,8 +17,8 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",

View File

@@ -165,6 +165,7 @@ def mlx_distributed_init(
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)

View File

@@ -1,5 +1,7 @@
import multiprocessing as mp
import socket
from typing import Literal
import time
import typing
import anyio
from fastapi import FastAPI
@@ -9,14 +11,16 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.shared.constants import EXO_MODELS_DIR
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.events import Event
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -32,14 +36,9 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import channel, mp_channel
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
@@ -47,32 +46,36 @@ from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
rdma_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "rdma", "both"]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
mp.set_start_method("spawn", force=True)
logger_setup(None)
iid = InstanceId("im testing here")
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:8000"
cfg.bind = "0.0.0.0:52415"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
app = FastAPI()
app.post("/run_test")(run_test)
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: shutdown.wait(),
)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
async def tb_detection():
@@ -84,19 +87,29 @@ async def tb_detection():
return recv.collect()
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def run_test(test: Tests):
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
@@ -104,44 +117,10 @@ async def run_test(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["rdma", "both"]:
i = jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
return await execute_test(test, ring_instance(test, iid, hn), hn)
def ring_instance(test: Tests, hn: str) -> Instance | None:
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
@@ -156,17 +135,13 @@ def ring_instance(test: Tests, hn: str) -> Instance | None:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
card = MODEL_CARDS[test.model_id]
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=test.model_id,
model_id=ModelId(test.model_id),
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -188,84 +163,119 @@ def ring_instance(test: Tests, hn: str) -> Instance | None:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
async def execute_test(test: Tests, instance: Instance, hn: str):
world_size = len(test.devs)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
max_tokens=50,
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
commands.insert(0,ConnectToGroup(instance_id=iid))
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
world_size = len(test.devs)
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
},
),
)
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
for command in commands:
task_send.send(command)
entrypoint(
bound_instance,
ev_send,
task_recv,
logger,
)
return ev_recv.collect()
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
world_size = len(test.devs)
assert test.rdma_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=test.rdma_devs,
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(host[0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i, host in enumerate(test.devs)
},
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
logger,
),
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()
time.sleep(0.1)
return (runner_process, ev_recv, task_send)
if __name__ == "__main__":
anyio.run(main)

View File

@@ -1,27 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -eq 0 ] && { echo "Usage: $0 host1 [host2 ...]"; exit 1; }
[ -z "$(git status --porcelain)" ] || { echo "Uncommitted changes"; exit 1; }
commit=$(git rev-parse HEAD)
git fetch -q
git branch -r --contains "$commit" | grep -q origin || { echo "Not pushed to origin"; exit 1; }
echo "Deploying $commit to $# hosts..."
for host; do
ssh -o BatchMode=yes -o ConnectTimeout=10 -o ServerAliveInterval=30 "$host@$host" "
cd exo &&
git fetch -q &&
git checkout -q $commit &&
exec uv run tests/headless_runner.py
" &
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:8000/models" >/dev/null 2>&1; do sleep 1; done
done
uv run tests/start_distributed_test.py "$@"

View File

@@ -1,81 +0,0 @@
#!/usr/bin/env python3
import itertools
import json
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, cast
from urllib.request import Request, urlopen
if not (args := sys.argv[1:]):
sys.exit(
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be rdma or ring"
)
kind = args[0] if args[0] in ("rdma", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]
devs = [[h, ip[h]] for h in hosts]
n = len(hosts)
def get_tb(a: str) -> list[dict[str, Any]]:
with urlopen(f"http://{a}:8000/tb_detection", timeout=5) as r:# pyright: ignore[reportAny]
return json.loads(r.read()) # pyright: ignore[reportAny]
def get_models(a: str) -> set[str]:
with urlopen(f"http://{a}:8000/models", timeout=5) as r: # pyright: ignore[reportAny]
return set(json.loads(r.read())) # pyright: ignore[reportAny]
def run(h: str, a: str, body: bytes) -> None:
with urlopen(
Request(
f"http://{a}:8000/run_test",
data=body,
method="POST",
headers={"Content-Type": "application/json"},
),
timeout=300,
) as r: # pyright: ignore[reportAny]
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
print(f"\n{h}@{a}: {line}", flush=True)
with ThreadPoolExecutor(n) as exctr:
if kind in ("rdma", "both"):
payloads = list(exctr.map(get_tb, ips))
u2e = {
ident["domainUuid"]: (i, ident["rdmaInterface"])
for i, p in enumerate(payloads)
for d in p
for ident in cast(
list[dict[str, str]],
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
)
}
edges = {
(u2e[s][0], u2e[t][0]): u2e[t][1]
for p in payloads
for d in p
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
}
rdma_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
else:
rdma_devs = None
models = set[str].intersection(*exctr.map(get_models, ips))
print(f"Starting test with {models}")
for model in models:
body = json.dumps(
{"devs": devs, "model_id": model, "rdma_devs": rdma_devs, "kind": kind}
).encode()
list(exctr.map(run, hosts, ips, itertools.repeat(body)))

56
tests/start_distributed_test.sh Executable file
View File

@@ -0,0 +1,56 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done

2181
uv.lock generated
View File

File diff suppressed because it is too large Load Diff