Compare commits

...

1 Commits

Author SHA1 Message Date
Evan Quiney
6b907398a4 cancel downloads for deleted instances (#1393)
after deleting an instance, if a given (node_id, model_id) pair doesn't exist in the left over instances, cancel the download of model_id on node_id.
2026-02-05 18:16:43 +00:00
8 changed files with 67 additions and 7 deletions

View File

@@ -20,7 +20,7 @@ sync-clean:
rust-rebuild:
cargo run --bin stub_gen
just sync-clean
uv sync --reinstall-package exo_pyo3_bindings
build-dashboard:
#!/usr/bin/env bash

View File

@@ -16,6 +16,7 @@ from exo.download.download_utils import (
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
@@ -107,6 +108,13 @@ class DownloadCoordinator:
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
case CancelDownload(model_id=model_id):
await self._cancel_download(model_id)
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads and model_id in self.download_status:
logger.info(f"Cancelling download for {model_id}")
self.active_downloads.pop(model_id).cancel()
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id

View File

@@ -105,6 +105,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
er_send, er_recv = channel[ElectionResult]()
@@ -188,6 +189,9 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
)
self._tg.start_soon(self.master.run)
elif (

View File

@@ -6,6 +6,7 @@ from loguru import logger
from exo.master.placement import (
add_instance_to_placements,
cancel_unnecessary_downloads,
delete_instance,
get_transition_events,
place_instance,
@@ -16,6 +17,7 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
@@ -66,12 +68,9 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
# Receiving indexed events from the forwarder to be applied to state
# Ideally these would be WorkerForwarderEvents but type system says no :(
local_event_receiver: Receiver[ForwarderEvent],
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -81,6 +80,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -280,6 +280,14 @@ class Master:
transition_events = get_transition_events(
self.state.instances, placement
)
for cmd in cancel_unnecessary_downloads(
placement, self.state.downloads
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
)
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(

View File

@@ -15,14 +15,20 @@ from exo.master.placement_utils import (
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CancelDownload,
CreateInstance,
DeleteInstance,
DownloadCommand,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -202,3 +208,29 @@ def get_transition_events(
)
return events
def cancel_unnecessary_downloads(
instances: Mapping[InstanceId, Instance],
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> Sequence[DownloadCommand]:
commands: list[DownloadCommand] = []
currently_downloading = [
(k, v.shard_metadata.model_card.model_id)
for k, vs in download_status.items()
for v in vs
if isinstance(v, (DownloadOngoing))
]
active_models = set(
(
node_id,
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
)
for instance in instances.values()
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
)
for pair in currently_downloading:
if pair not in active_models:
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
return commands

View File

@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
CommandId,
ForwarderCommand,
ForwarderDownloadCommand,
PlaceInstance,
TextGeneration,
)
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
download_command_sender=fcds,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -72,7 +72,12 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
Command = (

View File

@@ -35,7 +35,7 @@ i=0
for host; do
colour=${colours[i++ % 4]}
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
done