Compare commits

..

1 Commits

Author SHA1 Message Date
Evan Quiney
6b907398a4 cancel downloads for deleted instances (#1393)
after deleting an instance, if a given (node_id, model_id) pair doesn't exist in the left over instances, cancel the download of model_id on node_id.
2026-02-05 18:16:43 +00:00
10 changed files with 72 additions and 15 deletions

View File

@@ -20,7 +20,7 @@ sync-clean:
rust-rebuild:
cargo run --bin stub_gen
just sync-clean
uv sync --reinstall-package exo_pyo3_bindings
build-dashboard:
#!/usr/bin/env bash

View File

@@ -19,7 +19,7 @@ dependencies = [
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"mlx-lm",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -63,6 +63,7 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }

View File

@@ -16,6 +16,7 @@ from exo.download.download_utils import (
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
@@ -107,6 +108,13 @@ class DownloadCoordinator:
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
case CancelDownload(model_id=model_id):
await self._cancel_download(model_id)
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads and model_id in self.download_status:
logger.info(f"Cancelling download for {model_id}")
self.active_downloads.pop(model_id).cancel()
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id

View File

@@ -105,6 +105,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
er_send, er_recv = channel[ElectionResult]()
@@ -188,6 +189,9 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
)
self._tg.start_soon(self.master.run)
elif (

View File

@@ -6,6 +6,7 @@ from loguru import logger
from exo.master.placement import (
add_instance_to_placements,
cancel_unnecessary_downloads,
delete_instance,
get_transition_events,
place_instance,
@@ -16,6 +17,7 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
@@ -66,12 +68,9 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
# Receiving indexed events from the forwarder to be applied to state
# Ideally these would be WorkerForwarderEvents but type system says no :(
local_event_receiver: Receiver[ForwarderEvent],
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -81,6 +80,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -280,6 +280,14 @@ class Master:
transition_events = get_transition_events(
self.state.instances, placement
)
for cmd in cancel_unnecessary_downloads(
placement, self.state.downloads
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
)
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(

View File

@@ -15,14 +15,20 @@ from exo.master.placement_utils import (
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CancelDownload,
CreateInstance,
DeleteInstance,
DownloadCommand,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -202,3 +208,29 @@ def get_transition_events(
)
return events
def cancel_unnecessary_downloads(
instances: Mapping[InstanceId, Instance],
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> Sequence[DownloadCommand]:
commands: list[DownloadCommand] = []
currently_downloading = [
(k, v.shard_metadata.model_card.model_id)
for k, vs in download_status.items()
for v in vs
if isinstance(v, (DownloadOngoing))
]
active_models = set(
(
node_id,
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
)
for instance in instances.values()
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
)
for pair in currently_downloading:
if pair not in active_models:
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
return commands

View File

@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
CommandId,
ForwarderCommand,
ForwarderDownloadCommand,
PlaceInstance,
TextGeneration,
)
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
download_command_sender=fcds,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -72,7 +72,12 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
Command = (

View File

@@ -35,7 +35,7 @@ i=0
for host; do
colour=${colours[i++ % 4]}
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
done

10
uv.lock generated
View File

@@ -415,7 +415,7 @@ requires-dist = [
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.4" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.4" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
{ name = "psutil", specifier = ">=7.0.0" },
@@ -1072,8 +1072,8 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
version = "0.30.5"
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
@@ -1083,10 +1083,6 @@ dependencies = [
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451 },
]
[[package]]
name = "mlx-metal"