Compare commits

..

3 Commits

Author SHA1 Message Date
Evan
16d2252b3f yay 2026-01-28 12:38:53 +00:00
Alex Cheema
f1a2d054ec Update tagline to "Run frontier AI locally" (#1313)
- Update README tagline from "Run your own AI cluster at home with
everyday devices" to "Run frontier AI locally"
2026-01-28 12:38:14 +00:00
Alex Cheema
b3c8f85fc8 Update MLX to 0.30.4 (#1311)
## Summary
- Bump mlx from 0.30.3 to 0.30.4

## Test plan
- [x] `uv lock` succeeds
- [x] Type checking passes (`uv run basedpyright`)
- [x] Run inference tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 04:30:21 -08:00
8 changed files with 1247 additions and 1252 deletions

View File

@@ -5,7 +5,7 @@
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>

View File

@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 30
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -83,7 +80,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")

View File

@@ -17,8 +17,8 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",

View File

@@ -165,7 +165,6 @@ def mlx_distributed_init(
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)

View File

@@ -1,7 +1,6 @@
import multiprocessing as mp
import socket
import time
import typing
import anyio
from fastapi import FastAPI
@@ -11,16 +10,14 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -36,7 +33,12 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
@@ -46,8 +48,7 @@ from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
model_id: ModelId
mp.set_start_method("spawn", force=True)
@@ -56,26 +57,20 @@ logger_setup(None)
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
cfg.bind = "0.0.0.0:8000"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
app.post("/run_test")(run_test)
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: shutdown.wait(),
)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
async def tb_detection():
@@ -87,29 +82,20 @@ async def tb_detection():
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
async def run_test(test: Tests):
iid = InstanceId((str(test.devs)))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
@@ -117,10 +103,39 @@ async def ring_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def run():
logger.info(f"testing {test.model_id}")
for instance in (
ring_instance(test, iid, hn),
jaccl_instance(test, iid),
):
if instance is None:
continue
recv = await execute_test(test, instance, hn)
str_out = ""
with recv:
try:
async for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
return StreamingResponse(run())
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance | None:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
@@ -135,13 +150,17 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -163,68 +182,43 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str):
async def execute_test(test: Tests, instance: Instance, hn: str) -> MpReceiver[Event]:
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
iid = InstanceId(str(test.devs))
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
max_tokens=50,
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
return recv
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
def jaccl_instance(test: Tests, iid: InstanceId) -> MlxJacclInstance | None:
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
world_size = len(test.devs)
return MlxJacclInstance(
@@ -235,18 +229,18 @@ def jaccl_instance(test: Tests, iid: InstanceId):
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
RunnerId(host[0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
for i, host in enumerate(test.devs)
},
),
)

56
tests/start_distributed_test.py Executable file
View File

@@ -0,0 +1,56 @@
#!/usr/bin/env python3
import itertools
import json
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from urllib.request import Request, urlopen
hosts = sys.argv[1:]
if not hosts:
sys.exit(f"USAGE: {sys.argv[0]} [host1] [host2] ...")
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {line.split()[1]: line.split()[0] for line in ts if len(line.split()) >= 2}
ips = [ip[h] for h in hosts]
devs = [[h, ip[h]] for h in hosts]
def get_models(a: str) -> set[str]:
try:
r = urlopen(f"http://{a}:8000/models", timeout=5) # pyright: ignore[reportAny]
return set(json.loads(r.read())) # pyright: ignore[reportAny]
except Exception:
return set()
with ThreadPoolExecutor(len(ips)) as ex:
models = set[str].intersection(*ex.map(get_models, ips))
def run(h: str, a: str, body: bytes):
try:
r = urlopen( # pyright: ignore[reportAny]
Request(
f"http://{a}:8000/run_test",
data=body,
method="POST",
headers={"Content-Type": "application/json"},
),
timeout=300,
)
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
print(f"\n{h}@{a}: {line}", flush=True)
except Exception:
print(f"\n{h}@{a}: request to {h} failed", flush=True)
raise
print(f"Starting test with {models}")
for model in models:
body = json.dumps({"devs": devs, "model_id": model}).encode()
with ThreadPoolExecutor(max_workers=len(hosts)) as ex:
for _ in ex.map(run, hosts, ips, itertools.repeat(body)):
pass

View File

@@ -1,56 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done

2181
uv.lock generated
View File

File diff suppressed because it is too large Load Diff