Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
b0da9dd56b runner opts 2026-02-26 17:51:31 +00:00
9 changed files with 111 additions and 74 deletions

View File

@@ -25,6 +25,7 @@ from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.task_group import TaskGroup
from exo.worker.main import Worker
from exo.worker.runner.runner_opts import RunnerOpts
@dataclass
@@ -40,10 +41,11 @@ class Node:
node_id: NodeId
offline: bool
runner_opts: RunnerOpts
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
@classmethod
async def create(cls, args: "Args") -> Self:
@staticmethod
async def create(args: "Args") -> "Node":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -63,14 +65,28 @@ class Node:
logger.info(f"Starting node {node_id}")
if args.fast_synch is True:
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
logger.info("FAST_SYNCH forced OFF")
runner_opts = RunnerOpts(
fast_synch_override=args.fast_synch,
trust_remote_code_override=args.trust_remote_code,
)
if offline := args.offline:
logger.info(
"Running in OFFLINE mode — no internet checks, local models only"
)
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
exo_shard_downloader(offline=args.offline),
exo_shard_downloader(offline=offline),
event_sender=event_router.sender(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
offline=args.offline,
offline=offline,
)
else:
download_coordinator = None
@@ -90,6 +106,7 @@ class Node:
if not args.no_worker:
worker = Worker(
node_id,
runner_opts,
event_receiver=event_router.receiver(),
event_sender=event_router.sender(),
command_sender=router.sender(topics.COMMANDS),
@@ -123,7 +140,7 @@ class Node:
election_result_sender=er_send,
)
return cls(
return Node(
router,
event_router,
download_coordinator,
@@ -134,6 +151,7 @@ class Node:
api,
node_id,
args.offline,
runner_opts,
)
async def run(self):
@@ -238,6 +256,7 @@ class Node:
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
self.runner_opts,
event_receiver=self.event_router.receiver(),
event_sender=self.event_router.sender(),
command_sender=self.router.sender(topics.COMMANDS),
@@ -265,17 +284,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
try:
anyio.run(node.run)
@@ -297,8 +305,11 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true"
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
trust_remote_code: bool | None = (
None # None = auto, True = force on, False = force off
)
@classmethod
def parse(cls) -> Self:
@@ -365,6 +376,20 @@ class Args(CamelCaseModel):
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
trust_remote_code_group = parser.add_mutually_exclusive_group()
trust_remote_code_group.add_argument(
"--trust-remote-code",
action="store_true",
dest="trust_remote_code",
default=None,
help="Allow all models to execute custom code",
)
trust_remote_code_group.add_argument(
"--never-trust-remote-code",
action="store_false",
dest="trust_remote_code",
help="Deny all models from execute custom code",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -167,10 +167,12 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
trust_remote_code: bool | None,
) -> tuple[Model, TokenizerWrapper]:
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, lazy=True, strict=False)
# Eval layers one by one for progress reporting
@@ -189,12 +191,10 @@ def load_mlx_items(
mx.eval(model)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
model = shard_and_load(
bound_instance.bound_shard,
group=group,
on_timeout=on_timeout,
@@ -205,6 +205,14 @@ def load_mlx_items(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
tokenizer = load_tokenizer_for_model_id(
bound_instance.bound_shard.model_card.model_id,
model_path,
trust_remote_code=trust_remote_code
if trust_remote_code is not None
else bound_instance.bound_shard.model_card.trust_remote_code,
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
mx.clear_cache()
@@ -217,9 +225,8 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
) -> tuple[nn.Module, TokenizerWrapper]:
) -> nn.Module:
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
@@ -241,8 +248,6 @@ def shard_and_load(
assert isinstance(model, nn.Module)
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
# Estimate timeout based on model size (5x default for large queued workloads)
@@ -281,16 +286,7 @@ def shard_and_load(
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=shard_metadata.model_card.trust_remote_code,
)
return model
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:

View File

@@ -1,4 +1,5 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
import anyio
@@ -46,38 +47,34 @@ from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.utils.task_group import TaskGroup
from exo.worker.plan import plan
from exo.worker.runner.runner_opts import RunnerOpts
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@dataclass
class Worker:
def __init__(
self,
node_id: NodeId,
*,
event_receiver: Receiver[IndexedEvent],
event_sender: Sender[Event],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.event_receiver = event_receiver
self.event_sender = event_sender
self.command_sender = command_sender
self.download_command_sender = download_command_sender
node_id: NodeId
runner_opts: RunnerOpts
event_receiver: Receiver[IndexedEvent]
event_sender: Sender[Event]
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand]
download_command_sender: Sender[ForwarderDownloadCommand]
state: State = field(init=False, default_factory=State)
runners: dict[RunnerId, RunnerSupervisor] = field(init=False, default_factory=dict)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
_system_id: SystemId = field(init=False, default_factory=SystemId)
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = TaskGroup()
# Buffer for input image chunks (for image editing)
input_chunk_buffer: dict[CommandId, dict[int, str]] = field(
init=False, default_factory=dict
)
input_chunk_counts: dict[CommandId, int] = field(init=False, default_factory=dict)
self._system_id = SystemId()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
_download_backoff: KeyedBackoff[ModelId] = field(
init=False, default_factory=lambda: KeyedBackoff(base=0.5, cap=10.0)
)
async def run(self):
logger.info("Starting Worker")
@@ -283,6 +280,7 @@ class Worker:
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(
runner_opts=self.runner_opts,
bound_instance=task.bound_instance,
event_sender=self.event_sender.clone(),
)

View File

@@ -1,4 +1,5 @@
import os
import resource
import loguru
@@ -8,10 +9,13 @@ from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from .runner_opts import RunnerOpts
logger: "loguru.Logger" = loguru.logger
def entrypoint(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
@@ -20,12 +24,17 @@ def entrypoint(
) -> None:
global logger
logger = _logger
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override != "off":
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
fast_synch_override = runner_opts.fast_synch_override
if fast_synch_override is not None:
if fast_synch_override:
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
@@ -36,7 +45,7 @@ def entrypoint(
else:
from exo.worker.runner.llm_inference.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(runner_opts, bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")

View File

@@ -1,5 +1,4 @@
import base64
import resource
import time
from typing import TYPE_CHECKING, Literal
@@ -66,6 +65,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
@@ -183,14 +183,12 @@ def _send_image_chunk(
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,

View File

@@ -1,5 +1,4 @@
import math
import resource
import time
from collections.abc import Generator
from functools import cache
@@ -79,19 +78,18 @@ from exo.worker.engines.mlx.utils_mlx import (
mx_any,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
from .tool_parsers import ToolParser, make_mlx_parser
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,
@@ -194,6 +192,7 @@ def main(
group,
on_timeout=on_model_load_timeout,
on_layer_loaded=on_layer_loaded,
trust_remote_code=runner_opts.trust_remote_code_override,
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"

View File

@@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass
class RunnerOpts:
fast_synch_override: bool | None
trust_remote_code_override: bool | None

View File

@@ -34,6 +34,7 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.utils.task_group import TaskGroup
from exo.worker.runner.bootstrap import entrypoint
from exo.worker.runner.runner_opts import RunnerOpts
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -62,6 +63,7 @@ class RunnerSupervisor:
def create(
cls,
*,
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: Sender[Event],
initialize_timeout: float = 400,
@@ -73,6 +75,7 @@ class RunnerSupervisor:
runner_process = mp.Process(
target=entrypoint,
args=(
runner_opts,
bound_instance,
ev_send,
task_recv,

View File

@@ -40,6 +40,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.runner.runner_opts import RunnerOpts
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -184,6 +185,7 @@ def _run(tasks: Iterable[Task]):
make_nothin(mx.array([1])),
):
mlx_runner.main(
RunnerOpts(None, None),
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,