Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
526cd9f333 fix partial download progress showing 0% on restart
On restart, _emit_existing_download_progress() checked
downloaded_bytes_this_session to decide if a download was pending.
Since this field is always 0 in a new session, partially downloaded
models were reported as DownloadPending (0%) instead of DownloadOngoing
with their actual progress. Check downloaded_bytes (actual data on
disk) instead.

Closes #1042

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 10:13:24 -08:00
13 changed files with 56 additions and 66 deletions

View File

@@ -324,7 +324,7 @@ class DownloadCoordinator:
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
elif progress.downloaded_bytes.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,

View File

@@ -603,10 +603,10 @@ class API:
break
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
cancel_command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self.node_id, command=cancel_command)
)
raise
finally:
@@ -946,10 +946,10 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
cancel_command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self.node_id, command=cancel_command)
)
raise
finally:
@@ -1032,10 +1032,10 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
cancel_command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self.node_id, command=cancel_command)
)
raise
finally:

View File

@@ -417,16 +417,19 @@ class Master:
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
command.cancelled_command_id
in self.command_task_mapping
):
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
TaskDeleted(
task_id=self.command_task_mapping[
command.cancelled_command_id
]
)
)
del self.command_task_mapping[
command.cancelled_command_id
]
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -435,9 +438,10 @@ class Master:
]
)
)
self.command_task_mapping.pop(
command.finished_command_id, None
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time

View File

@@ -200,7 +200,7 @@ def try_place_for_meta_instance(
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
tasks: Mapping[TaskId, Task] | None = None,
tasks: Mapping[TaskId, Task],
) -> PlacementResult:
"""Try to place an instance satisfying the meta-instance constraints.
@@ -233,7 +233,7 @@ def try_place_for_meta_instance(
)
return PlacementResult(
events=list(
get_transition_events(current_instances, target_instances, tasks or {})
get_transition_events(current_instances, target_instances, tasks)
),
error=None,
)

View File

@@ -105,7 +105,6 @@ Command = (
| TaskCancelled
| CreateMetaInstance
| DeleteMetaInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -61,7 +61,7 @@ class TextGeneration(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
class CancelTask(BaseTask): # emitted by Worker when master cancels a task
cancelled_task_id: TaskId
runner_id: RunnerId

View File

@@ -125,9 +125,7 @@ class MpSender[T]:
self._state.buffer.put(item, block=True)
async def send_async(self, item: T) -> None:
await to_thread.run_sync(
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -1,11 +1,10 @@
import contextlib
from collections import defaultdict
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, ClosedResourceError, create_task_group, fail_after
from anyio import CancelScope, create_task_group, fail_after
from anyio.abc import TaskGroup
from loguru import logger
@@ -35,7 +34,6 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -118,8 +116,7 @@ class Worker:
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
with contextlib.suppress(ClosedResourceError):
runner.shutdown()
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
@@ -237,23 +234,15 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await runner.start_task(task)
await self.runners.pop(runner_id).start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
with contextlib.suppress(ClosedResourceError):
runner.shutdown()
case CancelTask(
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -291,18 +280,18 @@ class Worker:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self._start_runner_task(modified_task)
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self._start_runner_task(task)
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:
await self.runners[
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.

View File

@@ -328,7 +328,8 @@ def _pending_tasks(
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
) -> CancelTask | None:
"""Find a cancelled task that hasn't been sent to the runner yet."""
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue

View File

@@ -67,7 +67,9 @@ def entrypoint(
try:
event_sender.close()
task_receiver.close()
cancel_receiver.close()
finally:
event_sender.join()
task_receiver.join()
cancel_receiver.join()
logger.info("bye from the runner")

View File

@@ -243,7 +243,7 @@ def main(
assert inference_model
assert tokenizer
t = time.monotonic()
t = time.perf_counter()
toks = warmup_inference(
model=inference_model,
tokenizer=tokenizer,
@@ -251,7 +251,7 @@ def main(
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
math.ceil(toks / max(time.perf_counter() - t, 0.001)), 100
)
if group is not None:
check_for_cancel_every = int(

View File

@@ -72,8 +72,8 @@ class RunnerSupervisor:
initialize_timeout: float
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_event_sender: Sender[Event]
_pipe_read_fd: int | None = None # Python reads runner's pipe output
_pipe_write_fd: int | None = None # Python writes gathered data to runner
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
@@ -185,11 +185,8 @@ class RunnerSupervisor:
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
try:
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
except ClosedResourceError:
pass
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self._event_sender.close()
self._close_pipe_fds()
self.runner_process.join(1)

View File

@@ -1,9 +1,7 @@
# Check tasks are complete before runner is ever ready.
import unittest.mock
from collections.abc import Iterable
from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.runner as mlx_runner
@@ -117,6 +115,12 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock mx.distributed.all_gather so MockGroup doesn't hit real MLX C++ bindings.
def _mock_all_gather(x: object, **_kw: object) -> object:
return x
monkeypatch.setattr(mlx_runner.mx.distributed, "all_gather", _mock_all_gather)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
@@ -181,16 +185,12 @@ def _run(tasks: Iterable[Task]):
cancel_receiver.close = nothin
cancel_receiver.join = nothin
with unittest.mock.patch(
"exo.worker.runner.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
mlx_runner.main(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
mlx_runner.main(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
return event_sender.events