Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
17233f48ce check if we have a task before we delete it 2026-02-27 19:17:25 +00:00
10 changed files with 89 additions and 121 deletions

View File

@@ -25,7 +25,6 @@ from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.task_group import TaskGroup
from exo.worker.main import Worker
from exo.worker.runner.runner_opts import RunnerOpts
@dataclass
@@ -41,11 +40,10 @@ class Node:
node_id: NodeId
offline: bool
runner_opts: RunnerOpts
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
@staticmethod
async def create(args: "Args") -> "Node":
@classmethod
async def create(cls, args: "Args") -> Self:
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -65,28 +63,14 @@ class Node:
logger.info(f"Starting node {node_id}")
if args.fast_synch is True:
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
logger.info("FAST_SYNCH forced OFF")
runner_opts = RunnerOpts(
fast_synch_override=args.fast_synch,
trust_remote_code_override=args.trust_remote_code,
)
if offline := args.offline:
logger.info(
"Running in OFFLINE mode — no internet checks, local models only"
)
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
exo_shard_downloader(offline=offline),
exo_shard_downloader(offline=args.offline),
event_sender=event_router.sender(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
offline=offline,
offline=args.offline,
)
else:
download_coordinator = None
@@ -106,7 +90,6 @@ class Node:
if not args.no_worker:
worker = Worker(
node_id,
runner_opts,
event_receiver=event_router.receiver(),
event_sender=event_router.sender(),
command_sender=router.sender(topics.COMMANDS),
@@ -140,7 +123,7 @@ class Node:
election_result_sender=er_send,
)
return Node(
return cls(
router,
event_router,
download_coordinator,
@@ -151,7 +134,6 @@ class Node:
api,
node_id,
args.offline,
runner_opts,
)
async def run(self):
@@ -256,7 +238,6 @@ class Node:
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
self.runner_opts,
event_receiver=self.event_router.receiver(),
event_sender=self.event_router.sender(),
command_sender=self.router.sender(topics.COMMANDS),
@@ -284,6 +265,17 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
try:
anyio.run(node.run)
@@ -305,11 +297,8 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = False
offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true"
fast_synch: bool | None = None # None = auto, True = force on, False = force off
trust_remote_code: bool | None = (
None # None = auto, True = force on, False = force off
)
@classmethod
def parse(cls) -> Self:
@@ -376,20 +365,6 @@ class Args(CamelCaseModel):
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
trust_remote_code_group = parser.add_mutually_exclusive_group()
trust_remote_code_group.add_argument(
"--trust-remote-code",
action="store_true",
dest="trust_remote_code",
default=None,
help="Allow all models to execute custom code",
)
trust_remote_code_group.add_argument(
"--never-trust-remote-code",
action="store_false",
dest="trust_remote_code",
help="Deny all models from execute custom code",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -328,17 +328,22 @@ class Master:
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
else:
logger.warning(
f"Nonexistent command {command.cancelled_command_id} cancelled"
)
)
self.command_task_mapping.pop(
command.finished_command_id, None
)
case TaskFinished():
if (
task_id := self.command_task_mapping.pop(
command.finished_command_id, None
)
) is not None:
generated_events.append(TaskDeleted(task_id=task_id))
else:
logger.warning(
f"Finished command {command.finished_command_id} finished"
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time

View File

@@ -167,12 +167,10 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
trust_remote_code: bool | None,
) -> tuple[Model, TokenizerWrapper]:
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, lazy=True, strict=False)
# Eval layers one by one for progress reporting
@@ -191,10 +189,12 @@ def load_mlx_items(
mx.eval(model)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model = shard_and_load(
model, tokenizer = shard_and_load(
bound_instance.bound_shard,
group=group,
on_timeout=on_timeout,
@@ -205,14 +205,6 @@ def load_mlx_items(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
tokenizer = load_tokenizer_for_model_id(
bound_instance.bound_shard.model_card.model_id,
model_path,
trust_remote_code=trust_remote_code
if trust_remote_code is not None
else bound_instance.bound_shard.model_card.trust_remote_code,
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
mx.clear_cache()
@@ -225,8 +217,9 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
) -> nn.Module:
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
@@ -248,6 +241,8 @@ def shard_and_load(
assert isinstance(model, nn.Module)
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
# Estimate timeout based on model size (5x default for large queued workloads)
@@ -286,7 +281,16 @@ def shard_and_load(
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
return model
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=shard_metadata.model_card.trust_remote_code,
)
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:

View File

@@ -1,5 +1,4 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
import anyio
@@ -47,34 +46,38 @@ from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.utils.task_group import TaskGroup
from exo.worker.plan import plan
from exo.worker.runner.runner_opts import RunnerOpts
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@dataclass
class Worker:
node_id: NodeId
runner_opts: RunnerOpts
event_receiver: Receiver[IndexedEvent]
event_sender: Sender[Event]
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand]
download_command_sender: Sender[ForwarderDownloadCommand]
state: State = field(init=False, default_factory=State)
runners: dict[RunnerId, RunnerSupervisor] = field(init=False, default_factory=dict)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
_system_id: SystemId = field(init=False, default_factory=SystemId)
def __init__(
self,
node_id: NodeId,
*,
event_receiver: Receiver[IndexedEvent],
event_sender: Sender[Event],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.event_receiver = event_receiver
self.event_sender = event_sender
self.command_sender = command_sender
self.download_command_sender = download_command_sender
# Buffer for input image chunks (for image editing)
input_chunk_buffer: dict[CommandId, dict[int, str]] = field(
init=False, default_factory=dict
)
input_chunk_counts: dict[CommandId, int] = field(init=False, default_factory=dict)
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = TaskGroup()
_download_backoff: KeyedBackoff[ModelId] = field(
init=False, default_factory=lambda: KeyedBackoff(base=0.5, cap=10.0)
)
self._system_id = SystemId()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
async def run(self):
logger.info("Starting Worker")
@@ -280,7 +283,6 @@ class Worker:
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(
runner_opts=self.runner_opts,
bound_instance=task.bound_instance,
event_sender=self.event_sender.clone(),
)

View File

@@ -1,5 +1,4 @@
import os
import resource
import loguru
@@ -9,13 +8,10 @@ from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from .runner_opts import RunnerOpts
logger: "loguru.Logger" = loguru.logger
def entrypoint(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
@@ -24,17 +20,12 @@ def entrypoint(
) -> None:
global logger
logger = _logger
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
fast_synch_override = runner_opts.fast_synch_override
if fast_synch_override is not None:
if fast_synch_override:
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
else:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override != "off":
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
@@ -45,7 +36,7 @@ def entrypoint(
else:
from exo.worker.runner.llm_inference.runner import main
main(runner_opts, bound_instance, event_sender, task_receiver, cancel_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")

View File

@@ -1,4 +1,5 @@
import base64
import resource
import time
from typing import TYPE_CHECKING, Literal
@@ -65,7 +66,6 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
@@ -183,12 +183,14 @@ def _send_image_chunk(
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,

View File

@@ -1,4 +1,5 @@
import math
import resource
import time
from collections.abc import Generator
from functools import cache
@@ -78,18 +79,19 @@ from exo.worker.engines.mlx.utils_mlx import (
mx_any,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
from .tool_parsers import ToolParser, make_mlx_parser
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,
@@ -192,7 +194,6 @@ def main(
group,
on_timeout=on_model_load_timeout,
on_layer_loaded=on_layer_loaded,
trust_remote_code=runner_opts.trust_remote_code_override,
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"

View File

@@ -1,7 +0,0 @@
from dataclasses import dataclass
@dataclass
class RunnerOpts:
fast_synch_override: bool | None
trust_remote_code_override: bool | None

View File

@@ -34,7 +34,6 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.utils.task_group import TaskGroup
from exo.worker.runner.bootstrap import entrypoint
from exo.worker.runner.runner_opts import RunnerOpts
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -63,7 +62,6 @@ class RunnerSupervisor:
def create(
cls,
*,
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: Sender[Event],
initialize_timeout: float = 400,
@@ -75,7 +73,6 @@ class RunnerSupervisor:
runner_process = mp.Process(
target=entrypoint,
args=(
runner_opts,
bound_instance,
ev_send,
task_recv,

View File

@@ -40,7 +40,6 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.runner.runner_opts import RunnerOpts
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -185,7 +184,6 @@ def _run(tasks: Iterable[Task]):
make_nothin(mx.array([1])),
):
mlx_runner.main(
RunnerOpts(None, None),
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,