mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-20 07:46:42 -05:00
Compare commits
2 Commits
meta-insta
...
meta-insta
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21c363e997 | ||
|
|
b1c0e3116d |
@@ -338,7 +338,17 @@ class DownloadCoordinator:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif progress.status in ["in_progress", "not_started"]:
|
elif progress.status in ["in_progress", "not_started"]:
|
||||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
if (
|
||||||
|
progress.downloaded_bytes.in_bytes
|
||||||
|
>= progress.total_bytes.in_bytes
|
||||||
|
> 0
|
||||||
|
):
|
||||||
|
status = DownloadCompleted(
|
||||||
|
node_id=self.node_id,
|
||||||
|
shard_metadata=progress.shard,
|
||||||
|
total_bytes=progress.total_bytes,
|
||||||
|
)
|
||||||
|
elif progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||||
status = DownloadPending(
|
status = DownloadPending(
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
shard_metadata=progress.shard,
|
shard_metadata=progress.shard,
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ def main():
|
|||||||
target = min(max(soft, 65535), hard)
|
target = min(max(soft, 65535), hard)
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||||
|
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn", force=True)
|
||||||
# TODO: Refactor the current verbosity system
|
# TODO: Refactor the current verbosity system
|
||||||
logger_setup(EXO_LOG, args.verbosity)
|
logger_setup(EXO_LOG, args.verbosity)
|
||||||
logger.info("Starting EXO")
|
logger.info("Starting EXO")
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from exo.shared.types.openai_responses import (
|
|||||||
ResponseOutputText,
|
ResponseOutputText,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
ResponsesStreamEvent,
|
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
@@ -39,11 +38,6 @@ from exo.shared.types.openai_responses import (
|
|||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
def _format_sse(event: ResponsesStreamEvent) -> str:
|
|
||||||
"""Format a streaming event as an SSE message."""
|
|
||||||
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -225,13 +219,13 @@ async def generate_responses_stream(
|
|||||||
created_event = ResponseCreatedEvent(
|
created_event = ResponseCreatedEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield _format_sse(created_event)
|
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.in_progress
|
# response.in_progress
|
||||||
in_progress_event = ResponseInProgressEvent(
|
in_progress_event = ResponseInProgressEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield _format_sse(in_progress_event)
|
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.output_item.added
|
# response.output_item.added
|
||||||
initial_item = ResponseMessageItem(
|
initial_item = ResponseMessageItem(
|
||||||
@@ -242,7 +236,7 @@ async def generate_responses_stream(
|
|||||||
item_added = ResponseOutputItemAddedEvent(
|
item_added = ResponseOutputItemAddedEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=initial_item
|
sequence_number=next(seq), output_index=0, item=initial_item
|
||||||
)
|
)
|
||||||
yield _format_sse(item_added)
|
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.content_part.added
|
# response.content_part.added
|
||||||
initial_part = ResponseOutputText(text="")
|
initial_part = ResponseOutputText(text="")
|
||||||
@@ -253,7 +247,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=initial_part,
|
part=initial_part,
|
||||||
)
|
)
|
||||||
yield _format_sse(part_added)
|
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||||
|
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
function_call_items: list[ResponseFunctionCallItem] = []
|
function_call_items: list[ResponseFunctionCallItem] = []
|
||||||
@@ -287,7 +281,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_item,
|
item=fc_item,
|
||||||
)
|
)
|
||||||
yield _format_sse(fc_added)
|
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.function_call_arguments.delta
|
# response.function_call_arguments.delta
|
||||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||||
@@ -296,7 +290,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
delta=tool.arguments,
|
delta=tool.arguments,
|
||||||
)
|
)
|
||||||
yield _format_sse(args_delta)
|
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.function_call_arguments.done
|
# response.function_call_arguments.done
|
||||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||||
@@ -306,7 +300,7 @@ async def generate_responses_stream(
|
|||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool.arguments,
|
arguments=tool.arguments,
|
||||||
)
|
)
|
||||||
yield _format_sse(args_done)
|
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
fc_done_item = ResponseFunctionCallItem(
|
fc_done_item = ResponseFunctionCallItem(
|
||||||
@@ -321,7 +315,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_done_item,
|
item=fc_done_item,
|
||||||
)
|
)
|
||||||
yield _format_sse(fc_item_done)
|
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
function_call_items.append(fc_done_item)
|
function_call_items.append(fc_done_item)
|
||||||
next_output_index += 1
|
next_output_index += 1
|
||||||
@@ -337,7 +331,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
delta=chunk.text,
|
delta=chunk.text,
|
||||||
)
|
)
|
||||||
yield _format_sse(delta_event)
|
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.output_text.done
|
# response.output_text.done
|
||||||
text_done = ResponseTextDoneEvent(
|
text_done = ResponseTextDoneEvent(
|
||||||
@@ -347,7 +341,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
text=accumulated_text,
|
text=accumulated_text,
|
||||||
)
|
)
|
||||||
yield _format_sse(text_done)
|
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.content_part.done
|
# response.content_part.done
|
||||||
final_part = ResponseOutputText(text=accumulated_text)
|
final_part = ResponseOutputText(text=accumulated_text)
|
||||||
@@ -358,7 +352,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=final_part,
|
part=final_part,
|
||||||
)
|
)
|
||||||
yield _format_sse(part_done)
|
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
final_message_item = ResponseMessageItem(
|
final_message_item = ResponseMessageItem(
|
||||||
@@ -369,7 +363,7 @@ async def generate_responses_stream(
|
|||||||
item_done = ResponseOutputItemDoneEvent(
|
item_done = ResponseOutputItemDoneEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||||
)
|
)
|
||||||
yield _format_sse(item_done)
|
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# Create usage from usage data if available
|
# Create usage from usage data if available
|
||||||
usage = None
|
usage = None
|
||||||
@@ -394,4 +388,4 @@ async def generate_responses_stream(
|
|||||||
completed_event = ResponseCompletedEvent(
|
completed_event = ResponseCompletedEvent(
|
||||||
sequence_number=next(seq), response=final_response
|
sequence_number=next(seq), response=final_response
|
||||||
)
|
)
|
||||||
yield _format_sse(completed_event)
|
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||||
|
|||||||
@@ -36,8 +36,6 @@ from exo.shared.types.events import (
|
|||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
InstanceDeleted,
|
InstanceDeleted,
|
||||||
JacclSideChannelData,
|
|
||||||
JacclSideChannelGathered,
|
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
NodeTimedOut,
|
NodeTimedOut,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
@@ -62,7 +60,6 @@ from exo.shared.types.tasks import (
|
|||||||
TextGeneration as TextGenerationTask,
|
TextGeneration as TextGenerationTask,
|
||||||
)
|
)
|
||||||
from exo.shared.types.worker.instances import InstanceId
|
from exo.shared.types.worker.instances import InstanceId
|
||||||
from exo.shared.types.worker.runners import RunnerId
|
|
||||||
from exo.utils.channels import Receiver, Sender, channel
|
from exo.utils.channels import Receiver, Sender, channel
|
||||||
from exo.utils.event_buffer import MultiSourceBuffer
|
from exo.utils.event_buffer import MultiSourceBuffer
|
||||||
|
|
||||||
@@ -97,7 +94,6 @@ class Master:
|
|||||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||||
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
logger.info("Starting Master")
|
logger.info("Starting Master")
|
||||||
@@ -411,11 +407,6 @@ class Master:
|
|||||||
self._event_log.append(event)
|
self._event_log.append(event)
|
||||||
await self._send_event(indexed)
|
await self._send_event(indexed)
|
||||||
|
|
||||||
# After broadcasting JacclSideChannelData, accumulate and
|
|
||||||
# emit gathered result when all runners have contributed.
|
|
||||||
if isinstance(event, JacclSideChannelData):
|
|
||||||
await self._handle_jaccl_side_channel(event)
|
|
||||||
|
|
||||||
async def _loopback_processor(self) -> None:
|
async def _loopback_processor(self) -> None:
|
||||||
# this would ideally not be necessary.
|
# this would ideally not be necessary.
|
||||||
# this is WAY less hacky than how I was working around this before
|
# this is WAY less hacky than how I was working around this before
|
||||||
@@ -469,42 +460,3 @@ class Master:
|
|||||||
del self._pending_traces[task_id]
|
del self._pending_traces[task_id]
|
||||||
if task_id in self._expected_ranks:
|
if task_id in self._expected_ranks:
|
||||||
del self._expected_ranks[task_id]
|
del self._expected_ranks[task_id]
|
||||||
|
|
||||||
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
|
|
||||||
"""Accumulate SideChannel contributions; when all runners for an instance
|
|
||||||
have submitted for the same sequence, emit JacclSideChannelGathered."""
|
|
||||||
iid = event.instance_id
|
|
||||||
seq = event.sequence
|
|
||||||
|
|
||||||
if iid not in self._jaccl_pending:
|
|
||||||
self._jaccl_pending[iid] = {}
|
|
||||||
if seq not in self._jaccl_pending[iid]:
|
|
||||||
self._jaccl_pending[iid][seq] = {}
|
|
||||||
self._jaccl_pending[iid][seq][event.runner_id] = event.data
|
|
||||||
|
|
||||||
instance = self.state.instances.get(iid)
|
|
||||||
if instance is None:
|
|
||||||
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
|
|
||||||
return
|
|
||||||
|
|
||||||
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
|
|
||||||
submitted = set(self._jaccl_pending[iid][seq].keys())
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"JACCL side channel: instance={iid} seq={seq} "
|
|
||||||
f"submitted={len(submitted)}/{len(expected_runners)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if submitted >= expected_runners:
|
|
||||||
gathered = dict(self._jaccl_pending[iid][seq])
|
|
||||||
del self._jaccl_pending[iid][seq]
|
|
||||||
if not self._jaccl_pending[iid]:
|
|
||||||
del self._jaccl_pending[iid]
|
|
||||||
|
|
||||||
await self.event_sender.send(
|
|
||||||
JacclSideChannelGathered(
|
|
||||||
instance_id=iid,
|
|
||||||
sequence=seq,
|
|
||||||
gathered_data=gathered,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ from exo.shared.types.events import (
|
|||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
InstanceCreated,
|
InstanceCreated,
|
||||||
InstanceDeleted,
|
InstanceDeleted,
|
||||||
JacclSideChannelData,
|
|
||||||
JacclSideChannelGathered,
|
|
||||||
NodeDownloadProgress,
|
NodeDownloadProgress,
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
NodeTimedOut,
|
NodeTimedOut,
|
||||||
@@ -70,8 +68,6 @@ def event_apply(event: Event, state: State) -> State:
|
|||||||
| PrefillProgress()
|
| PrefillProgress()
|
||||||
| TracesCollected()
|
| TracesCollected()
|
||||||
| TracesMerged()
|
| TracesMerged()
|
||||||
| JacclSideChannelData()
|
|
||||||
| JacclSideChannelGathered()
|
|
||||||
): # Pass-through events that don't modify state
|
): # Pass-through events that don't modify state
|
||||||
return state
|
return state
|
||||||
case InstanceCreated():
|
case InstanceCreated():
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import base64
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Annotated, final
|
from typing import final
|
||||||
|
|
||||||
from pydantic import BeforeValidator, Field, PlainSerializer
|
from pydantic import Field
|
||||||
|
|
||||||
from exo.shared.topology import Connection
|
from exo.shared.topology import Connection
|
||||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||||
@@ -16,28 +14,6 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
|||||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||||
|
|
||||||
|
|
||||||
def _decode_base64_bytes(v: bytes | str) -> bytes:
|
|
||||||
if isinstance(v, bytes):
|
|
||||||
return v
|
|
||||||
return base64.b64decode(v)
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_base64_bytes(v: bytes) -> str:
|
|
||||||
return base64.b64encode(v).decode("ascii")
|
|
||||||
|
|
||||||
|
|
||||||
Base64Bytes = Annotated[
|
|
||||||
bytes,
|
|
||||||
BeforeValidator(_decode_base64_bytes),
|
|
||||||
PlainSerializer(_encode_base64_bytes, return_type=str),
|
|
||||||
]
|
|
||||||
"""bytes that serialize to/from base64 strings in JSON.
|
|
||||||
|
|
||||||
Needed because TaggedModel's wrap validator converts JSON→Python validation
|
|
||||||
context, which breaks strict-mode bytes deserialization from JSON strings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class EventId(Id):
|
class EventId(Id):
|
||||||
"""
|
"""
|
||||||
Newtype around `ID`
|
Newtype around `ID`
|
||||||
@@ -163,25 +139,6 @@ class TracesMerged(BaseEvent):
|
|||||||
traces: list[TraceEventData]
|
traces: list[TraceEventData]
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class JacclSideChannelData(BaseEvent):
|
|
||||||
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
|
|
||||||
|
|
||||||
instance_id: InstanceId
|
|
||||||
runner_id: RunnerId
|
|
||||||
sequence: int
|
|
||||||
data: Base64Bytes
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class JacclSideChannelGathered(BaseEvent):
|
|
||||||
"""Gathered result of a JACCL SideChannel all_gather round."""
|
|
||||||
|
|
||||||
instance_id: InstanceId
|
|
||||||
sequence: int
|
|
||||||
gathered_data: Mapping[RunnerId, Base64Bytes]
|
|
||||||
|
|
||||||
|
|
||||||
Event = (
|
Event = (
|
||||||
TestEvent
|
TestEvent
|
||||||
| TaskCreated
|
| TaskCreated
|
||||||
@@ -203,8 +160,6 @@ Event = (
|
|||||||
| TopologyEdgeDeleted
|
| TopologyEdgeDeleted
|
||||||
| TracesCollected
|
| TracesCollected
|
||||||
| TracesMerged
|
| TracesMerged
|
||||||
| JacclSideChannelData
|
|
||||||
| JacclSideChannelGathered
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -643,11 +643,6 @@ def mlx_cleanup(
|
|||||||
|
|
||||||
|
|
||||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||||
"""Synchronize a boolean across all distributed nodes.
|
|
||||||
|
|
||||||
Returns True if any node has bool_=True. Uses all_sum so every
|
|
||||||
node participates in the collective — preventing GPU deadlocks.
|
|
||||||
"""
|
|
||||||
if group is None:
|
if group is None:
|
||||||
return bool_
|
return bool_
|
||||||
num_true = mx.distributed.all_sum(
|
num_true = mx.distributed.all_sum(
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from exo.shared.types.events import (
|
|||||||
ForwarderEvent,
|
ForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
JacclSideChannelGathered,
|
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
@@ -160,15 +159,6 @@ class Worker:
|
|||||||
for idx, event in indexed_events:
|
for idx, event in indexed_events:
|
||||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||||
|
|
||||||
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
|
|
||||||
if isinstance(event, JacclSideChannelGathered):
|
|
||||||
for runner in self.runners.values():
|
|
||||||
if (
|
|
||||||
runner.bound_instance.instance.instance_id
|
|
||||||
== event.instance_id
|
|
||||||
):
|
|
||||||
runner.notify_gathered(event)
|
|
||||||
|
|
||||||
# Buffer input image chunks for image editing
|
# Buffer input image chunks for image editing
|
||||||
if isinstance(event, InputChunkReceived):
|
if isinstance(event, InputChunkReceived):
|
||||||
cmd_id = event.command_id
|
cmd_id = event.command_id
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ def entrypoint(
|
|||||||
task_receiver: MpReceiver[Task],
|
task_receiver: MpReceiver[Task],
|
||||||
cancel_receiver: MpReceiver[TaskId],
|
cancel_receiver: MpReceiver[TaskId],
|
||||||
_logger: "loguru.Logger",
|
_logger: "loguru.Logger",
|
||||||
pipe_fifo_paths: tuple[str, str] | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||||
if fast_synch_override == "on" or (
|
if fast_synch_override == "on" or (
|
||||||
@@ -31,16 +30,6 @@ def entrypoint(
|
|||||||
else:
|
else:
|
||||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||||
|
|
||||||
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
|
|
||||||
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
|
|
||||||
if pipe_fifo_paths is not None:
|
|
||||||
fifo_c2p, fifo_p2c = pipe_fifo_paths
|
|
||||||
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
|
|
||||||
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
|
|
||||||
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
|
|
||||||
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
|
|
||||||
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
|
|
||||||
|
|
||||||
global logger
|
global logger
|
||||||
logger = _logger
|
logger = _logger
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
|
||||||
import signal
|
import signal
|
||||||
import struct
|
|
||||||
import tempfile
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
@@ -18,14 +14,12 @@ from loguru import logger
|
|||||||
|
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
Event,
|
Event,
|
||||||
JacclSideChannelData,
|
|
||||||
JacclSideChannelGathered,
|
|
||||||
RunnerStatusUpdated,
|
RunnerStatusUpdated,
|
||||||
TaskAcknowledged,
|
TaskAcknowledged,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
)
|
)
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
from exo.shared.types.worker.instances import BoundInstance
|
||||||
from exo.shared.types.worker.runners import (
|
from exo.shared.types.worker.runners import (
|
||||||
RunnerConnecting,
|
RunnerConnecting,
|
||||||
RunnerFailed,
|
RunnerFailed,
|
||||||
@@ -40,26 +34,6 @@ from exo.shared.types.worker.shards import ShardMetadata
|
|||||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||||
from exo.worker.runner.bootstrap import entrypoint
|
from exo.worker.runner.bootstrap import entrypoint
|
||||||
|
|
||||||
|
|
||||||
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
|
|
||||||
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
|
|
||||||
data = b""
|
|
||||||
while len(data) < n:
|
|
||||||
chunk = os.read(fd, n - len(data))
|
|
||||||
if not chunk:
|
|
||||||
return None
|
|
||||||
data += chunk
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _pipe_write_all(fd: int, data: bytes) -> None:
|
|
||||||
"""Write all bytes to a file descriptor."""
|
|
||||||
view = memoryview(data)
|
|
||||||
while view:
|
|
||||||
written = os.write(fd, view)
|
|
||||||
view = view[written:]
|
|
||||||
|
|
||||||
|
|
||||||
PREFILL_TIMEOUT_SECONDS = 60
|
PREFILL_TIMEOUT_SECONDS = 60
|
||||||
DECODE_TIMEOUT_SECONDS = 5
|
DECODE_TIMEOUT_SECONDS = 5
|
||||||
|
|
||||||
@@ -74,19 +48,10 @@ class RunnerSupervisor:
|
|||||||
_task_sender: MpSender[Task]
|
_task_sender: MpSender[Task]
|
||||||
_event_sender: Sender[Event]
|
_event_sender: Sender[Event]
|
||||||
_cancel_sender: MpSender[TaskId]
|
_cancel_sender: MpSender[TaskId]
|
||||||
_pipe_read_fd: int | None = None # Python reads runner's pipe output
|
|
||||||
_pipe_write_fd: int | None = None # Python writes gathered data to runner
|
|
||||||
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
|
|
||||||
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
|
|
||||||
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
|
|
||||||
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
|
|
||||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||||
_gathered_waiters: dict[
|
|
||||||
int, tuple[anyio.Event, JacclSideChannelGathered | None]
|
|
||||||
] = field(default_factory=dict, init=False)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@@ -100,23 +65,6 @@ class RunnerSupervisor:
|
|||||||
task_sender, task_recv = mp_channel[Task]()
|
task_sender, task_recv = mp_channel[Task]()
|
||||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||||
|
|
||||||
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
|
|
||||||
# Named pipes work across multiprocessing.Process spawn (macOS default).
|
|
||||||
# FIFO c2p: C++ writes local data → Python reads it
|
|
||||||
# FIFO p2c: Python writes gathered data → C++ reads it
|
|
||||||
fifo_dir: str | None = None
|
|
||||||
fifo_c2p: str | None = None
|
|
||||||
fifo_p2c: str | None = None
|
|
||||||
pipe_fifo_paths: tuple[str, str] | None = None
|
|
||||||
|
|
||||||
if isinstance(bound_instance.instance, MlxJacclInstance):
|
|
||||||
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
|
|
||||||
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
|
|
||||||
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
|
|
||||||
os.mkfifo(fifo_c2p)
|
|
||||||
os.mkfifo(fifo_p2c)
|
|
||||||
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
|
|
||||||
|
|
||||||
runner_process = Process(
|
runner_process = Process(
|
||||||
target=entrypoint,
|
target=entrypoint,
|
||||||
args=(
|
args=(
|
||||||
@@ -125,7 +73,6 @@ class RunnerSupervisor:
|
|||||||
task_recv,
|
task_recv,
|
||||||
cancel_recv,
|
cancel_recv,
|
||||||
logger,
|
logger,
|
||||||
pipe_fifo_paths,
|
|
||||||
),
|
),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
@@ -141,58 +88,27 @@ class RunnerSupervisor:
|
|||||||
_task_sender=task_sender,
|
_task_sender=task_sender,
|
||||||
_cancel_sender=cancel_sender,
|
_cancel_sender=cancel_sender,
|
||||||
_event_sender=event_sender,
|
_event_sender=event_sender,
|
||||||
_fifo_dir=fifo_dir,
|
|
||||||
_fifo_c2p=fifo_c2p,
|
|
||||||
_fifo_p2c=fifo_p2c,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
self.runner_process.start()
|
self.runner_process.start()
|
||||||
|
await self._forward_events()
|
||||||
if self._fifo_c2p is not None and self._fifo_p2c is not None:
|
|
||||||
# Open FIFOs from parent side. These block until child opens the other end,
|
|
||||||
# so we run them in threads concurrently to avoid deadlock.
|
|
||||||
fifo_c2p = self._fifo_c2p
|
|
||||||
fifo_p2c = self._fifo_p2c
|
|
||||||
|
|
||||||
async def open_read() -> None:
|
|
||||||
self._pipe_read_fd = await to_thread.run_sync(
|
|
||||||
partial(os.open, fifo_c2p, os.O_RDONLY)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def open_write() -> None:
|
|
||||||
self._pipe_write_fd = await to_thread.run_sync(
|
|
||||||
partial(os.open, fifo_p2c, os.O_WRONLY)
|
|
||||||
)
|
|
||||||
|
|
||||||
async with anyio.create_task_group() as open_tg:
|
|
||||||
open_tg.start_soon(open_read)
|
|
||||||
open_tg.start_soon(open_write)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
|
|
||||||
)
|
|
||||||
|
|
||||||
async with anyio.create_task_group() as tg:
|
|
||||||
tg.start_soon(self._pipe_relay)
|
|
||||||
tg.start_soon(self._forward_events)
|
|
||||||
else:
|
|
||||||
await self._forward_events()
|
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
logger.info("Runner supervisor shutting down")
|
logger.info("Runner supervisor shutting down")
|
||||||
self._ev_recv.close()
|
with contextlib.suppress(ClosedResourceError):
|
||||||
self._task_sender.close()
|
self._ev_recv.close()
|
||||||
try:
|
with contextlib.suppress(ClosedResourceError):
|
||||||
|
self._task_sender.close()
|
||||||
|
with contextlib.suppress(ClosedResourceError):
|
||||||
|
self._event_sender.close()
|
||||||
|
with contextlib.suppress(ClosedResourceError):
|
||||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
|
with contextlib.suppress(ClosedResourceError):
|
||||||
self._cancel_sender.close()
|
self._cancel_sender.close()
|
||||||
except ClosedResourceError:
|
self.runner_process.join(5)
|
||||||
pass
|
|
||||||
self._event_sender.close()
|
|
||||||
self._close_pipe_fds()
|
|
||||||
self.runner_process.join(1)
|
|
||||||
if not self.runner_process.is_alive():
|
if not self.runner_process.is_alive():
|
||||||
logger.info("Runner process succesfully terminated")
|
logger.info("Runner process succesfully terminated")
|
||||||
return
|
return
|
||||||
@@ -229,7 +145,6 @@ class RunnerSupervisor:
|
|||||||
await event.wait()
|
await event.wait()
|
||||||
|
|
||||||
async def cancel_task(self, task_id: TaskId):
|
async def cancel_task(self, task_id: TaskId):
|
||||||
"""Send a cancellation signal to the runner process."""
|
|
||||||
if task_id in self.completed:
|
if task_id in self.completed:
|
||||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||||
return
|
return
|
||||||
@@ -271,110 +186,6 @@ class RunnerSupervisor:
|
|||||||
for tid in self.pending:
|
for tid in self.pending:
|
||||||
self.pending[tid].set()
|
self.pending[tid].set()
|
||||||
|
|
||||||
def _close_pipe_fds(self) -> None:
|
|
||||||
if self._pipe_read_fd is not None:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.close(self._pipe_read_fd)
|
|
||||||
self._pipe_read_fd = None
|
|
||||||
if self._pipe_write_fd is not None:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.close(self._pipe_write_fd)
|
|
||||||
self._pipe_write_fd = None
|
|
||||||
if self._child_pipe_fds is not None:
|
|
||||||
for fd in self._child_pipe_fds:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.close(fd)
|
|
||||||
self._child_pipe_fds = None
|
|
||||||
# Clean up FIFO files
|
|
||||||
if self._fifo_c2p is not None:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.unlink(self._fifo_c2p)
|
|
||||||
self._fifo_c2p = None
|
|
||||||
if self._fifo_p2c is not None:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.unlink(self._fifo_p2c)
|
|
||||||
self._fifo_p2c = None
|
|
||||||
if self._fifo_dir is not None:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.rmdir(self._fifo_dir)
|
|
||||||
self._fifo_dir = None
|
|
||||||
|
|
||||||
async def _pipe_relay(self) -> None:
|
|
||||||
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
|
|
||||||
assert self._pipe_read_fd is not None
|
|
||||||
assert self._pipe_write_fd is not None
|
|
||||||
read_fd = self._pipe_read_fd
|
|
||||||
write_fd = self._pipe_write_fd
|
|
||||||
sequence = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# 1. Read local data from runner: [uint32 size][size bytes]
|
|
||||||
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
|
|
||||||
if header is None:
|
|
||||||
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
|
|
||||||
break
|
|
||||||
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
|
|
||||||
local_data = await to_thread.run_sync(
|
|
||||||
partial(_pipe_read_exact, read_fd, data_size)
|
|
||||||
)
|
|
||||||
if local_data is None:
|
|
||||||
logger.warning("JACCL pipe relay: EOF reading data payload")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Emit JacclSideChannelData event
|
|
||||||
waiter = anyio.Event()
|
|
||||||
self._gathered_waiters[sequence] = (waiter, None)
|
|
||||||
await self._event_sender.send(
|
|
||||||
JacclSideChannelData(
|
|
||||||
instance_id=self.bound_instance.instance.instance_id,
|
|
||||||
runner_id=self.bound_instance.bound_runner_id,
|
|
||||||
sequence=sequence,
|
|
||||||
data=local_data,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Wait for gathered result
|
|
||||||
await waiter.wait()
|
|
||||||
_, gathered_event = self._gathered_waiters.pop(sequence)
|
|
||||||
assert gathered_event is not None
|
|
||||||
|
|
||||||
# 4. Order gathered data by runner rank and concatenate
|
|
||||||
instance = self.bound_instance.instance
|
|
||||||
assert isinstance(instance, MlxJacclInstance)
|
|
||||||
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
|
|
||||||
ordered_data = b"".join(
|
|
||||||
gathered_event.gathered_data[rid] for rid in runner_order
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
|
|
||||||
total_size = len(ordered_data)
|
|
||||||
response = struct.pack("<I", total_size) + ordered_data
|
|
||||||
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
|
|
||||||
)
|
|
||||||
sequence += 1
|
|
||||||
except OSError as e:
|
|
||||||
logger.warning(f"JACCL pipe relay: OS error: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
|
|
||||||
|
|
||||||
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
|
|
||||||
"""Called by the worker when a JacclSideChannelGathered event arrives."""
|
|
||||||
seq = event.sequence
|
|
||||||
if seq not in self._gathered_waiters:
|
|
||||||
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
|
|
||||||
return
|
|
||||||
waiter, _ = self._gathered_waiters[seq]
|
|
||||||
self._gathered_waiters[seq] = (waiter, event)
|
|
||||||
waiter.set()
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if self.runner_process.is_alive():
|
if self.runner_process.is_alive():
|
||||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||||
|
|||||||
Reference in New Issue
Block a user