Compare commits

...

3 Commits

Author SHA1 Message Date
Alex Cheema
cd497c3696 Deduplicate tasks in plan_step 2025-12-30 17:28:19 +00:00
Alex Cheema
16e2bfd3b3 log EXO_LIBP2P_NAMESPACE on start 2025-12-30 04:08:47 +00:00
Alex Cheema
ade3ee7ec5 fix warmup order. should be rank!=0 then rank=0 2025-12-30 03:29:34 +00:00
3 changed files with 12 additions and 5 deletions

View File

@@ -1,5 +1,6 @@
import argparse
import multiprocessing as mp
import os
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -194,6 +195,7 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
node = anyio.run(Node.create, args)
anyio.run(node.run)

View File

@@ -10,6 +10,7 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
from exo.shared.apply import apply
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.tasks import TaskId
from exo.shared.types.events import (
Event,
EventId,
@@ -172,6 +173,7 @@ class Worker:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
async def plan_step(self):
seen_tasks: set[TaskId] = set()
while True:
await anyio.sleep(0.1)
# 3. based on the updated state, we plan & execute an operation.
@@ -186,6 +188,10 @@ class Worker:
)
if task is None:
continue
if task.task_id in seen_tasks:
logger.warning("Worker tried to plan a duplicate task")
continue
seen_tasks.add(task.task_id)
logger.info(f"Worker plan: {task.__class__.__name__}")
assert task.task_status
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))

View File

@@ -235,9 +235,8 @@ def _ready_to_warmup(
assert device_rank < world_size
assert device_rank >= 0
# TODO: Ensure these align with MLX distributeds expectations.
# Rank < n-1
accepting_ranks_ready = device_rank < world_size - 1 and all(
# Rank != 0
accepting_ranks_ready = device_rank > 0 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -245,8 +244,8 @@ def _ready_to_warmup(
for global_runner_id in shard_assignments.runner_to_shard
)
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
# Rank = 0
connecting_rank_ready = device_rank == 0 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id