mirror of
https://github.com/exo-explore/exo.git
synced 2026-04-17 20:40:35 -04:00
## Motivation Replace confusing EXO_MODELS_DIR/EXO_MODELS_PATH with clearer multi-directory support, enabling automatic download spillover across volumes. ## Changes - EXO_MODELS_DIRS: colon-separated writable dirs (default always prepended, first with enough space wins) - EXO_MODELS_READ_ONLY_DIRS: colon-separated read-only dirs (protected from deletion) - select_download_dir(): picks writable dir by free space - resolve_existing_model(): unified lookup across all dirs - is_read_only_model_dir(): path-based read-only detection instead of hardcoded flag - Updated coordinator, worker, model cards, tests ## Why It Works Default dir always included so zero-config behavior is unchanged. Disk space checked at download time for automatic spillover. Read-only status derived from path, not hardcoded. ## Test Plan ### Manual Testing - No env vars set → identical behavior - EXO_MODELS_DIRS=/Volumes/SSD/models → downloads to external storage - EXO_MODELS_READ_ONLY_DIRS=/mnt/nfs → models found, deletion blocked ### Automated Testing - 4 new tests in test_xdg_paths.py (prepend, default-only, overlap, empty read-only) - Existing tests updated to patch new constants
265 lines
8.0 KiB
Python
265 lines
8.0 KiB
Python
import socket
|
|
from typing import Literal
|
|
|
|
import anyio
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from hypercorn import Config
|
|
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
|
from loguru import logger
|
|
from pydantic import BaseModel
|
|
|
|
from exo.shared.constants import EXO_DEFAULT_MODELS_DIR
|
|
from exo.shared.models.model_cards import ModelCard, ModelId
|
|
from exo.shared.types.chunks import TokenChunk
|
|
from exo.shared.types.commands import CommandId
|
|
from exo.shared.types.common import Host, NodeId
|
|
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
|
|
from exo.shared.types.tasks import (
|
|
ConnectToGroup,
|
|
LoadModel,
|
|
Shutdown,
|
|
StartWarmup,
|
|
Task,
|
|
TextGeneration,
|
|
)
|
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
|
from exo.shared.types.worker.instances import (
|
|
BoundInstance,
|
|
Instance,
|
|
InstanceId,
|
|
MlxJacclInstance,
|
|
MlxRingInstance,
|
|
)
|
|
from exo.shared.types.worker.runners import (
|
|
RunnerFailed,
|
|
RunnerId,
|
|
RunnerShutdown,
|
|
ShardAssignments,
|
|
)
|
|
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
|
from exo.utils.channels import channel, mp_channel
|
|
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
|
from exo.worker.runner.bootstrap import entrypoint
|
|
|
|
|
|
class Tests(BaseModel):
|
|
# list[hostname, ip addr]
|
|
devs: list[list[str]]
|
|
ibv_devs: list[list[str | None]] | None
|
|
model_id: ModelId
|
|
kind: Literal["ring", "jaccl", "both"]
|
|
|
|
|
|
iid = InstanceId("im testing here")
|
|
|
|
|
|
async def main():
|
|
logger.info("starting cool server majig")
|
|
cfg = Config()
|
|
cfg.bind = "0.0.0.0:52414"
|
|
# nb: shared.logging needs updating if any of this changes
|
|
cfg.accesslog = "-"
|
|
cfg.errorlog = "-"
|
|
ev = anyio.Event()
|
|
app = FastAPI()
|
|
app.post("/run_test")(run_test)
|
|
app.post("/kill")(lambda: kill(ev))
|
|
app.get("/tb_detection")(tb_detection)
|
|
app.get("/models")(list_models)
|
|
await serve(
|
|
app, # type: ignore
|
|
cfg,
|
|
shutdown_trigger=lambda: ev.wait(),
|
|
)
|
|
|
|
|
|
def kill(ev: anyio.Event):
|
|
ev.set()
|
|
return Response(status_code=204)
|
|
|
|
|
|
async def tb_detection():
|
|
send, recv = channel[GatheredInfo]()
|
|
ig = InfoGatherer(send)
|
|
with anyio.move_on_after(1):
|
|
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
|
|
with recv:
|
|
return recv.collect()
|
|
|
|
|
|
def list_models():
|
|
sent = set[str]()
|
|
for path in EXO_DEFAULT_MODELS_DIR.rglob("model-*.safetensors"):
|
|
if "--" not in path.parent.name:
|
|
continue
|
|
name = path.parent.name.replace("--", "/")
|
|
if name in sent:
|
|
continue
|
|
sent.add(name)
|
|
yield ModelId(path.parent.name.replace("--", "/"))
|
|
|
|
|
|
async def run_test(test: Tests):
|
|
weird_hn = socket.gethostname()
|
|
for dev in test.devs:
|
|
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
|
hn = dev[0]
|
|
break
|
|
else:
|
|
raise ValueError(f"{weird_hn} not in {test.devs}")
|
|
|
|
async def run():
|
|
logger.info(f"testing {test.model_id}")
|
|
|
|
instances: list[Instance] = []
|
|
if test.kind in ["ring", "both"]:
|
|
i = await ring_instance(test, hn)
|
|
if i is None:
|
|
yield "no model found"
|
|
return
|
|
instances.append(i)
|
|
if test.kind in ["jaccl", "both"]:
|
|
i = await jaccl_instance(test)
|
|
if i is None:
|
|
yield "no model found"
|
|
return
|
|
instances.append(i)
|
|
|
|
for instance in instances:
|
|
recv = await execute_test(test, instance, hn)
|
|
|
|
str_out = ""
|
|
|
|
for item in recv:
|
|
if isinstance(item, ChunkGenerated):
|
|
assert isinstance(item.chunk, TokenChunk)
|
|
str_out += item.chunk.text
|
|
|
|
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
|
item.runner_status, (RunnerFailed, RunnerShutdown)
|
|
):
|
|
yield str_out + "\n"
|
|
yield item.model_dump_json() + "\n"
|
|
|
|
return StreamingResponse(run())
|
|
|
|
|
|
async def ring_instance(test: Tests, hn: str) -> Instance | None:
|
|
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
|
|
world_size = len(test.devs)
|
|
for i in range(world_size):
|
|
if test.devs[i][0] == hn:
|
|
hn = test.devs[i][0]
|
|
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
|
|
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
|
|
hbn[i] = Host(ip="0.0.0.0", port=52417)
|
|
break
|
|
else:
|
|
raise ValueError(f"{hn} not in {test.devs}")
|
|
|
|
card = await ModelCard.load(test.model_id)
|
|
instance = MlxRingInstance(
|
|
instance_id=iid,
|
|
ephemeral_port=52417,
|
|
hosts_by_node={NodeId(hn): hbn},
|
|
shard_assignments=ShardAssignments(
|
|
model_id=test.model_id,
|
|
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
|
runner_to_shard={
|
|
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
|
model_card=card,
|
|
device_rank=i,
|
|
world_size=world_size,
|
|
start_layer=(card.n_layers // world_size) * i,
|
|
end_layer=min(
|
|
card.n_layers, (card.n_layers // world_size) * (i + 1)
|
|
),
|
|
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
|
|
- (card.n_layers // world_size) * i,
|
|
)
|
|
for i in range(world_size)
|
|
},
|
|
),
|
|
)
|
|
|
|
return instance
|
|
|
|
|
|
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
|
|
world_size = len(test.devs)
|
|
commands: list[Task] = [
|
|
(LoadModel(instance_id=iid)),
|
|
(StartWarmup(instance_id=iid)),
|
|
(
|
|
TextGeneration(
|
|
task_params=TextGenerationTaskParams(
|
|
model=test.model_id,
|
|
instructions="You are a helpful assistant",
|
|
input=[
|
|
InputMessage(
|
|
role="user", content="What is the capital of France?"
|
|
)
|
|
],
|
|
),
|
|
command_id=CommandId("yo"),
|
|
instance_id=iid,
|
|
)
|
|
),
|
|
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
|
|
]
|
|
if world_size > 1:
|
|
commands.insert(0, ConnectToGroup(instance_id=iid))
|
|
bound_instance = BoundInstance(
|
|
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
|
)
|
|
ev_send, _ev_recv = mp_channel[Event]()
|
|
task_send, task_recv = mp_channel[Task]()
|
|
|
|
for command in commands:
|
|
task_send.send(command)
|
|
|
|
entrypoint(
|
|
bound_instance,
|
|
ev_send,
|
|
task_recv,
|
|
logger,
|
|
)
|
|
|
|
# TODO(evan): return ev_recv.collect()
|
|
return []
|
|
|
|
|
|
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
|
|
card = await ModelCard.load(test.model_id)
|
|
world_size = len(test.devs)
|
|
assert test.ibv_devs
|
|
|
|
return MlxJacclInstance(
|
|
instance_id=iid,
|
|
jaccl_devices=test.ibv_devs,
|
|
# rank 0 is always coordinator
|
|
jaccl_coordinators={
|
|
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
|
|
},
|
|
shard_assignments=ShardAssignments(
|
|
model_id=test.model_id,
|
|
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
|
runner_to_shard={
|
|
RunnerId(host[0]): TensorShardMetadata(
|
|
model_card=card,
|
|
device_rank=i,
|
|
world_size=world_size,
|
|
start_layer=0,
|
|
end_layer=card.n_layers,
|
|
n_layers=card.n_layers,
|
|
)
|
|
for i, host in enumerate(test.devs)
|
|
},
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
anyio.run(main)
|