mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 13:29:29 -05:00
Compare commits
1 Commits
ciaran/fix
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6dbbe7797b |
@@ -2228,6 +2228,54 @@ class AppStore {
|
||||
this.conversations.find((c) => c.id === this.activeConversationId) || null
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a download on a specific node
|
||||
*/
|
||||
async startDownload(nodeId: string, shardMetadata: object): Promise<void> {
|
||||
try {
|
||||
const response = await fetch("/download/start", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
targetNodeId: nodeId,
|
||||
shardMetadata: shardMetadata,
|
||||
}),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Failed to start download: ${response.status} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error starting download:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a downloaded model from a specific node
|
||||
*/
|
||||
async deleteDownload(nodeId: string, modelId: string): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/download/${encodeURIComponent(nodeId)}/${encodeURIComponent(modelId)}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
},
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Failed to delete download: ${response.status} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error deleting download:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const appStore = new AppStore();
|
||||
@@ -2333,3 +2381,9 @@ export const setImageGenerationParams = (
|
||||
) => appStore.setImageGenerationParams(params);
|
||||
export const resetImageGenerationParams = () =>
|
||||
appStore.resetImageGenerationParams();
|
||||
|
||||
// Download actions
|
||||
export const startDownload = (nodeId: string, shardMetadata: object) =>
|
||||
appStore.startDownload(nodeId, shardMetadata);
|
||||
export const deleteDownload = (nodeId: string, modelId: string) =>
|
||||
appStore.deleteDownload(nodeId, modelId);
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
type DownloadProgress,
|
||||
refreshState,
|
||||
lastUpdate as lastUpdateStore,
|
||||
startDownload,
|
||||
deleteDownload,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
@@ -28,6 +30,7 @@
|
||||
etaMs: number;
|
||||
status: "completed" | "downloading";
|
||||
files: FileProgress[];
|
||||
shardMetadata?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type NodeEntry = {
|
||||
@@ -269,6 +272,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
// Extract shard_metadata for use with download actions
|
||||
const shardMetadata = (downloadPayload.shard_metadata ??
|
||||
downloadPayload.shardMetadata) as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
|
||||
const entry: ModelEntry = {
|
||||
modelId,
|
||||
prettyName,
|
||||
@@ -285,6 +294,7 @@
|
||||
? "completed"
|
||||
: "downloading",
|
||||
files,
|
||||
shardMetadata,
|
||||
};
|
||||
|
||||
const existing = modelMap.get(modelId);
|
||||
@@ -469,6 +479,52 @@
|
||||
>
|
||||
{pct.toFixed(1)}%
|
||||
</span>
|
||||
{#if model.status !== "completed" && model.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
onclick={() =>
|
||||
startDownload(node.nodeId, model.shardMetadata!)}
|
||||
title="Start download"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
{#if model.status === "completed"}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-red-400 transition-colors"
|
||||
onclick={() =>
|
||||
deleteDownload(node.nodeId, model.modelId)}
|
||||
title="Delete download"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
|
||||
@@ -80,6 +80,7 @@ class Node:
|
||||
port=args.api_port,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -44,6 +44,7 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
@@ -61,6 +62,8 @@ from exo.shared.types.api import (
|
||||
PlaceInstanceParams,
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
StartDownloadParams,
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
@@ -75,12 +78,16 @@ from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -156,12 +163,14 @@ class API:
|
||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log: list[Event] = []
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
@@ -260,6 +269,8 @@ class API:
|
||||
self.app.get("/images/{image_id}")(self.get_image)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -1292,3 +1303,28 @@ class API:
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def _send_download(self, command: DownloadCommand):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def start_download(
|
||||
self, payload: StartDownloadParams
|
||||
) -> StartDownloadResponse:
|
||||
command = StartDownload(
|
||||
target_node_id=payload.target_node_id,
|
||||
shard_metadata=payload.shard_metadata,
|
||||
)
|
||||
await self._send_download(command)
|
||||
return StartDownloadResponse(command_id=command.command_id)
|
||||
|
||||
async def delete_download(
|
||||
self, node_id: NodeId, model_id: ModelId
|
||||
) -> DeleteDownloadResponse:
|
||||
command = DeleteDownload(
|
||||
target_node_id=node_id,
|
||||
model_id=ModelId(model_id),
|
||||
)
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
@@ -7,10 +7,11 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -352,3 +353,16 @@ class ImageListItem(BaseModel, frozen=True):
|
||||
|
||||
class ImageListResponse(BaseModel, frozen=True):
|
||||
data: list[ImageListItem]
|
||||
|
||||
|
||||
class StartDownloadParams(CamelCaseModel):
|
||||
target_node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class StartDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
@@ -140,7 +140,6 @@ class DistributedImageModel:
|
||||
width=width,
|
||||
image_path=image_path,
|
||||
model_config=self._adapter.model.model_config, # pyright: ignore[reportAny]
|
||||
guidance=guidance_override if guidance_override is not None else 4.0,
|
||||
)
|
||||
|
||||
num_sync_steps = self._config.get_num_sync_steps(steps)
|
||||
|
||||
Reference in New Issue
Block a user