mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-07 04:32:28 -05:00
Compare commits
4 Commits
rust-explo
...
only-time-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b694552402 | ||
|
|
9f502793c1 | ||
|
|
c8371349d5 | ||
|
|
6b907398a4 |
2
justfile
2
justfile
@@ -20,7 +20,7 @@ sync-clean:
|
||||
|
||||
rust-rebuild:
|
||||
cargo run --bin stub_gen
|
||||
just sync-clean
|
||||
uv sync --reinstall-package exo_pyo3_bindings
|
||||
|
||||
build-dashboard:
|
||||
#!/usr/bin/env bash
|
||||
|
||||
@@ -31,8 +31,6 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
|
||||
# dependencies only required for development
|
||||
|
||||
@@ -59,6 +59,16 @@
|
||||
}
|
||||
);
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ exoVenv ];
|
||||
runtimeEnv = {
|
||||
EXO_DASHBOARD_DIR = self'.packages.dashboard;
|
||||
EXO_RESOURCES_DIR = inputs.self + /resources;
|
||||
};
|
||||
text = ''python ${path}'';
|
||||
};
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
@@ -66,13 +76,11 @@
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
|
||||
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
done
|
||||
# Create wrapper script
|
||||
makeWrapper ${exoVenv}/bin/exo $out/bin/exo \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + /resources} \
|
||||
${lib.optionalString pkgs.stdenv.hostPlatform.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
'';
|
||||
in
|
||||
{
|
||||
@@ -81,13 +89,15 @@
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
exo-bench = mkPythonScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
exo-distributed-test = mkPythonScript "exo-distributed-test" (inputs.self + /tests/headless_runner.py);
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
@@ -16,6 +16,7 @@ from exo.download.download_utils import (
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
@@ -107,6 +108,13 @@ class DownloadCoordinator:
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
case CancelDownload(model_id=model_id):
|
||||
await self._cancel_download(model_id)
|
||||
|
||||
async def _cancel_download(self, model_id: ModelId) -> None:
|
||||
if model_id in self.active_downloads and model_id in self.download_status:
|
||||
logger.info(f"Cancelling download for {model_id}")
|
||||
self.active_downloads.pop(model_id).cancel()
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
@@ -378,10 +378,14 @@ async def download_file_with_retry(
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
except Exception as e:
|
||||
on_connection_lost()
|
||||
if attempt == n_attempts - 1:
|
||||
on_connection_lost()
|
||||
raise e
|
||||
break
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
|
||||
@@ -105,6 +105,7 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -188,6 +189,9 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
add_instance_to_placements,
|
||||
cancel_unnecessary_downloads,
|
||||
delete_instance,
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
@@ -16,6 +17,7 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
@@ -66,12 +68,9 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
# Receiving indexed events from the forwarder to be applied to state
|
||||
# Ideally these would be WorkerForwarderEvents but type system says no :(
|
||||
local_event_receiver: Receiver[ForwarderEvent],
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -81,6 +80,7 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -280,6 +280,14 @@ class Master:
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id, command=cmd
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
|
||||
@@ -15,14 +15,20 @@ from exo.master.placement_utils import (
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -202,3 +208,29 @@ def get_transition_events(
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def cancel_unnecessary_downloads(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> Sequence[DownloadCommand]:
|
||||
commands: list[DownloadCommand] = []
|
||||
currently_downloading = [
|
||||
(k, v.shard_metadata.model_card.model_id)
|
||||
for k, vs in download_status.items()
|
||||
for v in vs
|
||||
if isinstance(v, (DownloadOngoing))
|
||||
]
|
||||
active_models = set(
|
||||
(
|
||||
node_id,
|
||||
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
|
||||
)
|
||||
for instance in instances.values()
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
)
|
||||
for pair in currently_downloading:
|
||||
if pair not in active_models:
|
||||
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
|
||||
|
||||
return commands
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -47,6 +48,7 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +69,7 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
download_command_sender=fcds,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -208,58 +208,19 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.remove_node(event.node_id)
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
}
|
||||
downloads = {
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
# Clean up all granular node mappings
|
||||
node_identities = {
|
||||
key: value
|
||||
for key, value in state.node_identities.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_memory = {
|
||||
key: value for key, value in state.node_memory.items() if key != event.node_id
|
||||
}
|
||||
node_system = {
|
||||
key: value for key, value in state.node_system.items() if key != event.node_id
|
||||
}
|
||||
node_network = {
|
||||
key: value for key, value in state.node_network.items() if key != event.node_id
|
||||
}
|
||||
node_thunderbolt = {
|
||||
key: value
|
||||
for key, value in state.node_thunderbolt.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_thunderbolt_bridge = {
|
||||
key: value
|
||||
for key, value in state.node_thunderbolt_bridge.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
# Only recompute cycles if the leaving node had TB bridge enabled
|
||||
leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)
|
||||
leaving_node_had_tb_enabled = (
|
||||
leaving_node_status is not None and leaving_node_status.enabled
|
||||
)
|
||||
thunderbolt_bridge_cycles = (
|
||||
topology.get_thunderbolt_bridge_cycles(node_thunderbolt_bridge, node_network)
|
||||
topology.get_thunderbolt_bridge_cycles(state.node_thunderbolt_bridge, state.node_network)
|
||||
if leaving_node_had_tb_enabled
|
||||
else [list(cycle) for cycle in state.thunderbolt_bridge_cycles]
|
||||
)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
"node_identities": node_identities,
|
||||
"node_memory": node_memory,
|
||||
"node_system": node_system,
|
||||
"node_network": node_network,
|
||||
"node_thunderbolt": node_thunderbolt,
|
||||
"node_thunderbolt_bridge": node_thunderbolt_bridge,
|
||||
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -72,7 +72,12 @@ class DeleteDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
class CancelDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
Command = (
|
||||
|
||||
@@ -35,7 +35,7 @@ i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
done
|
||||
|
||||
|
||||
Reference in New Issue
Block a user