mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-27 15:33:26 -05:00
Compare commits
2 Commits
leo/add-lo
...
improve-di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93ce089ac4 | ||
|
|
991d278119 |
2
justfile
2
justfile
@@ -1,7 +1,7 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
treefmt || nix fmt
|
||||
|
||||
lint:
|
||||
uv run ruff check --fix
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import multiprocessing as mp
|
||||
import socket
|
||||
import time
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
from fastapi import FastAPI
|
||||
@@ -11,16 +10,14 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -36,7 +33,12 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
@@ -46,8 +48,7 @@ from exo.worker.runner.bootstrap import entrypoint
|
||||
class Tests(BaseModel):
|
||||
# list[hostname, ip addr]
|
||||
devs: list[list[str]]
|
||||
model_id: str
|
||||
kind: typing.Literal["init", "warmup", "inference"]
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
@@ -56,26 +57,20 @@ logger_setup(None)
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
await assert_downloads()
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52415"
|
||||
cfg.bind = "0.0.0.0:8000"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = "-"
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
app = FastAPI()
|
||||
app.post("/ring")(ring_backend)
|
||||
app.post("/jaccl")(jaccl_backend)
|
||||
app.post("/tb_detection")(tb_detection)
|
||||
shutdown = anyio.Event()
|
||||
app.post("/run_test")(run_test)
|
||||
app.get("/tb_detection")(tb_detection)
|
||||
app.get("/models")(list_models)
|
||||
await serve(
|
||||
app, # type: ignore
|
||||
cfg,
|
||||
shutdown_trigger=lambda: shutdown.wait(),
|
||||
)
|
||||
await anyio.sleep_forever()
|
||||
# gracefully shutdown the api
|
||||
shutdown.set()
|
||||
|
||||
|
||||
async def tb_detection():
|
||||
@@ -87,29 +82,20 @@ async def tb_detection():
|
||||
return recv.collect()
|
||||
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
def list_models():
|
||||
sent = set[str]()
|
||||
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
|
||||
if "--" not in path.parent.name:
|
||||
continue
|
||||
name = path.parent.name.replace("--", "/")
|
||||
if name in sent:
|
||||
continue
|
||||
sent.add(name)
|
||||
yield ModelId(path.parent.name.replace("--", "/"))
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
async def run_test(test: Tests):
|
||||
iid = InstanceId((str(test.devs)))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
@@ -117,10 +103,39 @@ async def ring_backend(test: Tests):
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, ring_instance(test, iid, hn), hn)
|
||||
|
||||
async def run():
|
||||
logger.info(f"testing {test.model_id}")
|
||||
|
||||
for instance in (
|
||||
ring_instance(test, iid, hn),
|
||||
jaccl_instance(test, iid),
|
||||
):
|
||||
if instance is None:
|
||||
continue
|
||||
recv = await execute_test(test, instance, hn)
|
||||
|
||||
str_out = ""
|
||||
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
if isinstance(item, ChunkGenerated):
|
||||
assert isinstance(item.chunk, TokenChunk)
|
||||
str_out += item.chunk.text
|
||||
|
||||
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
||||
item.runner_status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
yield str_out + "\n"
|
||||
yield item.model_dump_json() + "\n"
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(run())
|
||||
|
||||
|
||||
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance | None:
|
||||
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
@@ -135,13 +150,17 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
|
||||
)
|
||||
if card is None:
|
||||
return None
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52416,
|
||||
hosts_by_node={NodeId(hn): hbn},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
model_id=test.model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
||||
@@ -163,68 +182,43 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str) -> MpReceiver[Event]:
|
||||
world_size = len(test.devs)
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
iid = InstanceId(str(test.devs))
|
||||
_handle, recv, send = new_runner(instance, hn)
|
||||
if world_size > 1:
|
||||
send.send(ConnectToGroup(instance_id=iid))
|
||||
send.send(LoadModel(instance_id=iid))
|
||||
|
||||
match test.kind:
|
||||
case "init":
|
||||
pass
|
||||
case "warmup":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
case "inference":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
)
|
||||
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
max_tokens=50,
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
)
|
||||
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
|
||||
|
||||
async def map_recv():
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
yield item.model_dump_json() + "\n"
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
ret = StreamingResponse(map_recv())
|
||||
ret._pls_dont_gc = _handle # type: ignore
|
||||
return ret
|
||||
return recv
|
||||
|
||||
|
||||
async def jaccl_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid), hn)
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
def jaccl_instance(test: Tests, iid: InstanceId) -> MlxJacclInstance | None:
|
||||
card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
|
||||
)
|
||||
if card is None:
|
||||
return None
|
||||
world_size = len(test.devs)
|
||||
|
||||
return MlxJacclInstance(
|
||||
@@ -235,18 +229,18 @@ def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
|
||||
},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
model_id=test.model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): TensorShardMetadata(
|
||||
RunnerId(host[0]): TensorShardMetadata(
|
||||
model_card=card,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=card.n_layers,
|
||||
start_layer=0,
|
||||
end_layer=card.n_layers,
|
||||
n_layers=card.n_layers,
|
||||
)
|
||||
for i in range(world_size)
|
||||
for i, host in enumerate(test.devs)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
56
tests/start_distributed_test.py
Executable file
56
tests/start_distributed_test.py
Executable file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
import itertools
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
hosts = sys.argv[1:]
|
||||
if not hosts:
|
||||
sys.exit(f"USAGE: {sys.argv[0]} [host1] [host2] ...")
|
||||
|
||||
ts = subprocess.run(
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = {line.split()[1]: line.split()[0] for line in ts if len(line.split()) >= 2}
|
||||
ips = [ip[h] for h in hosts]
|
||||
devs = [[h, ip[h]] for h in hosts]
|
||||
|
||||
|
||||
def get_models(a: str) -> set[str]:
|
||||
try:
|
||||
r = urlopen(f"http://{a}:8000/models", timeout=5) # pyright: ignore[reportAny]
|
||||
return set(json.loads(r.read())) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
with ThreadPoolExecutor(len(ips)) as ex:
|
||||
models = set[str].intersection(*ex.map(get_models, ips))
|
||||
|
||||
|
||||
def run(h: str, a: str, body: bytes):
|
||||
try:
|
||||
r = urlopen( # pyright: ignore[reportAny]
|
||||
Request(
|
||||
f"http://{a}:8000/run_test",
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
),
|
||||
timeout=300,
|
||||
)
|
||||
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
|
||||
print(f"\n{h}@{a}: {line}", flush=True)
|
||||
except Exception:
|
||||
print(f"\n{h}@{a}: request to {h} failed", flush=True)
|
||||
raise
|
||||
|
||||
|
||||
print(f"Starting test with {models}")
|
||||
for model in models:
|
||||
body = json.dumps({"devs": devs, "model_id": model}).encode()
|
||||
with ThreadPoolExecutor(max_workers=len(hosts)) as ex:
|
||||
for _ in ex.map(run, hosts, ips, itertools.repeat(body)):
|
||||
pass
|
||||
@@ -1,56 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
query() {
|
||||
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
|
||||
}
|
||||
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "USAGE: $0 <test kind> [host1] [host2] ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
kind=$1
|
||||
shift
|
||||
|
||||
test_kinds="ring jaccl"
|
||||
|
||||
if ! echo "$test_kinds" | grep -q "$kind"; then
|
||||
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
hostnames=("$@")
|
||||
weaved=()
|
||||
ips=()
|
||||
for name in "${hostnames[@]}"; do
|
||||
ip=$(query "$name")
|
||||
ips+=("$ip")
|
||||
weaved+=("$name" "$ip")
|
||||
done
|
||||
|
||||
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user