mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-26 23:10:01 -05:00
Compare commits
1 Commits
rust-explo
...
improve-di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc09ba01e4 |
@@ -7,6 +7,7 @@ from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.tests.patches import load_null_model
|
||||
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
@@ -16,6 +17,8 @@ def entrypoint(
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
*,
|
||||
_load_null_models: bool = False,
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
@@ -29,6 +32,13 @@ def entrypoint(
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
p = None
|
||||
if _load_null_models:
|
||||
from unittest.mock import patch
|
||||
|
||||
p = patch("mlx_lm.utils.load_model", new=load_null_model)
|
||||
p.start()
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
@@ -52,6 +62,8 @@ def entrypoint(
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if p is not None:
|
||||
p.stop()
|
||||
try:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
|
||||
50
src/exo/worker/tests/patches.py
Normal file
50
src/exo/worker/tests/patches.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# type: ignore
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from exo.worker.engines.mlx import Model
|
||||
|
||||
|
||||
def load_null_model(path: Path, **_: object) -> "tuple[Model, dict[str, Any]]":
|
||||
with open(path / "config.json", "r") as f:
|
||||
cfg = json.load(f)
|
||||
model, args = _get_classes(cfg)
|
||||
model = model(args.from_dict(cfg))
|
||||
return model, cfg
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
||||
try:
|
||||
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
|
||||
except ImportError:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
raise ValueError(msg) from None
|
||||
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
MODEL_REMAPPING = {
|
||||
"mistral": "llama",
|
||||
"llava": "mistral3",
|
||||
"phi-msft": "phixtral",
|
||||
"falcon_mamba": "mamba",
|
||||
"kimi_k2": "deepseek_v3",
|
||||
"qwen2_5_vl": "qwen2_vl",
|
||||
"minimax_m2": "minimax",
|
||||
"iquestcoder": "llama",
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
import multiprocessing as mp
|
||||
import socket
|
||||
import time
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
from fastapi import FastAPI
|
||||
@@ -11,16 +10,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -36,18 +31,17 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.runners import RunnerFailed, RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
MODEL_CARDS = {"haha": MODEL_CARDS["qwen3-coder-480b-a35b-8bit"]}
|
||||
|
||||
class Tests(BaseModel):
|
||||
# list[hostname, ip addr]
|
||||
devs: list[list[str]]
|
||||
model_id: str
|
||||
kind: typing.Literal["init", "warmup", "inference"]
|
||||
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
@@ -56,16 +50,14 @@ logger_setup(None)
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
await assert_downloads()
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52415"
|
||||
cfg.bind = "0.0.0.0:8000"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = "-"
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
app = FastAPI()
|
||||
app.post("/ring")(ring_backend)
|
||||
app.post("/jaccl")(jaccl_backend)
|
||||
app.post("/run_test")(run_test)
|
||||
app.post("/tb_detection")(tb_detection)
|
||||
shutdown = anyio.Event()
|
||||
await serve(
|
||||
@@ -87,28 +79,7 @@ async def tb_detection():
|
||||
return recv.collect()
|
||||
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
async def run_test(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
@@ -117,10 +88,30 @@ async def ring_backend(test: Tests):
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, ring_instance(test, iid, hn), hn)
|
||||
|
||||
async def run():
|
||||
for card in MODEL_CARDS.values():
|
||||
for instance in (
|
||||
ring_instance(test, card.model_id, iid, hn),
|
||||
jaccl_instance(test, card.model_id, iid),
|
||||
):
|
||||
recv = await execute_test(test, instance, hn)
|
||||
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
yield item.model_dump_json() + "\n"
|
||||
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
||||
item.runner_status, RunnerFailed
|
||||
):
|
||||
return
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(run())
|
||||
|
||||
|
||||
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
def ring_instance(test: Tests, model_id: ModelId, iid: InstanceId, hn: str) -> Instance:
|
||||
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
@@ -135,13 +126,13 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52416,
|
||||
hosts_by_node={NodeId(hn): hbn},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
model_id=model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
||||
@@ -163,7 +154,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str) -> MpReceiver[Event]:
|
||||
world_size = len(test.devs)
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
_handle, recv, send = new_runner(instance, hn)
|
||||
@@ -171,60 +162,33 @@ async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
send.send(ConnectToGroup(instance_id=iid))
|
||||
send.send(LoadModel(instance_id=iid))
|
||||
|
||||
match test.kind:
|
||||
case "init":
|
||||
pass
|
||||
case "warmup":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
case "inference":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
for card in MODEL_CARDS.values():
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=card.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
)
|
||||
|
||||
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
|
||||
|
||||
async def map_recv():
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
yield item.model_dump_json() + "\n"
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
ret = StreamingResponse(map_recv())
|
||||
ret._pls_dont_gc = _handle # type: ignore
|
||||
return ret
|
||||
return recv
|
||||
|
||||
|
||||
async def jaccl_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid), hn)
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
def jaccl_instance(test: Tests, model_id: ModelId, iid: InstanceId):
|
||||
card = next(card for card in MODEL_CARDS.values() if card.model_id == model_id)
|
||||
world_size = len(test.devs)
|
||||
|
||||
return MlxJacclInstance(
|
||||
@@ -235,7 +199,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
|
||||
},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
model_id=model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): TensorShardMetadata(
|
||||
@@ -270,6 +234,7 @@ def new_runner(
|
||||
task_recv,
|
||||
logger,
|
||||
),
|
||||
kwargs={"_load_null_models": True},
|
||||
)
|
||||
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
|
||||
runner_process.start()
|
||||
|
||||
@@ -6,19 +6,8 @@ query() {
|
||||
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
|
||||
}
|
||||
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "USAGE: $0 <test kind> [host1] [host2] ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
kind=$1
|
||||
shift
|
||||
|
||||
test_kinds="ring jaccl"
|
||||
|
||||
if ! echo "$test_kinds" | grep -q "$kind"; then
|
||||
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
|
||||
if [[ $# -lt 1 ]]; then
|
||||
echo "USAGE: $0 [host1] [host2] ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -34,23 +23,12 @@ done
|
||||
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:8000/run_test" \
|
||||
-H "Content-Type: application/json" -d "{\"devs\": ${devs}}" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user