Compare commits

...

5 Commits

Author SHA1 Message Date
Sami Khan
87098c69c1 Merge main into fix/1025 2026-02-03 11:56:00 +05:00
Evan Quiney
f400b4d7c5 fix InstanceViewModel.swift (#1359)
wasn't caught when we merged the API changes
2026-02-02 18:43:27 +00:00
Evan Quiney
d97bca88e6 improve distributed testing (#1300)
Our distributed test now does a full query cycle for every model loaded
onto the relevant machine. This will help find bugs early, as it already
has found one with Qwen3 Next! I didn't write down what the error was
though. Gooooooood luck with that!

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-02 18:25:39 +00:00
Sami Khan
851cf65bfa removed animation 2026-01-06 07:07:03 +05:00
Sami Khan
49c66f139f fix: resolve issue #1025 2026-01-02 10:00:52 +05:00
8 changed files with 1252 additions and 1233 deletions

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
"revision" : "afea2cda87819c960114f26e26f369a1a0945b17",
"version" : "2.9.0-beta.2"
}
}
],

View File

@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
let promptPreview: String?
let errorMessage: String?
let subtitle: String?
let parameters: ChatCompletionTaskParameters?
let parameters: TextGenerationTaskParameters?
var title: String {
switch kind {

View File

@@ -1003,16 +1003,5 @@
</div>
<style>
@keyframes pulse-slow {
0%,
100% {
opacity: 0.8;
}
50% {
opacity: 1;
}
}
.animate-pulse-slow {
animation: pulse-slow 1.5s ease-in-out infinite;
}
/* Styles removed - animations were causing GPU overhead */
</style>

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,25 +1,20 @@
import multiprocessing as mp
import socket
import time
import typing
from typing import Literal
import anyio
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.responses import Response, StreamingResponse
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -36,9 +31,14 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.channels import channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
@@ -46,36 +46,37 @@ from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
ibv_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "jaccl", "both"]
mp.set_start_method("spawn", force=True)
logger_setup(None)
iid = InstanceId("im testing here")
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
cfg.bind = "0.0.0.0:52414"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
ev = anyio.Event()
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
app.post("/run_test")(run_test)
app.post("/kill")(lambda: kill(ev))
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: shutdown.wait(),
shutdown_trigger=lambda: ev.wait(),
)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
def kill(ev: anyio.Event):
ev.set()
return Response(status_code=204)
async def tb_detection():
@@ -87,29 +88,19 @@ async def tb_detection():
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
async def run_test(test: Tests):
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
@@ -117,31 +108,67 @@ async def ring_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["rdma", "both"]:
i = jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
if i + 1 < len(test.devs):
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
hbn[i] = Host(ip="0.0.0.0", port=52417)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
ephemeral_port=52417,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -163,113 +190,79 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str):
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,
)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
if world_size > 1:
commands.insert(0, ConnectToGroup(instance_id=iid))
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, _ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
for command in commands:
task_send.send(command)
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
entrypoint(
bound_instance,
ev_send,
task_recv,
logger,
)
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
# TODO(evan): return ev_recv.collect()
return []
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
world_size = len(test.devs)
assert test.ibv_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
jaccl_devices=test.ibv_devs,
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
RunnerId(host[0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
for i, host in enumerate(test.devs)
},
),
)
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
logger,
),
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()
time.sleep(0.1)
return (runner_process, ev_recv, task_send)
if __name__ == "__main__":
anyio.run(main)

54
tests/run_exo_on.sh Executable file
View File

@@ -0,0 +1,54 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
done
wait

85
tests/start_distributed_test.py Executable file
View File

@@ -0,0 +1,85 @@
#!/usr/bin/env python3
import itertools
import json
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, cast
from urllib.request import Request, urlopen
if not (args := sys.argv[1:]):
sys.exit(
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be jaccl or ring"
)
kind = args[0] if args[0] in ("jaccl", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]
devs = [[h, ip[h]] for h in hosts]
n = len(hosts)
def get_tb(a: str) -> list[dict[str, Any]]:
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
return json.loads(r.read()) # pyright: ignore[reportAny]
def get_models(a: str) -> set[str]:
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
return set(json.loads(r.read())) # pyright: ignore[reportAny]
def run(h: str, a: str, body: bytes) -> None:
with urlopen(
Request(
f"http://{a}:52414/run_test",
data=body,
method="POST",
headers={"Content-Type": "application/json"},
),
timeout=300,
) as r: # pyright: ignore[reportAny]
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
print(f"\n{h}@{a}: {line}", flush=True)
with ThreadPoolExecutor(n) as exctr:
if kind in ("jaccl", "both"):
payloads = list(exctr.map(get_tb, ips))
u2e = {
ident["domainUuid"]: (i, ident["rdmaInterface"])
for i, p in enumerate(payloads)
for d in p
for ident in cast(
list[dict[str, str]],
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
)
}
edges = {
(u2e[s][0], u2e[t][0]): u2e[t][1]
for p in payloads
for d in p
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
}
ibv_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
else:
ibv_devs = None
models = set[str].intersection(*exctr.map(get_models, ips))
print("\n")
print("=" * 70)
print(f"Starting test with {models}")
print("=" * 70)
print("\n")
for model in models:
body = json.dumps(
{"devs": devs, "model_id": model, "ibv_devs": ibv_devs, "kind": kind}
).encode()
list(exctr.map(run, hosts, ips, itertools.repeat(body)))

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done