Files
exo/tests/headless_runner.py
Alex Cheema acb97127bf Normalize TextGenerationTaskParams.input to list[InputMessage] (#1360)
## Motivation

With the addition of the Responses API, we introduced `str |
list[InputMessage]` as the type for `TextGenerationTaskParams.input`
since the Responses API supports sending input as a plain string. But
there was no reason to leak that flexibility past the API adapter
boundary — it just meant every downstream consumer had to do `if
isinstance(messages, str):` checks, adding complexity for no benefit.

## Changes

- Changed `TextGenerationTaskParams.input` from `str |
list[InputMessage]` to `list[InputMessage]`
- Each API adapter (Chat Completions, Claude Messages, Responses) now
normalizes to `list[InputMessage]` at the boundary
- Removed `isinstance(task_params.input, str)` branches in
`utils_mlx.py` and `runner.py`
- Wrapped string inputs in `[InputMessage(role="user", content=...)]` in
the warmup path and all test files

## Why It Works

The API adapters are the only place where we deal with raw user input
formats. By normalizing there, all downstream code (worker, runner, MLX
engine) can just assume `list[InputMessage]` and skip the type-checking
branches. The type system (`basedpyright`) catches any missed call sites
at compile time.

## Test Plan

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — applied
- `uv run pytest` — 174 passed, 1 skipped

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 06:01:56 -08:00

265 lines
8.0 KiB
Python

import socket
from typing import Literal
import anyio
from fastapi import FastAPI
from fastapi.responses import Response, StreamingResponse
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TextGeneration,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
ibv_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "jaccl", "both"]
iid = InstanceId("im testing here")
async def main():
logger.info("starting cool server majig")
cfg = Config()
cfg.bind = "0.0.0.0:52414"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
ev = anyio.Event()
app = FastAPI()
app.post("/run_test")(run_test)
app.post("/kill")(lambda: kill(ev))
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: ev.wait(),
)
def kill(ev: anyio.Event):
ev.set()
return Response(status_code=204)
async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def run_test(test: Tests):
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = await ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["jaccl", "both"]:
i = await jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
async def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
hbn[i] = Host(ip="0.0.0.0", port=52417)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
card = await ModelCard.load(test.model_id)
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52417,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=(card.n_layers // world_size) * i,
end_layer=min(
card.n_layers, (card.n_layers // world_size) * (i + 1)
),
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
)
for i in range(world_size)
},
),
)
return instance
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
world_size = len(test.devs)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input=[
InputMessage(
role="user", content="What is the capital of France?"
)
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
if world_size > 1:
commands.insert(0, ConnectToGroup(instance_id=iid))
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, _ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
for command in commands:
task_send.send(command)
entrypoint(
bound_instance,
ev_send,
task_recv,
logger,
)
# TODO(evan): return ev_recv.collect()
return []
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = await ModelCard.load(test.model_id)
world_size = len(test.devs)
assert test.ibv_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=test.ibv_devs,
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(host[0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i, host in enumerate(test.devs)
},
),
)
if __name__ == "__main__":
anyio.run(main)