Compare commits

..

2 Commits

Author SHA1 Message Date
Alex Cheema
21c363e997 fix: move suppress(ClosedResourceError) inside runner.shutdown() per review
Move the ClosedResourceError suppression from the two call sites in
worker/main.py into RunnerSupervisor.shutdown() itself, so each
close/send on already-closed channels is individually guarded.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 09:01:54 -08:00
Alex Cheema
b1c0e3116d fix: misc bug fixes (spawn force, download restart, shutdown guard)
Three independent fixes extracted from meta-instance branch (#1519):

- Use force=True for mp.set_start_method("spawn") to prevent errors
  when the start method was already set by another initialization path
- Detect already-complete downloads on restart instead of reporting them
  as DownloadPending (checks downloaded_bytes >= total_bytes)
- Guard runner.shutdown() with contextlib.suppress(ClosedResourceError)
  to handle already-closed resources during worker teardown

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 05:37:11 -08:00
7 changed files with 35 additions and 207 deletions

View File

@@ -20,7 +20,6 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -963,21 +962,6 @@ Examples:
selected.sort(key=_placement_sort_key) selected.sort(key=_placement_sort_key)
preview = selected[0] preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"] instance = preview["instance"]
instance_id = instance_id_from_instance(instance) instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"]) sharding = str(preview["sharding"])

View File

@@ -35,7 +35,6 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -333,20 +332,6 @@ def main() -> int:
if args.dry_run: if args.dry_run:
return 0 return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = [] all_rows: list[dict[str, Any]] = []
for preview in selected: for preview in selected:

View File

@@ -282,151 +282,6 @@ def settle_and_fetch_placements(
return selected return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None: def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost")) ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument( ap.add_argument(
@@ -470,8 +325,3 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0, default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).", help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
) )
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -338,7 +338,17 @@ class DownloadCoordinator:
), ),
) )
elif progress.status in ["in_progress", "not_started"]: elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0: if (
progress.downloaded_bytes.in_bytes
>= progress.total_bytes.in_bytes
> 0
):
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending( status = DownloadPending(
node_id=self.node_id, node_id=self.node_id,
shard_metadata=progress.shard, shard_metadata=progress.shard,

View File

@@ -258,7 +258,7 @@ def main():
target = min(max(soft, 65535), hard) target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard)) resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn") mp.set_start_method("spawn", force=True)
# TODO: Refactor the current verbosity system # TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity) logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO") logger.info("Starting EXO")

View File

@@ -31,7 +31,6 @@ from exo.shared.types.openai_responses import (
ResponseOutputText, ResponseOutputText,
ResponsesRequest, ResponsesRequest,
ResponsesResponse, ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseTextDoneEvent,
ResponseUsage, ResponseUsage,
@@ -39,11 +38,6 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str: def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts.""" """Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str): if isinstance(content, str):
@@ -225,13 +219,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent( created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield _format_sse(created_event) yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress # response.in_progress
in_progress_event = ResponseInProgressEvent( in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield _format_sse(in_progress_event) yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added # response.output_item.added
initial_item = ResponseMessageItem( initial_item = ResponseMessageItem(
@@ -242,7 +236,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent( item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item sequence_number=next(seq), output_index=0, item=initial_item
) )
yield _format_sse(item_added) yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added # response.content_part.added
initial_part = ResponseOutputText(text="") initial_part = ResponseOutputText(text="")
@@ -253,7 +247,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
part=initial_part, part=initial_part,
) )
yield _format_sse(part_added) yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = "" accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = [] function_call_items: list[ResponseFunctionCallItem] = []
@@ -287,7 +281,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_item, item=fc_item,
) )
yield _format_sse(fc_added) yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
# response.function_call_arguments.delta # response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent( args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -296,7 +290,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
delta=tool.arguments, delta=tool.arguments,
) )
yield _format_sse(args_delta) yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
# response.function_call_arguments.done # response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent( args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -306,7 +300,7 @@ async def generate_responses_stream(
name=tool.name, name=tool.name,
arguments=tool.arguments, arguments=tool.arguments,
) )
yield _format_sse(args_done) yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
# response.output_item.done # response.output_item.done
fc_done_item = ResponseFunctionCallItem( fc_done_item = ResponseFunctionCallItem(
@@ -321,7 +315,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_done_item, item=fc_done_item,
) )
yield _format_sse(fc_item_done) yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
function_call_items.append(fc_done_item) function_call_items.append(fc_done_item)
next_output_index += 1 next_output_index += 1
@@ -337,7 +331,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
delta=chunk.text, delta=chunk.text,
) )
yield _format_sse(delta_event) yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done # response.output_text.done
text_done = ResponseTextDoneEvent( text_done = ResponseTextDoneEvent(
@@ -347,7 +341,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
text=accumulated_text, text=accumulated_text,
) )
yield _format_sse(text_done) yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done # response.content_part.done
final_part = ResponseOutputText(text=accumulated_text) final_part = ResponseOutputText(text=accumulated_text)
@@ -358,7 +352,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
part=final_part, part=final_part,
) )
yield _format_sse(part_done) yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done # response.output_item.done
final_message_item = ResponseMessageItem( final_message_item = ResponseMessageItem(
@@ -369,7 +363,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent( item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item sequence_number=next(seq), output_index=0, item=final_message_item
) )
yield _format_sse(item_done) yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from usage data if available # Create usage from usage data if available
usage = None usage = None
@@ -394,4 +388,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent( completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response sequence_number=next(seq), response=final_response
) )
yield _format_sse(completed_event) yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -98,11 +98,16 @@ class RunnerSupervisor:
def shutdown(self): def shutdown(self):
logger.info("Runner supervisor shutting down") logger.info("Runner supervisor shutting down")
self._ev_recv.close() with contextlib.suppress(ClosedResourceError):
self._task_sender.close() self._ev_recv.close()
self._event_sender.close() with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK")) self._task_sender.close()
self._cancel_sender.close() with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
self.runner_process.join(5) self.runner_process.join(5)
if not self.runner_process.is_alive(): if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated") logger.info("Runner process succesfully terminated")