Files
exo/tests/headless_runner.py
ciaranbor 15f1b61f4c Rework model storage directory management (for external storage) (#1765)
## Motivation

Replace confusing EXO_MODELS_DIR/EXO_MODELS_PATH with clearer
multi-directory support, enabling automatic download spillover across
volumes.

## Changes

- EXO_MODELS_DIRS: colon-separated writable dirs (default always
prepended, first with enough space wins)
- EXO_MODELS_READ_ONLY_DIRS: colon-separated read-only dirs (protected
from deletion)
- select_download_dir(): picks writable dir by free space
- resolve_existing_model(): unified lookup across all dirs
- is_read_only_model_dir(): path-based read-only detection instead of
hardcoded flag
- Updated coordinator, worker, model cards, tests

## Why It Works

Default dir always included so zero-config behavior is unchanged. Disk
space checked at download time for automatic spillover. Read-only status
derived from path, not hardcoded.

## Test Plan

### Manual Testing

- No env vars set → identical behavior
- EXO_MODELS_DIRS=/Volumes/SSD/models → downloads to external storage
- EXO_MODELS_READ_ONLY_DIRS=/mnt/nfs → models found, deletion blocked

### Automated Testing

- 4 new tests in test_xdg_paths.py (prepend, default-only, overlap,
empty read-only)
- Existing tests updated to patch new constants
2026-03-26 17:46:46 +00:00

265 lines
8.0 KiB
Python

import socket
from typing import Literal
import anyio
from fastapi import FastAPI
from fastapi.responses import Response, StreamingResponse
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.shared.constants import EXO_DEFAULT_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TextGeneration,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
ibv_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "jaccl", "both"]
iid = InstanceId("im testing here")
async def main():
logger.info("starting cool server majig")
cfg = Config()
cfg.bind = "0.0.0.0:52414"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
ev = anyio.Event()
app = FastAPI()
app.post("/run_test")(run_test)
app.post("/kill")(lambda: kill(ev))
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: ev.wait(),
)
def kill(ev: anyio.Event):
ev.set()
return Response(status_code=204)
async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
def list_models():
sent = set[str]()
for path in EXO_DEFAULT_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def run_test(test: Tests):
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = await ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["jaccl", "both"]:
i = await jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
async def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
hbn[i] = Host(ip="0.0.0.0", port=52417)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
card = await ModelCard.load(test.model_id)
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52417,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=(card.n_layers // world_size) * i,
end_layer=min(
card.n_layers, (card.n_layers // world_size) * (i + 1)
),
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
)
for i in range(world_size)
},
),
)
return instance
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
world_size = len(test.devs)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input=[
InputMessage(
role="user", content="What is the capital of France?"
)
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
if world_size > 1:
commands.insert(0, ConnectToGroup(instance_id=iid))
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, _ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
for command in commands:
task_send.send(command)
entrypoint(
bound_instance,
ev_send,
task_recv,
logger,
)
# TODO(evan): return ev_recv.collect()
return []
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = await ModelCard.load(test.model_id)
world_size = len(test.devs)
assert test.ibv_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=test.ibv_devs,
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(host[0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i, host in enumerate(test.devs)
},
),
)
if __name__ == "__main__":
anyio.run(main)