Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
b0da9dd56b runner opts 2026-02-26 17:51:31 +00:00
19 changed files with 790 additions and 1016 deletions

View File

@@ -25,6 +25,7 @@ from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.task_group import TaskGroup
from exo.worker.main import Worker
from exo.worker.runner.runner_opts import RunnerOpts
@dataclass
@@ -40,10 +41,11 @@ class Node:
node_id: NodeId
offline: bool
runner_opts: RunnerOpts
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
@classmethod
async def create(cls, args: "Args") -> Self:
@staticmethod
async def create(args: "Args") -> "Node":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -63,14 +65,28 @@ class Node:
logger.info(f"Starting node {node_id}")
if args.fast_synch is True:
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
logger.info("FAST_SYNCH forced OFF")
runner_opts = RunnerOpts(
fast_synch_override=args.fast_synch,
trust_remote_code_override=args.trust_remote_code,
)
if offline := args.offline:
logger.info(
"Running in OFFLINE mode — no internet checks, local models only"
)
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
exo_shard_downloader(offline=args.offline),
exo_shard_downloader(offline=offline),
event_sender=event_router.sender(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
offline=args.offline,
offline=offline,
)
else:
download_coordinator = None
@@ -90,6 +106,7 @@ class Node:
if not args.no_worker:
worker = Worker(
node_id,
runner_opts,
event_receiver=event_router.receiver(),
event_sender=event_router.sender(),
command_sender=router.sender(topics.COMMANDS),
@@ -123,7 +140,7 @@ class Node:
election_result_sender=er_send,
)
return cls(
return Node(
router,
event_router,
download_coordinator,
@@ -134,6 +151,7 @@ class Node:
api,
node_id,
args.offline,
runner_opts,
)
async def run(self):
@@ -238,6 +256,7 @@ class Node:
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
self.runner_opts,
event_receiver=self.event_router.receiver(),
event_sender=self.event_router.sender(),
command_sender=self.router.sender(topics.COMMANDS),
@@ -265,17 +284,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
try:
anyio.run(node.run)
@@ -297,8 +305,11 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true"
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
trust_remote_code: bool | None = (
None # None = auto, True = force on, False = force off
)
@classmethod
def parse(cls) -> Self:
@@ -365,6 +376,20 @@ class Args(CamelCaseModel):
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
trust_remote_code_group = parser.add_mutually_exclusive_group()
trust_remote_code_group.add_argument(
"--trust-remote-code",
action="store_true",
dest="trust_remote_code",
default=None,
help="Allow all models to execute custom code",
)
trust_remote_code_group.add_argument(
"--never-trust-remote-code",
action="store_false",
dest="trust_remote_code",
help="Deny all models from execute custom code",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -258,6 +258,6 @@ def get_node_id_keypair(
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate()
keypair = Keypair.generate_ed25519()
f.write(keypair.to_bytes())
return keypair

View File

@@ -1,6 +1,5 @@
import contextlib
import multiprocessing as mp
from collections.abc import Generator
from dataclasses import dataclass, field
from math import inf
from multiprocessing.synchronize import Event
@@ -283,54 +282,6 @@ class MpReceiver[T]:
return d
class NonBlockingGenerator[T](Generator[T | None, None, None]):
def __init__(self, source: MpReceiver[T] | Generator[T | None, None, None]) -> None:
self._receiver: MpReceiver[T] | None = None
self._inner: Generator[T | None, None, None] | None = None
if isinstance(source, MpReceiver):
self._receiver = source
else:
self._inner = source
self._exhausted = False
def send(self, value: None, /) -> T | None:
if self._exhausted:
raise StopIteration
if self._inner is not None:
try:
return next(self._inner)
except (StopIteration, ClosedResourceError):
self._exhausted = True
raise StopIteration from None
assert self._receiver is not None
try:
return self._receiver.receive_nowait()
except WouldBlock:
return None
except (EndOfStream, ClosedResourceError):
self._exhausted = True
raise StopIteration from None
def throw(
self,
typ: type[BaseException] | BaseException,
val: BaseException | object = None,
tb: TracebackType | None = None,
/,
) -> T | None:
raise StopIteration
@property
def is_exhausted(self) -> bool:
return self._exhausted
def try_receive(self) -> T | None:
try:
return next(self)
except StopIteration:
return None
class channel[T]: # noqa: N801
"""Create a pair of asynchronous channels for communicating within the same process"""

View File

@@ -437,7 +437,6 @@ def mlx_generate(
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
distributed_prompt_progress_callback: Callable[[], None] | None = None,
on_generation_token: Callable[[], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -645,9 +644,6 @@ def mlx_generate(
full_prompt_tokens, caches, cache_snapshots
)
if on_generation_token is not None:
on_generation_token()
yield GenerationResponse(
text=text,
token=out.token,

View File

@@ -167,10 +167,12 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
trust_remote_code: bool | None,
) -> tuple[Model, TokenizerWrapper]:
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, lazy=True, strict=False)
# Eval layers one by one for progress reporting
@@ -189,12 +191,10 @@ def load_mlx_items(
mx.eval(model)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
model = shard_and_load(
bound_instance.bound_shard,
group=group,
on_timeout=on_timeout,
@@ -205,6 +205,14 @@ def load_mlx_items(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
tokenizer = load_tokenizer_for_model_id(
bound_instance.bound_shard.model_card.model_id,
model_path,
trust_remote_code=trust_remote_code
if trust_remote_code is not None
else bound_instance.bound_shard.model_card.trust_remote_code,
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
mx.clear_cache()
@@ -217,9 +225,8 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None,
on_layer_loaded: LayerLoadedCallback | None,
) -> tuple[nn.Module, TokenizerWrapper]:
) -> nn.Module:
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
@@ -241,8 +248,6 @@ def shard_and_load(
assert isinstance(model, nn.Module)
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
# Estimate timeout based on model size (5x default for large queued workloads)
@@ -281,16 +286,7 @@ def shard_and_load(
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(
shard_metadata.model_card.model_id,
model_path,
trust_remote_code=shard_metadata.model_card.trust_remote_code,
)
return model
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:

View File

@@ -1,4 +1,5 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
import anyio
@@ -46,38 +47,34 @@ from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.utils.task_group import TaskGroup
from exo.worker.plan import plan
from exo.worker.runner.runner_opts import RunnerOpts
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@dataclass
class Worker:
def __init__(
self,
node_id: NodeId,
*,
event_receiver: Receiver[IndexedEvent],
event_sender: Sender[Event],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.event_receiver = event_receiver
self.event_sender = event_sender
self.command_sender = command_sender
self.download_command_sender = download_command_sender
node_id: NodeId
runner_opts: RunnerOpts
event_receiver: Receiver[IndexedEvent]
event_sender: Sender[Event]
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand]
download_command_sender: Sender[ForwarderDownloadCommand]
state: State = field(init=False, default_factory=State)
runners: dict[RunnerId, RunnerSupervisor] = field(init=False, default_factory=dict)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
_system_id: SystemId = field(init=False, default_factory=SystemId)
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = TaskGroup()
# Buffer for input image chunks (for image editing)
input_chunk_buffer: dict[CommandId, dict[int, str]] = field(
init=False, default_factory=dict
)
input_chunk_counts: dict[CommandId, int] = field(init=False, default_factory=dict)
self._system_id = SystemId()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
_download_backoff: KeyedBackoff[ModelId] = field(
init=False, default_factory=lambda: KeyedBackoff(base=0.5, cap=10.0)
)
async def run(self):
logger.info("Starting Worker")
@@ -283,6 +280,7 @@ class Worker:
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(
runner_opts=self.runner_opts,
bound_instance=task.bound_instance,
event_sender=self.event_sender.clone(),
)

View File

@@ -297,10 +297,10 @@ def _pending_tasks(
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed or task.task_id in runner.pending:
if task.task_id in runner.completed:
continue
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

@@ -1,4 +1,5 @@
import os
import resource
import loguru
@@ -8,10 +9,13 @@ from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from .runner_opts import RunnerOpts
logger: "loguru.Logger" = loguru.logger
def entrypoint(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
@@ -20,12 +24,17 @@ def entrypoint(
) -> None:
global logger
logger = _logger
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override != "off":
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
fast_synch_override = runner_opts.fast_synch_override
if fast_synch_override is not None:
if fast_synch_override:
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
@@ -33,15 +42,10 @@ def entrypoint(
try:
if bound_instance.is_image_model:
from exo.worker.runner.image_models.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
else:
from exo.worker.runner.llm_inference.runner import Runner
from exo.worker.runner.llm_inference.runner import main
runner = Runner(
bound_instance, event_sender, task_receiver, cancel_receiver
)
runner.main()
main(runner_opts, bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")

View File

@@ -1,5 +1,4 @@
import base64
import resource
import time
from typing import TYPE_CHECKING, Literal
@@ -66,6 +65,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
)
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.runner_opts import RunnerOpts
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
@@ -183,14 +183,12 @@ def _send_image_chunk(
def main(
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,

View File

@@ -1,178 +0,0 @@
from collections import deque
from collections.abc import Generator
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
from exo.shared.types.common import ModelId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.mlx import Model
from exo.shared.types.tasks import TaskId, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import PrefillCancelled, mlx_generate
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
mx_any,
)
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
"""Check for debug prompt triggers in the input."""
import time
from exo.worker.engines.mlx.utils_mlx import mlx_force_oom
if len(task_params.input) == 0:
return
prompt = task_params.input[0].content
if not prompt:
return
if EXO_RUNNER_MUST_FAIL in prompt:
raise Exception("Artificial runner exception - for testing purposes only.")
if EXO_RUNNER_MUST_OOM in prompt:
mlx_force_oom()
if EXO_RUNNER_MUST_TIMEOUT in prompt:
time.sleep(100)
@dataclass(eq=False)
class BatchGenerator:
model: Model
tokenizer: TokenizerWrapper
group: mx.distributed.Group | None
kv_prefix_cache: KVPrefixCache | None
model_id: ModelId
device_rank: int
cancel_receiver: MpReceiver[TaskId]
cancelled_tasks: set[TaskId]
event_sender: MpSender[Event]
check_for_cancel_every: int
_queue: deque[tuple[TextGeneration, MpSender[GenerationResponse]]] = field(
default_factory=deque, init=False
)
_active: (
tuple[
TextGeneration,
MpSender[GenerationResponse],
Generator[GenerationResponse],
]
| None
) = field(default=None, init=False)
def submit(
self,
task: TextGeneration,
sender: MpSender[GenerationResponse],
) -> None:
self._queue.append((task, sender))
if self._active is None:
self._start_next()
def step(self) -> None:
if self._active is None:
if self._queue:
self._start_next()
else:
return
if self._active is None:
return
task, sender, gen = self._active
try:
response = next(gen)
sender.send(response)
except (StopIteration, PrefillCancelled):
sender.close()
self._active = None
if self._queue:
self._start_next()
except Exception as e:
self._send_error(task, e)
sender.close()
self._active = None
raise
def _start_next(self) -> None:
task, sender = self._queue.popleft()
try:
gen = self._build_generator(task)
except Exception as e:
self._send_error(task, e)
sender.close()
raise
self._active = (task, sender, gen)
def _send_error(self, task: TextGeneration, e: Exception) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:
_check_for_debug_prompts(task.task_params)
prompt = apply_chat_template(self.tokenizer, task.task_params)
def on_prefill_progress(processed: int, total: int) -> None:
if self.device_rank == 0:
self.event_sender.send(
ChunkGenerated(
command_id=task.command_id,
chunk=PrefillProgressChunk(
model=self.model_id,
processed_tokens=processed,
total_tokens=total,
),
)
)
def distributed_prompt_progress_callback() -> None:
self.cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self.cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
tokens_since_cancel_check = self.check_for_cancel_every
def on_generation_token() -> None:
nonlocal tokens_since_cancel_check
tokens_since_cancel_check += 1
if tokens_since_cancel_check >= self.check_for_cancel_every:
tokens_since_cancel_check = 0
self.cancelled_tasks.update(self.cancel_receiver.collect())
want_to_cancel = (task.task_id in self.cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
)
if mx_any(want_to_cancel, self.group):
raise PrefillCancelled()
return mlx_generate(
model=self.model,
tokenizer=self.tokenizer,
task=task.task_params,
prompt=prompt,
kv_prefix_cache=self.kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
on_generation_token=on_generation_token,
group=self.group,
)

View File

@@ -1,341 +0,0 @@
from collections.abc import Generator
from functools import cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ToolCallItem
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.llm_inference.tool_parsers import ToolParser
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
# Debug: log every token with state
logger.debug(
f"parse_gpt_oss token={response.token} text={response.text!r} "
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
f"state={stream.state} current_tool={current_tool_name!r}"
)
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
logger.info(
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
)
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
],
usage=response.usage,
)
tool_arg_parts = []
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
if ch != "analysis" and thinking:
thinking = False
if delta:
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
if response.finish_reason is not None:
yield response
def parse_deepseek_v32(
responses: Generator[GenerationResponse | None],
) -> Generator[GenerationResponse | ToolCallResponse | None]:
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
Uses accumulated-text matching (not per-token marker checks) because
DSML markers like <DSMLfunction_calls> may span multiple tokens.
Also handles <think>...</think> blocks for thinking mode.
"""
from exo.worker.engines.mlx.dsml_encoding import (
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
parse_dsml_output,
)
accumulated = ""
in_tool_call = False
thinking = False
# Tokens buffered while we detect the start of a DSML block
pending_buffer: list[GenerationResponse] = []
# Text accumulated during a tool call block
tool_call_text = ""
for response in responses:
if response is None:
yield None
continue
# ── Handle thinking tags ──
if not thinking and THINKING_START in response.text:
thinking = True
# Yield any text before the <think> tag
before = response.text[: response.text.index(THINKING_START)]
if before:
yield response.model_copy(update={"text": before})
continue
if thinking and THINKING_END in response.text:
thinking = False
# Yield any text after the </think> tag
after = response.text[
response.text.index(THINKING_END) + len(THINKING_END) :
]
if after:
yield response.model_copy(update={"text": after, "is_thinking": False})
continue
if thinking:
yield response.model_copy(update={"is_thinking": True})
continue
# ── Handle tool call accumulation ──
if in_tool_call:
tool_call_text += response.text
if TOOL_CALLS_END in tool_call_text:
# Parse the accumulated DSML block
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# EOS reached before end marker — yield buffered text as-is
if response.finish_reason is not None:
logger.info("DSML tool call parsing interrupted by EOS")
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# ── Detect start of tool call block ──
accumulated += response.text
if TOOL_CALLS_START in accumulated:
# The start marker might be split across pending_buffer + current token
start_idx = accumulated.index(TOOL_CALLS_START)
# Yield any pending tokens that are purely before the marker
pre_text = accumulated[:start_idx]
if pre_text:
# Flush pending buffer tokens that contributed text before the marker
for buf_resp in pending_buffer:
if pre_text:
chunk = buf_resp.text
if len(chunk) <= len(pre_text):
yield buf_resp
pre_text = pre_text[len(chunk) :]
else:
yield buf_resp.model_copy(update={"text": pre_text})
pre_text = ""
pending_buffer = []
tool_call_text = accumulated[start_idx:]
accumulated = ""
# Check if the end marker is already present (entire tool call in one token)
if TOOL_CALLS_END in tool_call_text:
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
tool_call_text = ""
else:
in_tool_call = True
continue
# Check if accumulated text might be the start of a DSML marker
# Buffer tokens if we see a partial match at the end
if _could_be_dsml_prefix(accumulated):
pending_buffer.append(response)
continue
# No partial match — flush all pending tokens and the current one
for buf_resp in pending_buffer:
yield buf_resp
pending_buffer = []
accumulated = ""
yield response
# Flush any remaining pending buffer at generator end
for buf_resp in pending_buffer:
yield buf_resp
def _could_be_dsml_prefix(text: str) -> bool:
"""Check if the end of text could be the start of a DSML function_calls marker.
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
This allows us to buffer tokens until we can determine if a tool call is starting.
"""
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
# Only check the last portion of text that could overlap with the marker
max_check = len(TOOL_CALLS_START)
tail = text[-max_check:] if len(text) > max_check else text
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
for i in range(len(tail)):
suffix = tail[i:]
if TOOL_CALLS_START.startswith(suffix):
return True
return False
def parse_thinking_models(
responses: Generator[GenerationResponse | None],
tokenizer: TokenizerWrapper,
starts_in_thinking: bool = True,
) -> Generator[GenerationResponse | None]:
"""Route thinking tokens via is_thinking flag.
Swallows think tag tokens, sets is_thinking on all others.
Always yields tokens with finish_reason to avoid hanging the chunk stream.
"""
in_thinking = starts_in_thinking
for response in responses:
if response is None:
yield None
continue
if isinstance(response, ToolCallResponse):
yield response
continue
is_think_tag = (
tokenizer.think_end is not None and response.text == tokenizer.think_end
) or (
tokenizer.think_start is not None and response.text == tokenizer.think_start
)
if is_think_tag:
in_thinking = response.text != tokenizer.think_end
# Never swallow finish_reason — the chunk stream needs it to terminate.
if response.finish_reason is not None:
yield response.model_copy(update={"text": "", "is_thinking": False})
continue
yield response.model_copy(update={"is_thinking": in_thinking})
def parse_tool_calls(
responses: Generator[GenerationResponse | None], tool_parser: ToolParser
) -> Generator[GenerationResponse | ToolCallResponse | None]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
if response is None:
yield None
continue
if not in_tool_call and response.text.startswith(tool_parser.start_parsing):
in_tool_call = True
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"tool call parsing interrupted, yield partial tool call as text"
)
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
)
yield response
else:
# fallthrough
yield response

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass
class RunnerOpts:
fast_synch_override: bool | None
trust_remote_code_override: bool | None

View File

@@ -34,6 +34,7 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.utils.task_group import TaskGroup
from exo.worker.runner.bootstrap import entrypoint
from exo.worker.runner.runner_opts import RunnerOpts
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -62,6 +63,7 @@ class RunnerSupervisor:
def create(
cls,
*,
runner_opts: RunnerOpts,
bound_instance: BoundInstance,
event_sender: Sender[Event],
initialize_timeout: float = 400,
@@ -73,6 +75,7 @@ class RunnerSupervisor:
runner_process = mp.Process(
target=entrypoint,
args=(
runner_opts,
bound_instance,
ev_send,
task_recv,
@@ -172,7 +175,7 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending[event.task_id].set()
self.pending.pop(event.task_id).set()
continue
if (
isinstance(event, TaskStatusUpdated)
@@ -190,7 +193,6 @@ class RunnerSupervisor:
),
)
self.completed.add(event.task_id)
self.pending.pop(event.task_id, None)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask):

View File

@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
encode_messages,
parse_dsml_output,
)
from exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32
# ── Shared fixtures ──────────────────────────────────────────────

View File

@@ -6,7 +6,6 @@ from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator
import exo.worker.runner.llm_inference.runner as mlx_runner
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
@@ -41,6 +40,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.runner.runner_opts import RunnerOpts
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -116,20 +116,17 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(
mlx_batch_generator, "apply_chat_template", make_nothin("test prompt")
)
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_batch_generator, "mlx_generate", fake_generate)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness.
@@ -187,13 +184,13 @@ def _run(tasks: Iterable[Task]):
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
runner = mlx_runner.Runner(
mlx_runner.main(
RunnerOpts(None, None),
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
runner.main()
return event_sender.events

View File

@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss
from exo.worker.runner.llm_inference.runner import parse_gpt_oss
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
# These are stable since they come from the model's vocabulary.
@@ -107,7 +107,7 @@ def _collect(
def _gen() -> Generator[GenerationResponse, None, None]:
yield from _make_gen_responses(tokens)
return list(x for x in parse_gpt_oss(_gen()) if x is not None)
return list(parse_gpt_oss(_gen()))
def _get_tool_call(

View File

@@ -4,7 +4,7 @@ from collections.abc import Generator
from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls
from exo.worker.runner.llm_inference.runner import parse_tool_calls
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser