Compare commits

...

2 Commits

Author SHA1 Message Date
Alex Cheema
e5c31e50f3 feat: add JACCL SideChannel pipe relay for distributed tensor ops
Implement named-pipe (FIFO) based relay for JACCL all_gather operations
across the exo control plane, enabling distributed tensor operations
between MlxJaccl runner instances.

Components:
- Base64Bytes type + JacclSideChannelData/Gathered event types
- RunnerSupervisor: FIFO creation, _pipe_relay() async loop that reads
  local data from runner, emits events, waits for gathered result, and
  writes ordered data back
- Bootstrap: opens FIFOs in child process, sets MLX_JACCL_PIPE_IN/OUT
  env vars for C++ SideChannel
- Worker: routes JacclSideChannelGathered events to RunnerSupervisors
- Master: _handle_jaccl_side_channel() accumulates per-runner data and
  emits gathered event when all runners for an instance have contributed
- mx_any() docstring explaining all_sum for GPU deadlock prevention

Extracted from meta-instance branch (#1519) — PR 4 of 5.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 06:05:03 -08:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
8 changed files with 343 additions and 20 deletions

View File

@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
@@ -219,13 +225,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
yield _format_sse(created_event)
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
yield _format_sse(in_progress_event)
# response.output_item.added
initial_item = ResponseMessageItem(
@@ -236,7 +242,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
yield _format_sse(item_added)
# response.content_part.added
initial_part = ResponseOutputText(text="")
@@ -247,7 +253,7 @@ async def generate_responses_stream(
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
yield _format_sse(part_added)
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
@@ -281,7 +287,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
yield _format_sse(fc_added)
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -290,7 +296,7 @@ async def generate_responses_stream(
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
yield _format_sse(args_delta)
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -300,7 +306,7 @@ async def generate_responses_stream(
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
yield _format_sse(args_done)
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
@@ -315,7 +321,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item)
next_output_index += 1
@@ -331,7 +337,7 @@ async def generate_responses_stream(
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
yield _format_sse(delta_event)
# response.output_text.done
text_done = ResponseTextDoneEvent(
@@ -341,7 +347,7 @@ async def generate_responses_stream(
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
yield _format_sse(text_done)
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
@@ -352,7 +358,7 @@ async def generate_responses_stream(
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
yield _format_sse(part_done)
# response.output_item.done
final_message_item = ResponseMessageItem(
@@ -363,7 +369,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
yield _format_sse(item_done)
# Create usage from usage data if available
usage = None
@@ -388,4 +394,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
yield _format_sse(completed_event)

View File

@@ -36,6 +36,8 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
JacclSideChannelData,
JacclSideChannelGathered,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
@@ -60,6 +62,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
@@ -94,6 +97,7 @@ class Master:
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
async def run(self):
logger.info("Starting Master")
@@ -407,6 +411,11 @@ class Master:
self._event_log.append(event)
await self._send_event(indexed)
# After broadcasting JacclSideChannelData, accumulate and
# emit gathered result when all runners have contributed.
if isinstance(event, JacclSideChannelData):
await self._handle_jaccl_side_channel(event)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
@@ -460,3 +469,42 @@ class Master:
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
"""Accumulate SideChannel contributions; when all runners for an instance
have submitted for the same sequence, emit JacclSideChannelGathered."""
iid = event.instance_id
seq = event.sequence
if iid not in self._jaccl_pending:
self._jaccl_pending[iid] = {}
if seq not in self._jaccl_pending[iid]:
self._jaccl_pending[iid][seq] = {}
self._jaccl_pending[iid][seq][event.runner_id] = event.data
instance = self.state.instances.get(iid)
if instance is None:
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
return
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
submitted = set(self._jaccl_pending[iid][seq].keys())
logger.info(
f"JACCL side channel: instance={iid} seq={seq} "
f"submitted={len(submitted)}/{len(expected_runners)}"
)
if submitted >= expected_runners:
gathered = dict(self._jaccl_pending[iid][seq])
del self._jaccl_pending[iid][seq]
if not self._jaccl_pending[iid]:
del self._jaccl_pending[iid]
await self.event_sender.send(
JacclSideChannelGathered(
instance_id=iid,
sequence=seq,
gathered_data=gathered,
)
)

View File

@@ -12,6 +12,8 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
JacclSideChannelData,
JacclSideChannelGathered,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -68,6 +70,8 @@ def event_apply(event: Event, state: State) -> State:
| PrefillProgress()
| TracesCollected()
| TracesMerged()
| JacclSideChannelData()
| JacclSideChannelGathered()
): # Pass-through events that don't modify state
return state
case InstanceCreated():

View File

@@ -1,7 +1,9 @@
import base64
from collections.abc import Mapping
from datetime import datetime
from typing import final
from typing import Annotated, final
from pydantic import Field
from pydantic import BeforeValidator, Field, PlainSerializer
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
@@ -14,6 +16,28 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
def _decode_base64_bytes(v: bytes | str) -> bytes:
if isinstance(v, bytes):
return v
return base64.b64decode(v)
def _encode_base64_bytes(v: bytes) -> str:
return base64.b64encode(v).decode("ascii")
Base64Bytes = Annotated[
bytes,
BeforeValidator(_decode_base64_bytes),
PlainSerializer(_encode_base64_bytes, return_type=str),
]
"""bytes that serialize to/from base64 strings in JSON.
Needed because TaggedModel's wrap validator converts JSON→Python validation
context, which breaks strict-mode bytes deserialization from JSON strings.
"""
class EventId(Id):
"""
Newtype around `ID`
@@ -139,6 +163,25 @@ class TracesMerged(BaseEvent):
traces: list[TraceEventData]
@final
class JacclSideChannelData(BaseEvent):
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
instance_id: InstanceId
runner_id: RunnerId
sequence: int
data: Base64Bytes
@final
class JacclSideChannelGathered(BaseEvent):
"""Gathered result of a JACCL SideChannel all_gather round."""
instance_id: InstanceId
sequence: int
gathered_data: Mapping[RunnerId, Base64Bytes]
Event = (
TestEvent
| TaskCreated
@@ -160,6 +203,8 @@ Event = (
| TopologyEdgeDeleted
| TracesCollected
| TracesMerged
| JacclSideChannelData
| JacclSideChannelGathered
)

View File

@@ -643,6 +643,11 @@ def mlx_cleanup(
def mx_any(bool_: bool, group: Group | None) -> bool:
"""Synchronize a boolean across all distributed nodes.
Returns True if any node has bool_=True. Uses all_sum so every
node participates in the collective — preventing GPU deadlocks.
"""
if group is None:
return bool_
num_true = mx.distributed.all_sum(

View File

@@ -24,6 +24,7 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
JacclSideChannelGathered,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -159,6 +160,15 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
if isinstance(event, JacclSideChannelGathered):
for runner in self.runners.values():
if (
runner.bound_instance.instance.instance_id
== event.instance_id
):
runner.notify_gathered(event)
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id

View File

@@ -17,6 +17,7 @@ def entrypoint(
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
pipe_fifo_paths: tuple[str, str] | None = None,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -30,6 +31,16 @@ def entrypoint(
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
if pipe_fifo_paths is not None:
fifo_c2p, fifo_p2c = pipe_fifo_paths
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
global logger
logger = _logger

View File

@@ -1,6 +1,10 @@
import contextlib
import os
import signal
import struct
import tempfile
from dataclasses import dataclass, field
from functools import partial
from multiprocessing import Process
from typing import Self
@@ -14,12 +18,14 @@ from loguru import logger
from exo.shared.types.events import (
Event,
JacclSideChannelData,
JacclSideChannelGathered,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import (
RunnerConnecting,
RunnerFailed,
@@ -34,6 +40,26 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.worker.runner.bootstrap import entrypoint
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
data = b""
while len(data) < n:
chunk = os.read(fd, n - len(data))
if not chunk:
return None
data += chunk
return data
def _pipe_write_all(fd: int, data: bytes) -> None:
"""Write all bytes to a file descriptor."""
view = memoryview(data)
while view:
written = os.write(fd, view)
view = view[written:]
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -48,10 +74,19 @@ class RunnerSupervisor:
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_pipe_read_fd: int | None = None # Python reads runner's pipe output
_pipe_write_fd: int | None = None # Python writes gathered data to runner
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
_gathered_waiters: dict[
int, tuple[anyio.Event, JacclSideChannelGathered | None]
] = field(default_factory=dict, init=False)
@classmethod
def create(
@@ -65,6 +100,23 @@ class RunnerSupervisor:
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
# Named pipes work across multiprocessing.Process spawn (macOS default).
# FIFO c2p: C++ writes local data → Python reads it
# FIFO p2c: Python writes gathered data → C++ reads it
fifo_dir: str | None = None
fifo_c2p: str | None = None
fifo_p2c: str | None = None
pipe_fifo_paths: tuple[str, str] | None = None
if isinstance(bound_instance.instance, MlxJacclInstance):
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
os.mkfifo(fifo_c2p)
os.mkfifo(fifo_p2c)
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
runner_process = Process(
target=entrypoint,
args=(
@@ -73,6 +125,7 @@ class RunnerSupervisor:
task_recv,
cancel_recv,
logger,
pipe_fifo_paths,
),
daemon=True,
)
@@ -88,22 +141,58 @@ class RunnerSupervisor:
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
_fifo_dir=fifo_dir,
_fifo_c2p=fifo_c2p,
_fifo_p2c=fifo_p2c,
)
return self
async def run(self):
self.runner_process.start()
await self._forward_events()
if self._fifo_c2p is not None and self._fifo_p2c is not None:
# Open FIFOs from parent side. These block until child opens the other end,
# so we run them in threads concurrently to avoid deadlock.
fifo_c2p = self._fifo_c2p
fifo_p2c = self._fifo_p2c
async def open_read() -> None:
self._pipe_read_fd = await to_thread.run_sync(
partial(os.open, fifo_c2p, os.O_RDONLY)
)
async def open_write() -> None:
self._pipe_write_fd = await to_thread.run_sync(
partial(os.open, fifo_p2c, os.O_WRONLY)
)
async with anyio.create_task_group() as open_tg:
open_tg.start_soon(open_read)
open_tg.start_soon(open_write)
logger.info(
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
)
async with anyio.create_task_group() as tg:
tg.start_soon(self._pipe_relay)
tg.start_soon(self._forward_events)
else:
await self._forward_events()
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
try:
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
except ClosedResourceError:
pass
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(5)
self._close_pipe_fds()
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return
@@ -140,6 +229,7 @@ class RunnerSupervisor:
await event.wait()
async def cancel_task(self, task_id: TaskId):
"""Send a cancellation signal to the runner process."""
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
@@ -181,6 +271,110 @@ class RunnerSupervisor:
for tid in self.pending:
self.pending[tid].set()
def _close_pipe_fds(self) -> None:
if self._pipe_read_fd is not None:
with contextlib.suppress(OSError):
os.close(self._pipe_read_fd)
self._pipe_read_fd = None
if self._pipe_write_fd is not None:
with contextlib.suppress(OSError):
os.close(self._pipe_write_fd)
self._pipe_write_fd = None
if self._child_pipe_fds is not None:
for fd in self._child_pipe_fds:
with contextlib.suppress(OSError):
os.close(fd)
self._child_pipe_fds = None
# Clean up FIFO files
if self._fifo_c2p is not None:
with contextlib.suppress(OSError):
os.unlink(self._fifo_c2p)
self._fifo_c2p = None
if self._fifo_p2c is not None:
with contextlib.suppress(OSError):
os.unlink(self._fifo_p2c)
self._fifo_p2c = None
if self._fifo_dir is not None:
with contextlib.suppress(OSError):
os.rmdir(self._fifo_dir)
self._fifo_dir = None
async def _pipe_relay(self) -> None:
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
assert self._pipe_read_fd is not None
assert self._pipe_write_fd is not None
read_fd = self._pipe_read_fd
write_fd = self._pipe_write_fd
sequence = 0
try:
while True:
# 1. Read local data from runner: [uint32 size][size bytes]
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
if header is None:
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
break
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
local_data = await to_thread.run_sync(
partial(_pipe_read_exact, read_fd, data_size)
)
if local_data is None:
logger.warning("JACCL pipe relay: EOF reading data payload")
break
logger.info(
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
)
# 2. Emit JacclSideChannelData event
waiter = anyio.Event()
self._gathered_waiters[sequence] = (waiter, None)
await self._event_sender.send(
JacclSideChannelData(
instance_id=self.bound_instance.instance.instance_id,
runner_id=self.bound_instance.bound_runner_id,
sequence=sequence,
data=local_data,
)
)
# 3. Wait for gathered result
await waiter.wait()
_, gathered_event = self._gathered_waiters.pop(sequence)
assert gathered_event is not None
# 4. Order gathered data by runner rank and concatenate
instance = self.bound_instance.instance
assert isinstance(instance, MlxJacclInstance)
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
ordered_data = b"".join(
gathered_event.gathered_data[rid] for rid in runner_order
)
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
total_size = len(ordered_data)
response = struct.pack("<I", total_size) + ordered_data
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
logger.info(
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
)
sequence += 1
except OSError as e:
logger.warning(f"JACCL pipe relay: OS error: {e}")
except Exception as e:
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
"""Called by the worker when a JacclSideChannelGathered event arrives."""
seq = event.sequence
if seq not in self._gathered_waiters:
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
return
waiter, _ = self._gathered_waiters[seq]
self._gathered_waiters[seq] = (waiter, event)
waiter.set()
def __del__(self) -> None:
if self.runner_process.is_alive():
logger.warning("RunnerSupervisor was not stopped cleanly.")