mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-20 07:46:42 -05:00
Compare commits
2 Commits
meta-insta
...
test-scree
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42e1e7322b | ||
|
|
aa3f106fb9 |
@@ -20,6 +20,7 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
|
run_planning_phase,
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -962,6 +963,21 @@ Examples:
|
|||||||
|
|
||||||
selected.sort(key=_placement_sort_key)
|
selected.sort(key=_placement_sort_key)
|
||||||
preview = selected[0]
|
preview = selected[0]
|
||||||
|
|
||||||
|
settle_deadline = (
|
||||||
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Planning phase: checking downloads...", file=log)
|
||||||
|
run_planning_phase(
|
||||||
|
exo,
|
||||||
|
full_model_id,
|
||||||
|
preview,
|
||||||
|
args.danger_delete_downloads,
|
||||||
|
args.timeout,
|
||||||
|
settle_deadline,
|
||||||
|
)
|
||||||
|
|
||||||
instance = preview["instance"]
|
instance = preview["instance"]
|
||||||
instance_id = instance_id_from_instance(instance)
|
instance_id = instance_id_from_instance(instance)
|
||||||
sharding = str(preview["sharding"])
|
sharding = str(preview["sharding"])
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
|
run_planning_phase,
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -332,6 +333,20 @@ def main() -> int:
|
|||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
settle_deadline = (
|
||||||
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Planning phase: checking downloads...")
|
||||||
|
run_planning_phase(
|
||||||
|
client,
|
||||||
|
full_model_id,
|
||||||
|
selected[0],
|
||||||
|
args.danger_delete_downloads,
|
||||||
|
args.timeout,
|
||||||
|
settle_deadline,
|
||||||
|
)
|
||||||
|
|
||||||
all_rows: list[dict[str, Any]] = []
|
all_rows: list[dict[str, Any]] = []
|
||||||
|
|
||||||
for preview in selected:
|
for preview in selected:
|
||||||
|
|||||||
150
bench/harness.py
150
bench/harness.py
@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
|
|||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
def run_planning_phase(
|
||||||
|
client: ExoClient,
|
||||||
|
full_model_id: str,
|
||||||
|
preview: dict[str, Any],
|
||||||
|
danger_delete: bool,
|
||||||
|
timeout: float,
|
||||||
|
settle_deadline: float | None,
|
||||||
|
) -> None:
|
||||||
|
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||||
|
# Get model size from /models
|
||||||
|
models = client.request_json("GET", "/models") or {}
|
||||||
|
model_bytes = 0
|
||||||
|
for m in models.get("data", []):
|
||||||
|
if m.get("hugging_face_id") == full_model_id:
|
||||||
|
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||||
|
break
|
||||||
|
|
||||||
|
if not model_bytes:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get nodes from preview
|
||||||
|
inner = unwrap_instance(preview["instance"])
|
||||||
|
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||||
|
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||||
|
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
downloads = state.get("downloads", {})
|
||||||
|
node_disk = state.get("nodeDisk", {})
|
||||||
|
|
||||||
|
for node_id in node_ids:
|
||||||
|
node_downloads = downloads.get(node_id, [])
|
||||||
|
|
||||||
|
# Check if model already downloaded on this node
|
||||||
|
already_downloaded = any(
|
||||||
|
"DownloadCompleted" in p
|
||||||
|
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
]
|
||||||
|
== full_model_id
|
||||||
|
for p in node_downloads
|
||||||
|
)
|
||||||
|
if already_downloaded:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Wait for disk info if settle_deadline is set
|
||||||
|
disk_info = node_disk.get(node_id, {})
|
||||||
|
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||||
|
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||||
|
remaining = settle_deadline - time.monotonic()
|
||||||
|
logger.info(
|
||||||
|
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||||
|
)
|
||||||
|
time.sleep(min(backoff, remaining))
|
||||||
|
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
node_disk = state.get("nodeDisk", {})
|
||||||
|
disk_info = node_disk.get(node_id, {})
|
||||||
|
|
||||||
|
if not disk_info:
|
||||||
|
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||||
|
continue
|
||||||
|
|
||||||
|
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||||
|
if avail >= model_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not danger_delete:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||||
|
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete from smallest to largest
|
||||||
|
completed = [
|
||||||
|
(
|
||||||
|
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
],
|
||||||
|
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||||
|
)
|
||||||
|
for p in node_downloads
|
||||||
|
if "DownloadCompleted" in p
|
||||||
|
]
|
||||||
|
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||||
|
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||||
|
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||||
|
avail += size
|
||||||
|
if avail >= model_bytes:
|
||||||
|
break
|
||||||
|
|
||||||
|
if avail < model_bytes:
|
||||||
|
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||||
|
|
||||||
|
# Start downloads (idempotent)
|
||||||
|
for node_id in node_ids:
|
||||||
|
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||||
|
shard = runner_to_shard[runner_id]
|
||||||
|
client.request_json(
|
||||||
|
"POST",
|
||||||
|
"/download/start",
|
||||||
|
body={
|
||||||
|
"targetNodeId": node_id,
|
||||||
|
"shardMetadata": shard,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(f"Started download on {node_id}")
|
||||||
|
|
||||||
|
# Wait for downloads
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
downloads = state.get("downloads", {})
|
||||||
|
all_done = True
|
||||||
|
for node_id in node_ids:
|
||||||
|
done = any(
|
||||||
|
"DownloadCompleted" in p
|
||||||
|
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||||
|
"modelCard"
|
||||||
|
]["modelId"]
|
||||||
|
== full_model_id
|
||||||
|
for p in downloads.get(node_id, [])
|
||||||
|
)
|
||||||
|
failed = [
|
||||||
|
p["DownloadFailed"]["errorMessage"]
|
||||||
|
for p in downloads.get(node_id, [])
|
||||||
|
if "DownloadFailed" in p
|
||||||
|
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
]
|
||||||
|
== full_model_id
|
||||||
|
]
|
||||||
|
if failed:
|
||||||
|
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||||
|
if not done:
|
||||||
|
all_done = False
|
||||||
|
if all_done:
|
||||||
|
return
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
raise TimeoutError("Downloads did not complete in time")
|
||||||
|
|
||||||
|
|
||||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
|||||||
default=0,
|
default=0,
|
||||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--danger-delete-downloads",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||||
|
)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
|
|||||||
ResponseOutputText,
|
ResponseOutputText,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
|
ResponsesStreamEvent,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
|
|||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
|
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||||
|
"""Format a streaming event as an SSE message."""
|
||||||
|
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -219,13 +225,13 @@ async def generate_responses_stream(
|
|||||||
created_event = ResponseCreatedEvent(
|
created_event = ResponseCreatedEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
yield _format_sse(created_event)
|
||||||
|
|
||||||
# response.in_progress
|
# response.in_progress
|
||||||
in_progress_event = ResponseInProgressEvent(
|
in_progress_event = ResponseInProgressEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
yield _format_sse(in_progress_event)
|
||||||
|
|
||||||
# response.output_item.added
|
# response.output_item.added
|
||||||
initial_item = ResponseMessageItem(
|
initial_item = ResponseMessageItem(
|
||||||
@@ -236,7 +242,7 @@ async def generate_responses_stream(
|
|||||||
item_added = ResponseOutputItemAddedEvent(
|
item_added = ResponseOutputItemAddedEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=initial_item
|
sequence_number=next(seq), output_index=0, item=initial_item
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
yield _format_sse(item_added)
|
||||||
|
|
||||||
# response.content_part.added
|
# response.content_part.added
|
||||||
initial_part = ResponseOutputText(text="")
|
initial_part = ResponseOutputText(text="")
|
||||||
@@ -247,7 +253,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=initial_part,
|
part=initial_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
yield _format_sse(part_added)
|
||||||
|
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
function_call_items: list[ResponseFunctionCallItem] = []
|
function_call_items: list[ResponseFunctionCallItem] = []
|
||||||
@@ -281,7 +287,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_item,
|
item=fc_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
yield _format_sse(fc_added)
|
||||||
|
|
||||||
# response.function_call_arguments.delta
|
# response.function_call_arguments.delta
|
||||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||||
@@ -290,7 +296,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
delta=tool.arguments,
|
delta=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
yield _format_sse(args_delta)
|
||||||
|
|
||||||
# response.function_call_arguments.done
|
# response.function_call_arguments.done
|
||||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||||
@@ -300,7 +306,7 @@ async def generate_responses_stream(
|
|||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool.arguments,
|
arguments=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
yield _format_sse(args_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
fc_done_item = ResponseFunctionCallItem(
|
fc_done_item = ResponseFunctionCallItem(
|
||||||
@@ -315,7 +321,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_done_item,
|
item=fc_done_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
yield _format_sse(fc_item_done)
|
||||||
|
|
||||||
function_call_items.append(fc_done_item)
|
function_call_items.append(fc_done_item)
|
||||||
next_output_index += 1
|
next_output_index += 1
|
||||||
@@ -331,7 +337,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
delta=chunk.text,
|
delta=chunk.text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
yield _format_sse(delta_event)
|
||||||
|
|
||||||
# response.output_text.done
|
# response.output_text.done
|
||||||
text_done = ResponseTextDoneEvent(
|
text_done = ResponseTextDoneEvent(
|
||||||
@@ -341,7 +347,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
text=accumulated_text,
|
text=accumulated_text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
yield _format_sse(text_done)
|
||||||
|
|
||||||
# response.content_part.done
|
# response.content_part.done
|
||||||
final_part = ResponseOutputText(text=accumulated_text)
|
final_part = ResponseOutputText(text=accumulated_text)
|
||||||
@@ -352,7 +358,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=final_part,
|
part=final_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
yield _format_sse(part_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
final_message_item = ResponseMessageItem(
|
final_message_item = ResponseMessageItem(
|
||||||
@@ -363,7 +369,7 @@ async def generate_responses_stream(
|
|||||||
item_done = ResponseOutputItemDoneEvent(
|
item_done = ResponseOutputItemDoneEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
yield _format_sse(item_done)
|
||||||
|
|
||||||
# Create usage from usage data if available
|
# Create usage from usage data if available
|
||||||
usage = None
|
usage = None
|
||||||
@@ -388,4 +394,4 @@ async def generate_responses_stream(
|
|||||||
completed_event = ResponseCompletedEvent(
|
completed_event = ResponseCompletedEvent(
|
||||||
sequence_number=next(seq), response=final_response
|
sequence_number=next(seq), response=final_response
|
||||||
)
|
)
|
||||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
yield _format_sse(completed_event)
|
||||||
|
|||||||
Reference in New Issue
Block a user