mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 11:43:17 -05:00
## Motivation With the addition of the Responses API, we introduced `str | list[InputMessage]` as the type for `TextGenerationTaskParams.input` since the Responses API supports sending input as a plain string. But there was no reason to leak that flexibility past the API adapter boundary — it just meant every downstream consumer had to do `if isinstance(messages, str):` checks, adding complexity for no benefit. ## Changes - Changed `TextGenerationTaskParams.input` from `str | list[InputMessage]` to `list[InputMessage]` - Each API adapter (Chat Completions, Claude Messages, Responses) now normalizes to `list[InputMessage]` at the boundary - Removed `isinstance(task_params.input, str)` branches in `utils_mlx.py` and `runner.py` - Wrapped string inputs in `[InputMessage(role="user", content=...)]` in the warmup path and all test files ## Why It Works The API adapters are the only place where we deal with raw user input formats. By normalizing there, all downstream code (worker, runner, MLX engine) can just assume `list[InputMessage]` and skip the type-checking branches. The type system (`basedpyright`) catches any missed call sites at compile time. ## Test Plan ### Automated Testing - `uv run basedpyright` — 0 errors - `uv run ruff check` — passes - `nix fmt` — applied - `uv run pytest` — 174 passed, 1 skipped Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
265 lines
8.0 KiB
Python
265 lines
8.0 KiB
Python
import socket
|
|
from typing import Literal
|
|
|
|
import anyio
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from hypercorn import Config
|
|
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
|
from loguru import logger
|
|
from pydantic import BaseModel
|
|
|
|
from exo.shared.constants import EXO_MODELS_DIR
|
|
from exo.shared.models.model_cards import ModelCard, ModelId
|
|
from exo.shared.types.chunks import TokenChunk
|
|
from exo.shared.types.commands import CommandId
|
|
from exo.shared.types.common import Host, NodeId
|
|
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
|
|
from exo.shared.types.tasks import (
|
|
ConnectToGroup,
|
|
LoadModel,
|
|
Shutdown,
|
|
StartWarmup,
|
|
Task,
|
|
TextGeneration,
|
|
)
|
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
|
from exo.shared.types.worker.instances import (
|
|
BoundInstance,
|
|
Instance,
|
|
InstanceId,
|
|
MlxJacclInstance,
|
|
MlxRingInstance,
|
|
)
|
|
from exo.shared.types.worker.runners import (
|
|
RunnerFailed,
|
|
RunnerId,
|
|
RunnerShutdown,
|
|
ShardAssignments,
|
|
)
|
|
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
|
from exo.utils.channels import channel, mp_channel
|
|
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
|
from exo.worker.runner.bootstrap import entrypoint
|
|
|
|
|
|
class Tests(BaseModel):
|
|
# list[hostname, ip addr]
|
|
devs: list[list[str]]
|
|
ibv_devs: list[list[str | None]] | None
|
|
model_id: ModelId
|
|
kind: Literal["ring", "jaccl", "both"]
|
|
|
|
|
|
iid = InstanceId("im testing here")
|
|
|
|
|
|
async def main():
|
|
logger.info("starting cool server majig")
|
|
cfg = Config()
|
|
cfg.bind = "0.0.0.0:52414"
|
|
# nb: shared.logging needs updating if any of this changes
|
|
cfg.accesslog = "-"
|
|
cfg.errorlog = "-"
|
|
ev = anyio.Event()
|
|
app = FastAPI()
|
|
app.post("/run_test")(run_test)
|
|
app.post("/kill")(lambda: kill(ev))
|
|
app.get("/tb_detection")(tb_detection)
|
|
app.get("/models")(list_models)
|
|
await serve(
|
|
app, # type: ignore
|
|
cfg,
|
|
shutdown_trigger=lambda: ev.wait(),
|
|
)
|
|
|
|
|
|
def kill(ev: anyio.Event):
|
|
ev.set()
|
|
return Response(status_code=204)
|
|
|
|
|
|
async def tb_detection():
|
|
send, recv = channel[GatheredInfo]()
|
|
ig = InfoGatherer(send)
|
|
with anyio.move_on_after(1):
|
|
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
|
|
with recv:
|
|
return recv.collect()
|
|
|
|
|
|
def list_models():
|
|
sent = set[str]()
|
|
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
|
|
if "--" not in path.parent.name:
|
|
continue
|
|
name = path.parent.name.replace("--", "/")
|
|
if name in sent:
|
|
continue
|
|
sent.add(name)
|
|
yield ModelId(path.parent.name.replace("--", "/"))
|
|
|
|
|
|
async def run_test(test: Tests):
|
|
weird_hn = socket.gethostname()
|
|
for dev in test.devs:
|
|
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
|
hn = dev[0]
|
|
break
|
|
else:
|
|
raise ValueError(f"{weird_hn} not in {test.devs}")
|
|
|
|
async def run():
|
|
logger.info(f"testing {test.model_id}")
|
|
|
|
instances: list[Instance] = []
|
|
if test.kind in ["ring", "both"]:
|
|
i = await ring_instance(test, hn)
|
|
if i is None:
|
|
yield "no model found"
|
|
return
|
|
instances.append(i)
|
|
if test.kind in ["jaccl", "both"]:
|
|
i = await jaccl_instance(test)
|
|
if i is None:
|
|
yield "no model found"
|
|
return
|
|
instances.append(i)
|
|
|
|
for instance in instances:
|
|
recv = await execute_test(test, instance, hn)
|
|
|
|
str_out = ""
|
|
|
|
for item in recv:
|
|
if isinstance(item, ChunkGenerated):
|
|
assert isinstance(item.chunk, TokenChunk)
|
|
str_out += item.chunk.text
|
|
|
|
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
|
item.runner_status, (RunnerFailed, RunnerShutdown)
|
|
):
|
|
yield str_out + "\n"
|
|
yield item.model_dump_json() + "\n"
|
|
|
|
return StreamingResponse(run())
|
|
|
|
|
|
async def ring_instance(test: Tests, hn: str) -> Instance | None:
|
|
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
|
|
world_size = len(test.devs)
|
|
for i in range(world_size):
|
|
if test.devs[i][0] == hn:
|
|
hn = test.devs[i][0]
|
|
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
|
|
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
|
|
hbn[i] = Host(ip="0.0.0.0", port=52417)
|
|
break
|
|
else:
|
|
raise ValueError(f"{hn} not in {test.devs}")
|
|
|
|
card = await ModelCard.load(test.model_id)
|
|
instance = MlxRingInstance(
|
|
instance_id=iid,
|
|
ephemeral_port=52417,
|
|
hosts_by_node={NodeId(hn): hbn},
|
|
shard_assignments=ShardAssignments(
|
|
model_id=test.model_id,
|
|
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
|
runner_to_shard={
|
|
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
|
model_card=card,
|
|
device_rank=i,
|
|
world_size=world_size,
|
|
start_layer=(card.n_layers // world_size) * i,
|
|
end_layer=min(
|
|
card.n_layers, (card.n_layers // world_size) * (i + 1)
|
|
),
|
|
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
|
|
- (card.n_layers // world_size) * i,
|
|
)
|
|
for i in range(world_size)
|
|
},
|
|
),
|
|
)
|
|
|
|
return instance
|
|
|
|
|
|
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
|
|
world_size = len(test.devs)
|
|
commands: list[Task] = [
|
|
(LoadModel(instance_id=iid)),
|
|
(StartWarmup(instance_id=iid)),
|
|
(
|
|
TextGeneration(
|
|
task_params=TextGenerationTaskParams(
|
|
model=test.model_id,
|
|
instructions="You are a helpful assistant",
|
|
input=[
|
|
InputMessage(
|
|
role="user", content="What is the capital of France?"
|
|
)
|
|
],
|
|
),
|
|
command_id=CommandId("yo"),
|
|
instance_id=iid,
|
|
)
|
|
),
|
|
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
|
|
]
|
|
if world_size > 1:
|
|
commands.insert(0, ConnectToGroup(instance_id=iid))
|
|
bound_instance = BoundInstance(
|
|
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
|
)
|
|
ev_send, _ev_recv = mp_channel[Event]()
|
|
task_send, task_recv = mp_channel[Task]()
|
|
|
|
for command in commands:
|
|
task_send.send(command)
|
|
|
|
entrypoint(
|
|
bound_instance,
|
|
ev_send,
|
|
task_recv,
|
|
logger,
|
|
)
|
|
|
|
# TODO(evan): return ev_recv.collect()
|
|
return []
|
|
|
|
|
|
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
|
|
card = await ModelCard.load(test.model_id)
|
|
world_size = len(test.devs)
|
|
assert test.ibv_devs
|
|
|
|
return MlxJacclInstance(
|
|
instance_id=iid,
|
|
jaccl_devices=test.ibv_devs,
|
|
# rank 0 is always coordinator
|
|
jaccl_coordinators={
|
|
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
|
|
},
|
|
shard_assignments=ShardAssignments(
|
|
model_id=test.model_id,
|
|
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
|
runner_to_shard={
|
|
RunnerId(host[0]): TensorShardMetadata(
|
|
model_card=card,
|
|
device_rank=i,
|
|
world_size=world_size,
|
|
start_layer=0,
|
|
end_layer=card.n_layers,
|
|
n_layers=card.n_layers,
|
|
)
|
|
for i, host in enumerate(test.devs)
|
|
},
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
anyio.run(main)
|