mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 23:36:30 -05:00
Compare commits
2 Commits
test-scree
...
meta-insta
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21c363e997 | ||
|
|
b1c0e3116d |
@@ -20,7 +20,6 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
run_planning_phase,
|
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -963,21 +962,6 @@ Examples:
|
|||||||
|
|
||||||
selected.sort(key=_placement_sort_key)
|
selected.sort(key=_placement_sort_key)
|
||||||
preview = selected[0]
|
preview = selected[0]
|
||||||
|
|
||||||
settle_deadline = (
|
|
||||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Planning phase: checking downloads...", file=log)
|
|
||||||
run_planning_phase(
|
|
||||||
exo,
|
|
||||||
full_model_id,
|
|
||||||
preview,
|
|
||||||
args.danger_delete_downloads,
|
|
||||||
args.timeout,
|
|
||||||
settle_deadline,
|
|
||||||
)
|
|
||||||
|
|
||||||
instance = preview["instance"]
|
instance = preview["instance"]
|
||||||
instance_id = instance_id_from_instance(instance)
|
instance_id = instance_id_from_instance(instance)
|
||||||
sharding = str(preview["sharding"])
|
sharding = str(preview["sharding"])
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
run_planning_phase,
|
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -333,20 +332,6 @@ def main() -> int:
|
|||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
settle_deadline = (
|
|
||||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Planning phase: checking downloads...")
|
|
||||||
run_planning_phase(
|
|
||||||
client,
|
|
||||||
full_model_id,
|
|
||||||
selected[0],
|
|
||||||
args.danger_delete_downloads,
|
|
||||||
args.timeout,
|
|
||||||
settle_deadline,
|
|
||||||
)
|
|
||||||
|
|
||||||
all_rows: list[dict[str, Any]] = []
|
all_rows: list[dict[str, Any]] = []
|
||||||
|
|
||||||
for preview in selected:
|
for preview in selected:
|
||||||
|
|||||||
150
bench/harness.py
150
bench/harness.py
@@ -282,151 +282,6 @@ def settle_and_fetch_placements(
|
|||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
def run_planning_phase(
|
|
||||||
client: ExoClient,
|
|
||||||
full_model_id: str,
|
|
||||||
preview: dict[str, Any],
|
|
||||||
danger_delete: bool,
|
|
||||||
timeout: float,
|
|
||||||
settle_deadline: float | None,
|
|
||||||
) -> None:
|
|
||||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
|
||||||
# Get model size from /models
|
|
||||||
models = client.request_json("GET", "/models") or {}
|
|
||||||
model_bytes = 0
|
|
||||||
for m in models.get("data", []):
|
|
||||||
if m.get("hugging_face_id") == full_model_id:
|
|
||||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
|
||||||
break
|
|
||||||
|
|
||||||
if not model_bytes:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get nodes from preview
|
|
||||||
inner = unwrap_instance(preview["instance"])
|
|
||||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
|
||||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
|
||||||
|
|
||||||
state = client.request_json("GET", "/state")
|
|
||||||
downloads = state.get("downloads", {})
|
|
||||||
node_disk = state.get("nodeDisk", {})
|
|
||||||
|
|
||||||
for node_id in node_ids:
|
|
||||||
node_downloads = downloads.get(node_id, [])
|
|
||||||
|
|
||||||
# Check if model already downloaded on this node
|
|
||||||
already_downloaded = any(
|
|
||||||
"DownloadCompleted" in p
|
|
||||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
|
||||||
"modelId"
|
|
||||||
]
|
|
||||||
== full_model_id
|
|
||||||
for p in node_downloads
|
|
||||||
)
|
|
||||||
if already_downloaded:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Wait for disk info if settle_deadline is set
|
|
||||||
disk_info = node_disk.get(node_id, {})
|
|
||||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
|
||||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
|
||||||
remaining = settle_deadline - time.monotonic()
|
|
||||||
logger.info(
|
|
||||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
|
||||||
)
|
|
||||||
time.sleep(min(backoff, remaining))
|
|
||||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
|
||||||
state = client.request_json("GET", "/state")
|
|
||||||
node_disk = state.get("nodeDisk", {})
|
|
||||||
disk_info = node_disk.get(node_id, {})
|
|
||||||
|
|
||||||
if not disk_info:
|
|
||||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
|
||||||
continue
|
|
||||||
|
|
||||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
|
||||||
if avail >= model_bytes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not danger_delete:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
|
||||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete from smallest to largest
|
|
||||||
completed = [
|
|
||||||
(
|
|
||||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
|
||||||
"modelId"
|
|
||||||
],
|
|
||||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
|
||||||
)
|
|
||||||
for p in node_downloads
|
|
||||||
if "DownloadCompleted" in p
|
|
||||||
]
|
|
||||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
|
||||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
|
||||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
|
||||||
avail += size
|
|
||||||
if avail >= model_bytes:
|
|
||||||
break
|
|
||||||
|
|
||||||
if avail < model_bytes:
|
|
||||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
|
||||||
|
|
||||||
# Start downloads (idempotent)
|
|
||||||
for node_id in node_ids:
|
|
||||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
|
||||||
shard = runner_to_shard[runner_id]
|
|
||||||
client.request_json(
|
|
||||||
"POST",
|
|
||||||
"/download/start",
|
|
||||||
body={
|
|
||||||
"targetNodeId": node_id,
|
|
||||||
"shardMetadata": shard,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logger.info(f"Started download on {node_id}")
|
|
||||||
|
|
||||||
# Wait for downloads
|
|
||||||
start = time.time()
|
|
||||||
while time.time() - start < timeout:
|
|
||||||
state = client.request_json("GET", "/state")
|
|
||||||
downloads = state.get("downloads", {})
|
|
||||||
all_done = True
|
|
||||||
for node_id in node_ids:
|
|
||||||
done = any(
|
|
||||||
"DownloadCompleted" in p
|
|
||||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
|
||||||
"modelCard"
|
|
||||||
]["modelId"]
|
|
||||||
== full_model_id
|
|
||||||
for p in downloads.get(node_id, [])
|
|
||||||
)
|
|
||||||
failed = [
|
|
||||||
p["DownloadFailed"]["errorMessage"]
|
|
||||||
for p in downloads.get(node_id, [])
|
|
||||||
if "DownloadFailed" in p
|
|
||||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
|
||||||
"modelId"
|
|
||||||
]
|
|
||||||
== full_model_id
|
|
||||||
]
|
|
||||||
if failed:
|
|
||||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
|
||||||
if not done:
|
|
||||||
all_done = False
|
|
||||||
if all_done:
|
|
||||||
return
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
raise TimeoutError("Downloads did not complete in time")
|
|
||||||
|
|
||||||
|
|
||||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@@ -470,8 +325,3 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
|||||||
default=0,
|
default=0,
|
||||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||||
)
|
)
|
||||||
ap.add_argument(
|
|
||||||
"--danger-delete-downloads",
|
|
||||||
action="store_true",
|
|
||||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -338,7 +338,17 @@ class DownloadCoordinator:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif progress.status in ["in_progress", "not_started"]:
|
elif progress.status in ["in_progress", "not_started"]:
|
||||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
if (
|
||||||
|
progress.downloaded_bytes.in_bytes
|
||||||
|
>= progress.total_bytes.in_bytes
|
||||||
|
> 0
|
||||||
|
):
|
||||||
|
status = DownloadCompleted(
|
||||||
|
node_id=self.node_id,
|
||||||
|
shard_metadata=progress.shard,
|
||||||
|
total_bytes=progress.total_bytes,
|
||||||
|
)
|
||||||
|
elif progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||||
status = DownloadPending(
|
status = DownloadPending(
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
shard_metadata=progress.shard,
|
shard_metadata=progress.shard,
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ def main():
|
|||||||
target = min(max(soft, 65535), hard)
|
target = min(max(soft, 65535), hard)
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||||
|
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn", force=True)
|
||||||
# TODO: Refactor the current verbosity system
|
# TODO: Refactor the current verbosity system
|
||||||
logger_setup(EXO_LOG, args.verbosity)
|
logger_setup(EXO_LOG, args.verbosity)
|
||||||
logger.info("Starting EXO")
|
logger.info("Starting EXO")
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from exo.shared.types.openai_responses import (
|
|||||||
ResponseOutputText,
|
ResponseOutputText,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
ResponsesStreamEvent,
|
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
@@ -39,11 +38,6 @@ from exo.shared.types.openai_responses import (
|
|||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
def _format_sse(event: ResponsesStreamEvent) -> str:
|
|
||||||
"""Format a streaming event as an SSE message."""
|
|
||||||
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -225,13 +219,13 @@ async def generate_responses_stream(
|
|||||||
created_event = ResponseCreatedEvent(
|
created_event = ResponseCreatedEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield _format_sse(created_event)
|
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.in_progress
|
# response.in_progress
|
||||||
in_progress_event = ResponseInProgressEvent(
|
in_progress_event = ResponseInProgressEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield _format_sse(in_progress_event)
|
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.output_item.added
|
# response.output_item.added
|
||||||
initial_item = ResponseMessageItem(
|
initial_item = ResponseMessageItem(
|
||||||
@@ -242,7 +236,7 @@ async def generate_responses_stream(
|
|||||||
item_added = ResponseOutputItemAddedEvent(
|
item_added = ResponseOutputItemAddedEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=initial_item
|
sequence_number=next(seq), output_index=0, item=initial_item
|
||||||
)
|
)
|
||||||
yield _format_sse(item_added)
|
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.content_part.added
|
# response.content_part.added
|
||||||
initial_part = ResponseOutputText(text="")
|
initial_part = ResponseOutputText(text="")
|
||||||
@@ -253,7 +247,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=initial_part,
|
part=initial_part,
|
||||||
)
|
)
|
||||||
yield _format_sse(part_added)
|
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||||
|
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
function_call_items: list[ResponseFunctionCallItem] = []
|
function_call_items: list[ResponseFunctionCallItem] = []
|
||||||
@@ -287,7 +281,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_item,
|
item=fc_item,
|
||||||
)
|
)
|
||||||
yield _format_sse(fc_added)
|
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.function_call_arguments.delta
|
# response.function_call_arguments.delta
|
||||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||||
@@ -296,7 +290,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
delta=tool.arguments,
|
delta=tool.arguments,
|
||||||
)
|
)
|
||||||
yield _format_sse(args_delta)
|
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.function_call_arguments.done
|
# response.function_call_arguments.done
|
||||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||||
@@ -306,7 +300,7 @@ async def generate_responses_stream(
|
|||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool.arguments,
|
arguments=tool.arguments,
|
||||||
)
|
)
|
||||||
yield _format_sse(args_done)
|
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
fc_done_item = ResponseFunctionCallItem(
|
fc_done_item = ResponseFunctionCallItem(
|
||||||
@@ -321,7 +315,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_done_item,
|
item=fc_done_item,
|
||||||
)
|
)
|
||||||
yield _format_sse(fc_item_done)
|
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
function_call_items.append(fc_done_item)
|
function_call_items.append(fc_done_item)
|
||||||
next_output_index += 1
|
next_output_index += 1
|
||||||
@@ -337,7 +331,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
delta=chunk.text,
|
delta=chunk.text,
|
||||||
)
|
)
|
||||||
yield _format_sse(delta_event)
|
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.output_text.done
|
# response.output_text.done
|
||||||
text_done = ResponseTextDoneEvent(
|
text_done = ResponseTextDoneEvent(
|
||||||
@@ -347,7 +341,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
text=accumulated_text,
|
text=accumulated_text,
|
||||||
)
|
)
|
||||||
yield _format_sse(text_done)
|
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.content_part.done
|
# response.content_part.done
|
||||||
final_part = ResponseOutputText(text=accumulated_text)
|
final_part = ResponseOutputText(text=accumulated_text)
|
||||||
@@ -358,7 +352,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=final_part,
|
part=final_part,
|
||||||
)
|
)
|
||||||
yield _format_sse(part_done)
|
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
final_message_item = ResponseMessageItem(
|
final_message_item = ResponseMessageItem(
|
||||||
@@ -369,7 +363,7 @@ async def generate_responses_stream(
|
|||||||
item_done = ResponseOutputItemDoneEvent(
|
item_done = ResponseOutputItemDoneEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||||
)
|
)
|
||||||
yield _format_sse(item_done)
|
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# Create usage from usage data if available
|
# Create usage from usage data if available
|
||||||
usage = None
|
usage = None
|
||||||
@@ -394,4 +388,4 @@ async def generate_responses_stream(
|
|||||||
completed_event = ResponseCompletedEvent(
|
completed_event = ResponseCompletedEvent(
|
||||||
sequence_number=next(seq), response=final_response
|
sequence_number=next(seq), response=final_response
|
||||||
)
|
)
|
||||||
yield _format_sse(completed_event)
|
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||||
|
|||||||
@@ -98,11 +98,16 @@ class RunnerSupervisor:
|
|||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
logger.info("Runner supervisor shutting down")
|
logger.info("Runner supervisor shutting down")
|
||||||
self._ev_recv.close()
|
with contextlib.suppress(ClosedResourceError):
|
||||||
self._task_sender.close()
|
self._ev_recv.close()
|
||||||
self._event_sender.close()
|
with contextlib.suppress(ClosedResourceError):
|
||||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
self._task_sender.close()
|
||||||
self._cancel_sender.close()
|
with contextlib.suppress(ClosedResourceError):
|
||||||
|
self._event_sender.close()
|
||||||
|
with contextlib.suppress(ClosedResourceError):
|
||||||
|
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
|
with contextlib.suppress(ClosedResourceError):
|
||||||
|
self._cancel_sender.close()
|
||||||
self.runner_process.join(5)
|
self.runner_process.join(5)
|
||||||
if not self.runner_process.is_alive():
|
if not self.runner_process.is_alive():
|
||||||
logger.info("Runner process succesfully terminated")
|
logger.info("Runner process succesfully terminated")
|
||||||
|
|||||||
Reference in New Issue
Block a user