Compare commits

...

5 Commits

Author SHA1 Message Date
ciaranbor
dd7064d022 Set max num_sync_steps to 100 2026-02-05 18:35:42 +00:00
ciaranbor
817659bab2 Expose num_sync_steps in advanced params 2026-02-05 18:18:58 +00:00
ciaranbor
8f2435129d Configure num_sync_steps directly 2026-02-05 18:03:03 +00:00
Evan Quiney
572e647908 better cancellation (#1388)
a lot of our cleanup logic wasn't running leading to bad shutdown states

## changes
- added `try: except` blocks around most task groups
- made the runner shutdown code synchronous
- abandon the MpReceiver's recv_async thread on cancellation
- this only occurs during runner shutdown, the queue closing from the
other end should terminate the mp.Queue, cleaning up the thread in its
own time. i could try other methods if this is not sufficient.

## outcome
ctrl-c just works now! minus the tokio panic of course :) no more
hypercorn lifespan errors though!
2026-02-05 15:22:33 +00:00
Evan Quiney
e59ebd986d set exo as the nix default package (#1391)
!!!
2026-02-05 15:15:52 +00:00
18 changed files with 197 additions and 137 deletions

View File

@@ -148,6 +148,15 @@
setImageGenerationParams({ guidance: null });
}
function handleNumSyncStepsChange(event: Event) {
const value = parseInt((event.target as HTMLInputElement).value, 10);
setImageGenerationParams({ numSyncSteps: value });
}
function clearNumSyncSteps() {
setImageGenerationParams({ numSyncSteps: null });
}
function handleReset() {
resetImageGenerationParams();
showAdvanced = false;
@@ -157,7 +166,8 @@
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
(params.negativePrompt !== null && params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null,
);
</script>
@@ -578,7 +588,50 @@
</div>
</div>
<!-- Row 3: Negative Prompt -->
<!-- Row 3: Sync Steps -->
<div class="flex items-center gap-1.5">
<span
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
>SYNC STEPS:</span
>
<div class="flex items-center gap-2 flex-1 max-w-xs">
<input
type="range"
min="1"
max="100"
value={params.numSyncSteps ?? 1}
oninput={handleNumSyncStepsChange}
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
/>
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
{params.numSyncSteps ?? "--"}
</span>
{#if params.numSyncSteps !== null}
<button
type="button"
onclick={clearNumSyncSteps}
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
title="Clear"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
{/if}
</div>
</div>
<!-- Row 4: Negative Prompt -->
<div class="flex flex-col gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>NEGATIVE PROMPT:</span

View File

@@ -298,6 +298,7 @@ export interface ImageGenerationParams {
numInferenceSteps: number | null;
guidance: number | null;
negativePrompt: string | null;
numSyncSteps: number | null;
// Edit mode params
inputFidelity: "low" | "high";
}
@@ -319,6 +320,7 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
numInferenceSteps: null,
guidance: null,
negativePrompt: null,
numSyncSteps: null,
inputFidelity: "low",
};
@@ -2396,7 +2398,9 @@ class AppStore {
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "");
(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null;
const requestBody: Record<string, unknown> = {
model,
@@ -2421,6 +2425,9 @@ class AppStore {
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
...(params.numSyncSteps !== null && {
num_sync_steps: params.numSyncSteps,
}),
};
}
@@ -2670,29 +2677,19 @@ class AppStore {
formData.append("input_fidelity", params.inputFidelity);
// Advanced params
if (params.seed !== null) {
formData.append(
"advanced_params",
JSON.stringify({
seed: params.seed,
...(params.numInferenceSteps !== null && {
num_inference_steps: params.numInferenceSteps,
}),
...(params.guidance !== null && { guidance: params.guidance }),
...(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
}),
);
} else if (
const hasAdvancedParams =
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "")
) {
(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null;
if (hasAdvancedParams) {
formData.append(
"advanced_params",
JSON.stringify({
...(params.seed !== null && { seed: params.seed }),
...(params.numInferenceSteps !== null && {
num_inference_steps: params.numInferenceSteps,
}),
@@ -2701,6 +2698,9 @@ class AppStore {
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
...(params.numSyncSteps !== null && {
num_sync_steps: params.numSyncSteps,
}),
}),
);
}

View File

@@ -118,9 +118,10 @@
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit (self'.packages) metal-toolchain;
inherit uvLockMlxVersion;
};
default = self'.packages.exo;
}
);

View File

@@ -53,11 +53,10 @@ class DownloadCoordinator:
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")

View File

@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.main import Worker
# I marked this as a dataclass as I want trivial constructors.
@dataclass
class Node:
router: Router
@@ -136,7 +135,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -148,6 +146,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit

View File

@@ -1320,29 +1320,40 @@ class API:
]
async def run(self):
shutdown_ev = anyio.Event()
try:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
tg.start_soon(self.run_api, shutdown_ev)
try:
await anyio.sleep_forever()
finally:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
self.command_sender.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"
cfg.bind = [f"0.0.0.0:{self.port}"]
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = None
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
with anyio.CancelScope(shield=True):
await serve(
cast(ASGIFramework, self.app),
cfg,
shutdown_trigger=lambda: anyio.sleep_forever(),
shutdown_trigger=ev.wait,
)
self.command_sender.close()
self.global_event_receiver.close()
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -96,16 +96,18 @@ class Master:
async def run(self):
logger.info("Starting Master")
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
try:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")

View File

@@ -9,6 +9,7 @@ from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
move_on_after,
sleep_forever,
)
from anyio.abc import TaskGroup
@@ -146,18 +147,21 @@ class Router:
async def run(self):
logger.debug("Starting Router")
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
try:
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
with move_on_after(1, shield=True):
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self):
logger.debug("Shutting down Router")
@@ -166,12 +170,12 @@ class Router:
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
logger.info(f"Subscribed to {topic}")
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
logger.info(f"Unsubscribed from {topic}")
async def _networking_recv(self):
while True:

View File

@@ -86,28 +86,29 @@ class Election:
async def run(self):
logger.info("Starting Election")
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
try:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election finished")
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
finally:
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election shutdown")
async def elect(self, em: ElectionMessage) -> None:
logger.debug(f"Electing: {em}")

View File

@@ -272,6 +272,7 @@ class AdvancedImageParams(BaseModel):
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
negative_prompt: str | None = None
num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None
class ImageGenerationTaskParams(BaseModel):

View File

@@ -194,9 +194,10 @@ class MpReceiver[T]:
raise EndOfStream from None
return item
# nb: this function will not cancel particularly well
async def receive_async(self) -> T:
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
return await to_thread.run_sync(
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -1,5 +1,4 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
@@ -23,7 +22,7 @@ class ImageModelConfig(BaseModel):
block_configs: tuple[TransformerBlockConfig, ...]
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
num_sync_steps: int # Number of sync steps for distributed inference
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@@ -45,6 +44,3 @@ class ImageModelConfig(BaseModel):
def get_steps_for_quality(self, quality: str) -> int:
return self.default_steps[quality]
def get_num_sync_steps(self, steps: int) -> int:
return ceil(steps * self.num_sync_steps_factor)

View File

@@ -150,7 +150,10 @@ class DistributedImageModel:
guidance=guidance_override if guidance_override is not None else 4.0,
)
num_sync_steps = self._config.get_num_sync_steps(steps)
if advanced_params is not None and advanced_params.num_sync_steps is not None:
num_sync_steps = advanced_params.num_sync_steps
else:
num_sync_steps = self._config.num_sync_steps
for result in self._runner.generate_image(
runtime_config=config,

View File

@@ -15,7 +15,7 @@ FLUX_SCHNELL_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
num_sync_steps=1,
)
@@ -30,5 +30,5 @@ FLUX_DEV_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
num_sync_steps=4,
)

View File

@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.25,
num_sync_steps=7,
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
)
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.25,
num_sync_steps=7,
guidance_scale=3.5,
)

View File

@@ -98,21 +98,23 @@ class Worker:
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
try:
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:

View File

@@ -8,10 +8,8 @@ import anyio
from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
to_thread,
)
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import (
@@ -49,7 +47,6 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
@@ -93,28 +90,29 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._forward_events)
await self._forward_events()
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join, 5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
@@ -122,10 +120,6 @@ class RunnerSupervisor:
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(

View File

@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
@@ -34,21 +34,13 @@ reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
wait