mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-27 11:46:14 -05:00
Compare commits
1 Commits
main
...
runner-opt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0da9dd56b |
@@ -25,6 +25,7 @@ from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
from exo.worker.main import Worker
|
||||
from exo.worker.runner.runner_opts import RunnerOpts
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -40,10 +41,11 @@ class Node:
|
||||
|
||||
node_id: NodeId
|
||||
offline: bool
|
||||
runner_opts: RunnerOpts
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> Self:
|
||||
@staticmethod
|
||||
async def create(args: "Args") -> "Node":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
@@ -63,14 +65,28 @@ class Node:
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
if args.fast_synch is True:
|
||||
logger.info("FAST_SYNCH forced ON")
|
||||
elif args.fast_synch is False:
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
runner_opts = RunnerOpts(
|
||||
fast_synch_override=args.fast_synch,
|
||||
trust_remote_code_override=args.trust_remote_code,
|
||||
)
|
||||
|
||||
if offline := args.offline:
|
||||
logger.info(
|
||||
"Running in OFFLINE mode — no internet checks, local models only"
|
||||
)
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
exo_shard_downloader(offline=args.offline),
|
||||
exo_shard_downloader(offline=offline),
|
||||
event_sender=event_router.sender(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
offline=args.offline,
|
||||
offline=offline,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
@@ -90,6 +106,7 @@ class Node:
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
runner_opts,
|
||||
event_receiver=event_router.receiver(),
|
||||
event_sender=event_router.sender(),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
@@ -123,7 +140,7 @@ class Node:
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
return cls(
|
||||
return Node(
|
||||
router,
|
||||
event_router,
|
||||
download_coordinator,
|
||||
@@ -134,6 +151,7 @@ class Node:
|
||||
api,
|
||||
node_id,
|
||||
args.offline,
|
||||
runner_opts,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
@@ -238,6 +256,7 @@ class Node:
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
self.runner_opts,
|
||||
event_receiver=self.event_router.receiver(),
|
||||
event_sender=self.event_router.sender(),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
@@ -265,17 +284,6 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
if args.offline:
|
||||
logger.info("Running in OFFLINE mode — no internet checks, local models only")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
logger.info("FAST_SYNCH forced ON")
|
||||
elif args.fast_synch is False:
|
||||
os.environ["EXO_FAST_SYNCH"] = "off"
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
try:
|
||||
anyio.run(node.run)
|
||||
@@ -297,8 +305,11 @@ class Args(CamelCaseModel):
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true"
|
||||
offline: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
trust_remote_code: bool | None = (
|
||||
None # None = auto, True = force on, False = force off
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -365,6 +376,20 @@ class Args(CamelCaseModel):
|
||||
dest="fast_synch",
|
||||
help="Force MLX FAST_SYNCH off",
|
||||
)
|
||||
trust_remote_code_group = parser.add_mutually_exclusive_group()
|
||||
trust_remote_code_group.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
dest="trust_remote_code",
|
||||
default=None,
|
||||
help="Allow all models to execute custom code",
|
||||
)
|
||||
trust_remote_code_group.add_argument(
|
||||
"--never-trust-remote-code",
|
||||
action="store_false",
|
||||
dest="trust_remote_code",
|
||||
help="Deny all models from execute custom code",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -167,10 +167,12 @@ def load_mlx_items(
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
trust_remote_code: bool | None,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
|
||||
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
# Eval layers one by one for progress reporting
|
||||
@@ -189,12 +191,10 @@ def load_mlx_items(
|
||||
mx.eval(model)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
model = shard_and_load(
|
||||
bound_instance.bound_shard,
|
||||
group=group,
|
||||
on_timeout=on_timeout,
|
||||
@@ -205,6 +205,14 @@ def load_mlx_items(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(
|
||||
bound_instance.bound_shard.model_card.model_id,
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code
|
||||
if trust_remote_code is not None
|
||||
else bound_instance.bound_shard.model_card.trust_remote_code,
|
||||
)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
mx.clear_cache()
|
||||
@@ -217,9 +225,8 @@ def shard_and_load(
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
on_layer_loaded: LayerLoadedCallback | None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
) -> nn.Module:
|
||||
model_path = build_model_path(shard_metadata.model_card.model_id)
|
||||
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
logger.debug(model)
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
@@ -241,8 +248,6 @@ def shard_and_load(
|
||||
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
tokenizer = get_tokenizer(model_path, shard_metadata)
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
@@ -281,16 +286,7 @@ def shard_and_load(
|
||||
# Synchronize processes before generation to avoid timeout
|
||||
mx_barrier(group)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(
|
||||
shard_metadata.model_card.model_id,
|
||||
model_path,
|
||||
trust_remote_code=shard_metadata.model_card.trust_remote_code,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import anyio
|
||||
@@ -46,38 +47,34 @@ from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
from exo.utils.task_group import TaskGroup
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_opts import RunnerOpts
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId,
|
||||
*,
|
||||
event_receiver: Receiver[IndexedEvent],
|
||||
event_sender: Sender[Event],
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.event_receiver = event_receiver
|
||||
self.event_sender = event_sender
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
node_id: NodeId
|
||||
runner_opts: RunnerOpts
|
||||
event_receiver: Receiver[IndexedEvent]
|
||||
event_sender: Sender[Event]
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand]
|
||||
download_command_sender: Sender[ForwarderDownloadCommand]
|
||||
state: State = field(init=False, default_factory=State)
|
||||
runners: dict[RunnerId, RunnerSupervisor] = field(init=False, default_factory=dict)
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
_system_id: SystemId = field(init=False, default_factory=SystemId)
|
||||
|
||||
self.state: State = State()
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
# Buffer for input image chunks (for image editing)
|
||||
input_chunk_buffer: dict[CommandId, dict[int, str]] = field(
|
||||
init=False, default_factory=dict
|
||||
)
|
||||
input_chunk_counts: dict[CommandId, int] = field(init=False, default_factory=dict)
|
||||
|
||||
self._system_id = SystemId()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
|
||||
_download_backoff: KeyedBackoff[ModelId] = field(
|
||||
init=False, default_factory=lambda: KeyedBackoff(base=0.5, cap=10.0)
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
@@ -283,6 +280,7 @@ class Worker:
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
runner_opts=self.runner_opts,
|
||||
bound_instance=task.bound_instance,
|
||||
event_sender=self.event_sender.clone(),
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import resource
|
||||
|
||||
import loguru
|
||||
|
||||
@@ -8,10 +9,13 @@ from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
from .runner_opts import RunnerOpts
|
||||
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
|
||||
def entrypoint(
|
||||
runner_opts: RunnerOpts,
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
@@ -20,12 +24,17 @@ def entrypoint(
|
||||
) -> None:
|
||||
global logger
|
||||
logger = _logger
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override != "off":
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
fast_synch_override = runner_opts.fast_synch_override
|
||||
if fast_synch_override is not None:
|
||||
if fast_synch_override:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
@@ -36,7 +45,7 @@ def entrypoint(
|
||||
else:
|
||||
from exo.worker.runner.llm_inference.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
main(runner_opts, bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import base64
|
||||
import resource
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
@@ -66,6 +65,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
from exo.worker.runner.runner_opts import RunnerOpts
|
||||
|
||||
|
||||
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
|
||||
@@ -183,14 +183,12 @@ def _send_image_chunk(
|
||||
|
||||
|
||||
def main(
|
||||
runner_opts: RunnerOpts,
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
bound_instance.bound_runner_id,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import math
|
||||
import resource
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
@@ -79,19 +78,18 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
from exo.worker.runner.runner_opts import RunnerOpts
|
||||
|
||||
from .tool_parsers import ToolParser, make_mlx_parser
|
||||
|
||||
|
||||
def main(
|
||||
runner_opts: RunnerOpts,
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
bound_instance.bound_runner_id,
|
||||
@@ -194,6 +192,7 @@ def main(
|
||||
group,
|
||||
on_timeout=on_model_load_timeout,
|
||||
on_layer_loaded=on_layer_loaded,
|
||||
trust_remote_code=runner_opts.trust_remote_code_override,
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
|
||||
|
||||
7
src/exo/worker/runner/runner_opts.py
Normal file
7
src/exo/worker/runner/runner_opts.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunnerOpts:
|
||||
fast_synch_override: bool | None
|
||||
trust_remote_code_override: bool | None
|
||||
@@ -34,6 +34,7 @@ from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
from exo.worker.runner.runner_opts import RunnerOpts
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
DECODE_TIMEOUT_SECONDS = 5
|
||||
@@ -62,6 +63,7 @@ class RunnerSupervisor:
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
runner_opts: RunnerOpts,
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: Sender[Event],
|
||||
initialize_timeout: float = 400,
|
||||
@@ -73,6 +75,7 @@ class RunnerSupervisor:
|
||||
runner_process = mp.Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
runner_opts,
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
|
||||
@@ -40,6 +40,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.runner.runner_opts import RunnerOpts
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
@@ -184,6 +185,7 @@ def _run(tasks: Iterable[Task]):
|
||||
make_nothin(mx.array([1])),
|
||||
):
|
||||
mlx_runner.main(
|
||||
RunnerOpts(None, None),
|
||||
bound_instance,
|
||||
event_sender, # pyright: ignore[reportArgumentType]
|
||||
task_receiver,
|
||||
|
||||
Reference in New Issue
Block a user