mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 13:29:29 -05:00
Compare commits
1 Commits
ciaran/fix
...
state-comp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ee0bce898 |
@@ -3,45 +3,6 @@
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Stub source with lockfiles and minimal files for build to succeed
|
||||
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
|
||||
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
|
||||
mkdir -p $out
|
||||
cp ${inputs.self}/dashboard/package.json $out/
|
||||
cp ${inputs.self}/dashboard/package-lock.json $out/
|
||||
# Minimal files so vite build succeeds (produces empty output)
|
||||
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
|
||||
mkdir -p $out/src
|
||||
touch $out/src/app.html
|
||||
'';
|
||||
|
||||
# Deps-only build using stub source (for prettier-svelte)
|
||||
# Only rebuilds when package.json or package-lock.json change
|
||||
dashboardDeps = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
|
||||
}
|
||||
# Override build phases to skip the actual build - just need node_modules
|
||||
{
|
||||
mkDerivation = {
|
||||
buildPhase = lib.mkForce "true";
|
||||
installPhase = lib.mkForce ''
|
||||
runHook preInstall
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
}
|
||||
];
|
||||
};
|
||||
|
||||
# Filter source to only include dashboard directory
|
||||
dashboardSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
@@ -81,12 +42,11 @@
|
||||
'';
|
||||
|
||||
# Prettier with svelte plugin for treefmt
|
||||
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
|
||||
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
|
||||
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
|
||||
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
|
||||
exec ${pkgs.nodejs}/bin/node \
|
||||
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
"$@"
|
||||
'';
|
||||
};
|
||||
|
||||
@@ -216,8 +216,6 @@ export interface Message {
|
||||
attachments?: MessageAttachment[];
|
||||
ttftMs?: number; // Time to first token in ms (for assistant messages)
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
requestType?: "chat" | "image-generation" | "image-editing";
|
||||
sourceImageDataUrl?: string; // For image editing regeneration
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -1272,46 +1270,10 @@ class AppStore {
|
||||
|
||||
if (lastUserIndex === -1) return;
|
||||
|
||||
const lastUserMessage = this.messages[lastUserIndex];
|
||||
const requestType = lastUserMessage.requestType || "chat";
|
||||
const prompt = lastUserMessage.content;
|
||||
// Remove any messages after the user message
|
||||
this.messages = this.messages.slice(0, lastUserIndex + 1);
|
||||
|
||||
// Remove messages after user message (including the user message for image requests
|
||||
// since generateImage/editImage will re-add it)
|
||||
this.messages = this.messages.slice(0, lastUserIndex);
|
||||
|
||||
switch (requestType) {
|
||||
case "image-generation":
|
||||
await this.generateImage(prompt);
|
||||
break;
|
||||
case "image-editing":
|
||||
if (lastUserMessage.sourceImageDataUrl) {
|
||||
await this.editImage(prompt, lastUserMessage.sourceImageDataUrl);
|
||||
} else {
|
||||
// Can't regenerate edit without source image - restore user message and show error
|
||||
this.messages.push(lastUserMessage);
|
||||
const errorMessage = this.addMessage("assistant", "");
|
||||
const idx = this.messages.findIndex((m) => m.id === errorMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content =
|
||||
"Error: Cannot regenerate image edit - source image not found";
|
||||
}
|
||||
this.updateActiveConversation();
|
||||
}
|
||||
break;
|
||||
case "chat":
|
||||
default:
|
||||
// Restore the user message for chat regeneration
|
||||
this.messages.push(lastUserMessage);
|
||||
await this.regenerateChatCompletion();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to regenerate a chat completion response
|
||||
*/
|
||||
private async regenerateChatCompletion(): Promise<void> {
|
||||
// Resend the message to get a new response
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -1826,7 +1788,6 @@ class AppStore {
|
||||
role: "user",
|
||||
content: prompt,
|
||||
timestamp: Date.now(),
|
||||
requestType: "image-generation",
|
||||
};
|
||||
this.messages.push(userMessage);
|
||||
|
||||
@@ -2037,8 +1998,6 @@ class AppStore {
|
||||
role: "user",
|
||||
content: prompt,
|
||||
timestamp: Date.now(),
|
||||
requestType: "image-editing",
|
||||
sourceImageDataUrl: imageDataUrl,
|
||||
};
|
||||
this.messages.push(userMessage);
|
||||
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import (
|
||||
RepoDownloadProgress,
|
||||
delete_model,
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
with self.download_command_receiver as commands:
|
||||
async for cmd in commands:
|
||||
# Only process commands targeting this node
|
||||
if cmd.command.target_node_id != self.node_id:
|
||||
continue
|
||||
|
||||
match cmd.command:
|
||||
case StartDownload(shard_metadata=shard):
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Check if already downloading or complete
|
||||
if model_id in self.download_status:
|
||||
status = self.download_status[model_id]
|
||||
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
|
||||
logger.debug(
|
||||
f"Download for {model_id} already in progress or complete, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Emit pending status
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
# Check initial status from downloader
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
return
|
||||
|
||||
# Start actual download
|
||||
self._start_download_task(shard, initial_progress)
|
||||
|
||||
def _start_download_task(
|
||||
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Emit ongoing status
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal last_progress_time
|
||||
|
||||
if progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
# Clean up active download tracking
|
||||
if callback_shard.model_card.model_id in self.active_downloads:
|
||||
del self.active_downloads[callback_shard.model_card.model_id]
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
ongoing = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=callback_shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=ongoing)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
async def download_wrapper() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(shard)
|
||||
except Exception as e:
|
||||
logger.error(f"Download failed for {model_id}: {e}")
|
||||
failed = DownloadFailed(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=str(e),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed)
|
||||
)
|
||||
finally:
|
||||
if model_id in self.active_downloads:
|
||||
del self.active_downloads[model_id]
|
||||
|
||||
task = asyncio.create_task(download_wrapper())
|
||||
self.active_downloads[model_id] = task
|
||||
|
||||
async def _delete_download(self, model_id: ModelId) -> None:
|
||||
# Cancel if active
|
||||
if model_id in self.active_downloads:
|
||||
logger.info(f"Cancelling active download for {model_id} before deletion")
|
||||
self.active_downloads[model_id].cancel()
|
||||
del self.active_downloads[model_id]
|
||||
|
||||
# Delete from disk
|
||||
logger.info(f"Deleting model files for {model_id}")
|
||||
deleted = await delete_model(model_id)
|
||||
|
||||
if deleted:
|
||||
logger.info(f"Successfully deleted model {model_id}")
|
||||
else:
|
||||
logger.warning(f"Model {model_id} was not found on disk")
|
||||
|
||||
# Emit pending status to reset UI state, then remove from local tracking
|
||||
if model_id in self.download_status:
|
||||
current_status = self.download_status[model_id]
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info(
|
||||
"DownloadCoordinator: Fetching and emitting existing download progress..."
|
||||
)
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info(
|
||||
"DownloadCoordinator: Done emitting existing download progress."
|
||||
)
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DownloadCoordinator: Error emitting existing download progress: {e}"
|
||||
)
|
||||
@@ -1,11 +1,10 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, Self
|
||||
from typing import Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -13,8 +12,6 @@ from loguru import logger
|
||||
from pydantic import PositiveInt
|
||||
|
||||
import exo.routing.topics as topics
|
||||
from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
@@ -24,6 +21,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
@@ -31,7 +29,6 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
download_coordinator: DownloadCoordinator | None
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
@@ -39,7 +36,6 @@ class Node:
|
||||
api: API | None
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -53,26 +49,9 @@ class Node:
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
await router.register_topic(topics.DOWNLOAD_COMMANDS)
|
||||
await router.register_topic(topics.STATE_CATCHUP)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
@@ -81,6 +60,7 @@ class Node:
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
|
||||
)
|
||||
else:
|
||||
api = None
|
||||
@@ -89,12 +69,12 @@ class Node:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
@@ -106,6 +86,7 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
state_catchup_sender=router.sender(topics.STATE_CATCHUP),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -122,25 +103,13 @@ class Node:
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
return cls(
|
||||
router,
|
||||
download_coordinator,
|
||||
worker,
|
||||
election,
|
||||
er_recv,
|
||||
master,
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
)
|
||||
return cls(router, worker, election, er_recv, master, api, node_id)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
@@ -188,6 +157,7 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
state_catchup_sender=self.router.sender(topics.STATE_CATCHUP),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
@@ -205,27 +175,13 @@ class Node:
|
||||
)
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
self.worker.shutdown()
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
@@ -234,10 +190,9 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
state_catchup_receiver=self.router.receiver(
|
||||
topics.STATE_CATCHUP
|
||||
),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
@@ -279,7 +234,6 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -322,11 +276,6 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-downloads",
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -158,12 +158,14 @@ class API:
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
state_catchup_receiver: Receiver[State],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log: list[Event] = []
|
||||
self.command_sender = command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.state_catchup_receiver = state_catchup_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
@@ -1231,6 +1233,7 @@ class API:
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
tg.start_soon(self._state_catchup)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
@@ -1241,6 +1244,22 @@ class API:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _state_catchup(self):
|
||||
with self.state_catchup_receiver as states:
|
||||
async for state in states:
|
||||
if (
|
||||
self.state.last_event_applied_idx == -1
|
||||
and state.last_event_applied_idx > self.state.last_event_applied_idx
|
||||
):
|
||||
logger.info(
|
||||
f"API catching up state to idx {state.last_event_applied_idx}"
|
||||
)
|
||||
self.event_buffer.store = {}
|
||||
self.event_buffer.next_idx_to_release = (
|
||||
state.last_event_applied_idx + 1
|
||||
)
|
||||
self.state = state
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -68,6 +68,8 @@ class Master:
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
# not a fan but - send the entire state to a node so it can catchup without the whole event log.
|
||||
state_catchup_sender: Sender[State],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -77,6 +79,7 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.state_catchup_sender = state_catchup_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -84,7 +87,6 @@ class Master:
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
|
||||
async def run(self):
|
||||
@@ -291,11 +293,17 @@ class Master:
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
if command.since_idx == 0:
|
||||
# This is an optimization, and should not be relied upon in theory.
|
||||
logger.info(
|
||||
f"Master sending catchup state for index {self.state.last_event_applied_idx}"
|
||||
)
|
||||
await self.state_catchup_sender.send(self.state)
|
||||
else:
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -27,6 +27,7 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
@@ -47,6 +48,7 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
st_s, _st_r = channel[State]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +69,7 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
state_catchup_sender=st_s,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -3,10 +3,11 @@ from enum import Enum
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
||||
from exo.shared.types.commands import ForwarderCommand
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
@@ -45,6 +46,4 @@ ELECTION_MESSAGES = TypedTopic(
|
||||
CONNECTION_MESSAGES = TypedTopic(
|
||||
"connection_messages", PublishPolicy.Never, ConnectionMessage
|
||||
)
|
||||
DOWNLOAD_COMMANDS = TypedTopic(
|
||||
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
|
||||
)
|
||||
STATE_CATCHUP = TypedTopic("state_catchup", PublishPolicy.Always, State)
|
||||
|
||||
@@ -621,7 +621,7 @@ class ConfigData(BaseModel):
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
from exo.download.download_utils import (
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
@@ -643,11 +643,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
|
||||
async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
from exo.download.download_utils import (
|
||||
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
|
||||
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
@@ -9,7 +9,7 @@ from exo.shared.types.api import (
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -62,19 +62,6 @@ class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
|
||||
class StartDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DeleteDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
|
||||
|
||||
Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
@@ -92,8 +79,3 @@ Command = (
|
||||
class ForwarderCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
command: Command
|
||||
|
||||
|
||||
class ForwarderDownloadCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
command: DownloadCommand
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import time
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
class KeyedBackoff(Generic[K]):
|
||||
"""Tracks exponential backoff state per key."""
|
||||
|
||||
def __init__(self, base: float = 0.5, cap: float = 10.0):
|
||||
self._base = base
|
||||
self._cap = cap
|
||||
self._attempts: dict[K, int] = {}
|
||||
self._last_time: dict[K, float] = {}
|
||||
|
||||
def should_proceed(self, key: K) -> bool:
|
||||
"""Returns True if enough time has elapsed since last attempt."""
|
||||
now = time.monotonic()
|
||||
last = self._last_time.get(key, 0.0)
|
||||
attempts = self._attempts.get(key, 0)
|
||||
delay = min(self._cap, self._base * (2.0**attempts))
|
||||
return now - last >= delay
|
||||
|
||||
def record_attempt(self, key: K) -> None:
|
||||
"""Record that an attempt was made for this key."""
|
||||
self._last_time[key] = time.monotonic()
|
||||
self._attempts[key] = self._attempts.get(key, 0) + 1
|
||||
|
||||
def reset(self, key: K) -> None:
|
||||
"""Reset backoff state for a key (e.g., on success)."""
|
||||
self._attempts.pop(key, None)
|
||||
self._last_time.pop(key, None)
|
||||
@@ -24,13 +24,6 @@ from pydantic import (
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from exo.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -42,6 +35,13 @@ from exo.shared.types.worker.downloads import (
|
||||
RepoFileDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceAuthenticationError(Exception):
|
||||
@@ -5,13 +5,13 @@ from typing import AsyncIterator, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
@@ -5,13 +5,13 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
@@ -6,10 +6,10 @@ import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import AdvancedImageParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
@@ -140,7 +140,6 @@ class DistributedImageModel:
|
||||
width=width,
|
||||
image_path=image_path,
|
||||
model_config=self._adapter.model.model_config, # pyright: ignore[reportAny]
|
||||
guidance=guidance_override if guidance_override is not None else 4.0,
|
||||
)
|
||||
|
||||
num_sync_steps = self._config.get_num_sync_steps(steps)
|
||||
|
||||
@@ -41,7 +41,6 @@ import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
from pydantic import RootModel
|
||||
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -56,6 +55,7 @@ from exo.shared.types.worker.shards import (
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
TimeoutCallback,
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
from anyio import CancelScope, create_task_group, current_time, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
@@ -11,12 +10,7 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
RequestEventLog,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
@@ -24,6 +18,7 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -41,12 +36,23 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -56,29 +62,31 @@ class Worker:
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
shard_downloader: ShardDownloader,
|
||||
*,
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
local_event_sender: Sender[ForwarderEvent],
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
state_catchup_receiver: Receiver[State],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
|
||||
self.shard_downloader: ShardDownloader = shard_downloader
|
||||
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.event_index_counter = event_index_counter
|
||||
self.state_catchup_receiver = state_catchup_receiver
|
||||
self.local_event_index = 0
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.connection_message_receiver = connection_message_receiver
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
|
||||
@@ -93,8 +101,6 @@ class Worker:
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
@@ -105,16 +111,17 @@ class Worker:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
tg.start_soon(self._check_catchup_state)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
@@ -129,6 +136,22 @@ class Worker:
|
||||
)
|
||||
)
|
||||
|
||||
async def _check_catchup_state(self):
|
||||
with self.state_catchup_receiver as states:
|
||||
async for state in states:
|
||||
if (
|
||||
self.state.last_event_applied_idx == -1
|
||||
and state.last_event_applied_idx > self.state.last_event_applied_idx
|
||||
):
|
||||
logger.info(
|
||||
f"Worker catching up state to idx {state.last_event_applied_idx}"
|
||||
)
|
||||
self.event_buffer.store = {}
|
||||
self.event_buffer.next_idx_to_release = (
|
||||
state.last_event_applied_idx + 1
|
||||
)
|
||||
self.state = state
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
@@ -173,9 +196,11 @@ class Worker:
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
# 3. based on the updated state, we plan & execute an operation.
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.download_status,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
@@ -199,26 +224,42 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
model_id = shard.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
continue
|
||||
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
),
|
||||
if shard.model_card.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
@@ -318,10 +359,7 @@ class Worker:
|
||||
# We request all events after (and including) the missing index.
|
||||
# This function is started whenever we receive an event that is out of sequence.
|
||||
# It is cancelled as soon as we receiver an event that is in sequence.
|
||||
|
||||
if since_idx < 0:
|
||||
logger.warning(f"Negative value encountered for nack request {since_idx=}")
|
||||
since_idx = 0
|
||||
assert since_idx >= 0
|
||||
|
||||
with CancelScope() as scope:
|
||||
self._nack_cancel_scope = scope
|
||||
@@ -363,17 +401,104 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
initial_progress: RepoDownloadProgress,
|
||||
):
|
||||
"""Manages the shard download process with progress tracking."""
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal self
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
async def download_with_error_handling() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(task.shard_metadata)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(
|
||||
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
|
||||
)
|
||||
failed_status = DownloadFailed(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
error_message=error_message,
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = (
|
||||
failed_status
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed_status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Failed
|
||||
)
|
||||
)
|
||||
|
||||
self._tg.start_soon(download_with_error_handling)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin_idx=self.local_event_index,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
logger.debug(
|
||||
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
|
||||
)
|
||||
self.local_event_index += 1
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
|
||||
@@ -421,3 +546,42 @@ class Worker:
|
||||
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.debug("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
@@ -44,6 +45,9 @@ def plan(
|
||||
node_id: NodeId,
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
# DL_status is expected to be FRESH and so should not come from state
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
# gdls is not expected to be fresh
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
@@ -55,7 +59,7 @@ def plan(
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
@@ -111,15 +115,9 @@ def _create_runner(
|
||||
|
||||
|
||||
def _model_needs_download(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
local_downloads = global_download_status.get(node_id, [])
|
||||
download_status = {
|
||||
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
|
||||
}
|
||||
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
|
||||
@@ -11,12 +11,12 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.download.download_utils import (
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import ModelId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
@@ -45,9 +45,13 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ModelId, DownloadProgress] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=download_status,
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -88,6 +92,14 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -104,6 +116,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -135,19 +148,23 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# Global state shows shard is downloaded for NODE_A
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_A: [],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -185,6 +202,12 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -197,6 +220,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -221,6 +245,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -47,7 +47,8 @@ def test_plan_kills_runner_when_instance_missing():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -86,7 +87,8 @@ def test_plan_kills_runner_when_sibling_failed():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -118,6 +120,7 @@ def test_plan_creates_runner_when_missing_for_node():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners,
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -155,7 +158,8 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -185,6 +189,7 @@ def test_plan_does_not_create_runner_for_unassigned_node():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -65,6 +65,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -112,6 +113,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -156,6 +158,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -218,6 +221,7 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -257,6 +261,7 @@ def test_plan_returns_none_when_nothing_to_do():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -57,6 +57,7 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -98,6 +99,7 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -138,6 +140,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -182,6 +185,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -198,6 +202,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -241,6 +246,7 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -283,6 +289,7 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -324,6 +331,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -11,10 +11,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
@@ -40,6 +36,10 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user