mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 13:29:29 -05:00
Compare commits
4 Commits
fix-kv-pre
...
prioritise
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5de67883c0 | ||
|
|
6dbbe7797b | ||
|
|
9357503c6f | ||
|
|
ba19940828 |
@@ -2,14 +2,11 @@
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Literal, Self
|
||||
|
||||
import mlx.nn as nn
|
||||
import mlx.core as mx
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Self
|
||||
from mlx.core import array
|
||||
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
import mlx.core as mx
|
||||
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
@@ -35,7 +32,6 @@ def make_prompt_cache(
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
"""
|
||||
...
|
||||
|
||||
def save_prompt_cache(
|
||||
file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...
|
||||
@@ -49,7 +45,6 @@ def save_prompt_cache(
|
||||
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||
state.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_prompt_cache(file_name: str, return_metadata=...) -> array:
|
||||
"""
|
||||
@@ -64,15 +59,13 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
the metadata if requested.
|
||||
"""
|
||||
...
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
def can_trim_prompt_cache(cache: List[Cache]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
...
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
|
||||
def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
|
||||
@@ -86,7 +79,6 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
...
|
||||
|
||||
def create_attention_mask(
|
||||
N: int, offset: int, return_array: bool, window_size: Optional[int]
|
||||
@@ -115,125 +107,164 @@ class ConcatenateKVCache(_BaseCache):
|
||||
KVCache with a larger step size before using this cache.
|
||||
"""
|
||||
def __init__(self) -> None: ...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
|
||||
...
|
||||
@property
|
||||
def state(self): ...
|
||||
def state(self): # -> tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def update_and_fetch(self, keys, values): # -> Any:
|
||||
...
|
||||
@property
|
||||
def state(self): ...
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): ...
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): ...
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self) -> None: ...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||
...
|
||||
@property
|
||||
def state(self) -> tuple[array, array]: ...
|
||||
def state(
|
||||
self,
|
||||
) -> tuple[array, array]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self, max_size, keep=...) -> None: ...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def update_and_fetch(
|
||||
self, keys, values
|
||||
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||
...
|
||||
@property
|
||||
def state(self): ...
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): ...
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): ...
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(
|
||||
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): ...
|
||||
): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class ArraysCache(_BaseCache):
|
||||
def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||
def __setitem__(self, idx, value): ...
|
||||
def __setitem__(self, idx, value): # -> None:
|
||||
...
|
||||
def __getitem__(self, idx): ...
|
||||
@property
|
||||
def state(self): ...
|
||||
def state(self): # -> list[Any | array] | list[array]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): ...
|
||||
def filter(self, batch_indices):
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other):
|
||||
def extend(self, other): # -> None:
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def make_mask(self, N: int): ...
|
||||
def make_mask(self, N: int): # -> array | None:
|
||||
...
|
||||
|
||||
class MambaCache(ArraysCache):
|
||||
def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||
|
||||
class ChunkedKVCache(KVCache):
|
||||
def __init__(self, chunk_size) -> None: ...
|
||||
def maybe_trim_front(self): ...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def trim(self, n): ...
|
||||
def maybe_trim_front(self): # -> None:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): ...
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): ...
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
|
||||
class CacheList(_BaseCache):
|
||||
def __init__(self, *caches) -> None: ...
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): ...
|
||||
@property
|
||||
def state(self): ...
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): ...
|
||||
def filter(self, batch_indices):
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other):
|
||||
def extend(self, other): # -> None:
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
class BatchKVCache(_BaseCache):
|
||||
step = ...
|
||||
@@ -256,56 +287,71 @@ class BatchKVCache(_BaseCache):
|
||||
And ``left_padding`` specifies the amount of padding for each.
|
||||
In this case, ``left_padding = [1, 3, 0]``.
|
||||
"""
|
||||
...
|
||||
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||
...
|
||||
@property
|
||||
def state(self): ...
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, N: int, return_array: bool = ..., **kwargs): ...
|
||||
def filter(self, batch_indices):
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int | float:
|
||||
...
|
||||
def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other):
|
||||
def extend(self, other): # -> None:
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
class BatchRotatingKVCache(_BaseCache):
|
||||
step = ...
|
||||
def __init__(self, max_size, left_padding: List[int]) -> None: ...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def update_and_fetch(
|
||||
self, keys, values
|
||||
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||
...
|
||||
@property
|
||||
def state(self): ...
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): ...
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): ...
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(
|
||||
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): ...
|
||||
def filter(self, batch_indices):
|
||||
): # -> array:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other):
|
||||
def extend(self, other): # -> None:
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -8,10 +8,6 @@ from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
class StreamingDetokenizer:
|
||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||
|
||||
@@ -49,7 +45,6 @@ class StreamingDetokenizer:
|
||||
@property
|
||||
def last_segment(self):
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
...
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
@@ -59,11 +54,15 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
repeatedly detokenize the same tokens until a new line is generated.
|
||||
"""
|
||||
def __init__(self, tokenizer) -> None: ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
@property
|
||||
def text(self): ...
|
||||
def text(self): # -> str:
|
||||
...
|
||||
|
||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for SPM models.
|
||||
@@ -72,9 +71,12 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
underscore which results in linear complexity.
|
||||
"""
|
||||
def __init__(self, tokenizer, trim_space=...) -> None: ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
|
||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for OpenAI style BPE models.
|
||||
@@ -86,13 +88,15 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
_byte_decoder = ...
|
||||
_space_matches = ...
|
||||
def __init__(self, tokenizer) -> None: ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
@classmethod
|
||||
def make_byte_decoder(cls):
|
||||
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
@classmethod
|
||||
def make_byte_decoder(cls): # -> None:
|
||||
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||
|
||||
class TokenizerWrapper:
|
||||
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
||||
@@ -153,10 +157,13 @@ class TokenizerWrapper:
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def encode(self, text, **kwargs): ...
|
||||
def encode(self, text, **kwargs): # -> list[int]:
|
||||
...
|
||||
def encode_batch(self, texts, **kwargs): ...
|
||||
def decode(self, *args, **kwargs): ...
|
||||
def batch_decode(self, *args, **kwargs): ...
|
||||
def decode(self, *args, **kwargs): # -> str:
|
||||
...
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
|
||||
def load(
|
||||
model_path: Path,
|
||||
@@ -169,7 +176,6 @@ def load(
|
||||
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
...
|
||||
|
||||
# Alias for backward compatibility
|
||||
load_tokenizer = load
|
||||
|
||||
@@ -216,6 +216,8 @@ export interface Message {
|
||||
attachments?: MessageAttachment[];
|
||||
ttftMs?: number; // Time to first token in ms (for assistant messages)
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
requestType?: "chat" | "image-generation" | "image-editing";
|
||||
sourceImageDataUrl?: string; // For image editing regeneration
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -1270,10 +1272,46 @@ class AppStore {
|
||||
|
||||
if (lastUserIndex === -1) return;
|
||||
|
||||
// Remove any messages after the user message
|
||||
this.messages = this.messages.slice(0, lastUserIndex + 1);
|
||||
const lastUserMessage = this.messages[lastUserIndex];
|
||||
const requestType = lastUserMessage.requestType || "chat";
|
||||
const prompt = lastUserMessage.content;
|
||||
|
||||
// Resend the message to get a new response
|
||||
// Remove messages after user message (including the user message for image requests
|
||||
// since generateImage/editImage will re-add it)
|
||||
this.messages = this.messages.slice(0, lastUserIndex);
|
||||
|
||||
switch (requestType) {
|
||||
case "image-generation":
|
||||
await this.generateImage(prompt);
|
||||
break;
|
||||
case "image-editing":
|
||||
if (lastUserMessage.sourceImageDataUrl) {
|
||||
await this.editImage(prompt, lastUserMessage.sourceImageDataUrl);
|
||||
} else {
|
||||
// Can't regenerate edit without source image - restore user message and show error
|
||||
this.messages.push(lastUserMessage);
|
||||
const errorMessage = this.addMessage("assistant", "");
|
||||
const idx = this.messages.findIndex((m) => m.id === errorMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content =
|
||||
"Error: Cannot regenerate image edit - source image not found";
|
||||
}
|
||||
this.updateActiveConversation();
|
||||
}
|
||||
break;
|
||||
case "chat":
|
||||
default:
|
||||
// Restore the user message for chat regeneration
|
||||
this.messages.push(lastUserMessage);
|
||||
await this.regenerateChatCompletion();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to regenerate a chat completion response
|
||||
*/
|
||||
private async regenerateChatCompletion(): Promise<void> {
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -1788,6 +1826,7 @@ class AppStore {
|
||||
role: "user",
|
||||
content: prompt,
|
||||
timestamp: Date.now(),
|
||||
requestType: "image-generation",
|
||||
};
|
||||
this.messages.push(userMessage);
|
||||
|
||||
@@ -1998,6 +2037,8 @@ class AppStore {
|
||||
role: "user",
|
||||
content: prompt,
|
||||
timestamp: Date.now(),
|
||||
requestType: "image-editing",
|
||||
sourceImageDataUrl: imageDataUrl,
|
||||
};
|
||||
this.messages.push(userMessage);
|
||||
|
||||
@@ -2187,6 +2228,54 @@ class AppStore {
|
||||
this.conversations.find((c) => c.id === this.activeConversationId) || null
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a download on a specific node
|
||||
*/
|
||||
async startDownload(nodeId: string, shardMetadata: object): Promise<void> {
|
||||
try {
|
||||
const response = await fetch("/download/start", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
targetNodeId: nodeId,
|
||||
shardMetadata: shardMetadata,
|
||||
}),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Failed to start download: ${response.status} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error starting download:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a downloaded model from a specific node
|
||||
*/
|
||||
async deleteDownload(nodeId: string, modelId: string): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/download/${encodeURIComponent(nodeId)}/${encodeURIComponent(modelId)}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
},
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Failed to delete download: ${response.status} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error deleting download:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const appStore = new AppStore();
|
||||
@@ -2292,3 +2381,9 @@ export const setImageGenerationParams = (
|
||||
) => appStore.setImageGenerationParams(params);
|
||||
export const resetImageGenerationParams = () =>
|
||||
appStore.resetImageGenerationParams();
|
||||
|
||||
// Download actions
|
||||
export const startDownload = (nodeId: string, shardMetadata: object) =>
|
||||
appStore.startDownload(nodeId, shardMetadata);
|
||||
export const deleteDownload = (nodeId: string, modelId: string) =>
|
||||
appStore.deleteDownload(nodeId, modelId);
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
type DownloadProgress,
|
||||
refreshState,
|
||||
lastUpdate as lastUpdateStore,
|
||||
startDownload,
|
||||
deleteDownload,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
@@ -28,6 +30,7 @@
|
||||
etaMs: number;
|
||||
status: "completed" | "downloading";
|
||||
files: FileProgress[];
|
||||
shardMetadata?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type NodeEntry = {
|
||||
@@ -269,6 +272,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
// Extract shard_metadata for use with download actions
|
||||
const shardMetadata = (downloadPayload.shard_metadata ??
|
||||
downloadPayload.shardMetadata) as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
|
||||
const entry: ModelEntry = {
|
||||
modelId,
|
||||
prettyName,
|
||||
@@ -285,6 +294,7 @@
|
||||
? "completed"
|
||||
: "downloading",
|
||||
files,
|
||||
shardMetadata,
|
||||
};
|
||||
|
||||
const existing = modelMap.get(modelId);
|
||||
@@ -469,6 +479,52 @@
|
||||
>
|
||||
{pct.toFixed(1)}%
|
||||
</span>
|
||||
{#if model.status !== "completed" && model.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
onclick={() =>
|
||||
startDownload(node.nodeId, model.shardMetadata!)}
|
||||
title="Start download"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
{#if model.status === "completed"}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-red-400 transition-colors"
|
||||
onclick={() =>
|
||||
deleteDownload(node.nodeId, model.modelId)}
|
||||
title="Delete download"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
|
||||
284
src/exo/download/coordinator.py
Normal file
284
src/exo/download/coordinator.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import (
|
||||
RepoDownloadProgress,
|
||||
delete_model,
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
with self.download_command_receiver as commands:
|
||||
async for cmd in commands:
|
||||
# Only process commands targeting this node
|
||||
if cmd.command.target_node_id != self.node_id:
|
||||
continue
|
||||
|
||||
match cmd.command:
|
||||
case StartDownload(shard_metadata=shard):
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Check if already downloading or complete
|
||||
if model_id in self.download_status:
|
||||
status = self.download_status[model_id]
|
||||
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
|
||||
logger.debug(
|
||||
f"Download for {model_id} already in progress or complete, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Emit pending status
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
# Check initial status from downloader
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
return
|
||||
|
||||
# Start actual download
|
||||
self._start_download_task(shard, initial_progress)
|
||||
|
||||
def _start_download_task(
|
||||
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Emit ongoing status
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal last_progress_time
|
||||
|
||||
if progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
# Clean up active download tracking
|
||||
if callback_shard.model_card.model_id in self.active_downloads:
|
||||
del self.active_downloads[callback_shard.model_card.model_id]
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
ongoing = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=callback_shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=ongoing)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
async def download_wrapper() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(shard)
|
||||
except Exception as e:
|
||||
logger.error(f"Download failed for {model_id}: {e}")
|
||||
failed = DownloadFailed(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=str(e),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed)
|
||||
)
|
||||
finally:
|
||||
if model_id in self.active_downloads:
|
||||
del self.active_downloads[model_id]
|
||||
|
||||
task = asyncio.create_task(download_wrapper())
|
||||
self.active_downloads[model_id] = task
|
||||
|
||||
async def _delete_download(self, model_id: ModelId) -> None:
|
||||
# Cancel if active
|
||||
if model_id in self.active_downloads:
|
||||
logger.info(f"Cancelling active download for {model_id} before deletion")
|
||||
self.active_downloads[model_id].cancel()
|
||||
del self.active_downloads[model_id]
|
||||
|
||||
# Delete from disk
|
||||
logger.info(f"Deleting model files for {model_id}")
|
||||
deleted = await delete_model(model_id)
|
||||
|
||||
if deleted:
|
||||
logger.info(f"Successfully deleted model {model_id}")
|
||||
else:
|
||||
logger.warning(f"Model {model_id} was not found on disk")
|
||||
|
||||
# Emit pending status to reset UI state, then remove from local tracking
|
||||
if model_id in self.download_status:
|
||||
current_status = self.download_status[model_id]
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info(
|
||||
"DownloadCoordinator: Fetching and emitting existing download progress..."
|
||||
)
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info(
|
||||
"DownloadCoordinator: Done emitting existing download progress."
|
||||
)
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DownloadCoordinator: Error emitting existing download progress: {e}"
|
||||
)
|
||||
@@ -24,6 +24,13 @@ from pydantic import (
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from exo.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -35,13 +42,6 @@ from exo.shared.types.worker.downloads import (
|
||||
RepoFileDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceAuthenticationError(Exception):
|
||||
@@ -5,13 +5,13 @@ from typing import AsyncIterator, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
@@ -5,13 +5,13 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
@@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
from typing import Iterator, Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -12,6 +13,8 @@ from loguru import logger
|
||||
from pydantic import PositiveInt
|
||||
|
||||
import exo.routing.topics as topics
|
||||
from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
@@ -21,7 +24,6 @@ from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
@@ -29,6 +31,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
download_coordinator: DownloadCoordinator | None
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
@@ -36,6 +39,7 @@ class Node:
|
||||
api: API | None
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -49,8 +53,26 @@ class Node:
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
await router.register_topic(topics.DOWNLOAD_COMMANDS)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
@@ -58,6 +80,7 @@ class Node:
|
||||
port=args.api_port,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
)
|
||||
else:
|
||||
@@ -67,11 +90,12 @@ class Node:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
@@ -99,13 +123,25 @@ class Node:
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
return cls(router, worker, election, er_recv, master, api, node_id)
|
||||
return cls(
|
||||
router,
|
||||
download_coordinator,
|
||||
worker,
|
||||
election,
|
||||
er_recv,
|
||||
master,
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
@@ -170,13 +206,27 @@ class Node:
|
||||
)
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
self.worker.shutdown()
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
@@ -185,6 +235,10 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
@@ -226,6 +280,7 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -268,6 +323,11 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-downloads",
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -44,6 +44,7 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
@@ -61,6 +62,8 @@ from exo.shared.types.api import (
|
||||
PlaceInstanceParams,
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
StartDownloadParams,
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
@@ -75,12 +78,16 @@ from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -156,12 +163,14 @@ class API:
|
||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log: list[Event] = []
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
@@ -260,6 +269,8 @@ class API:
|
||||
self.app.get("/images/{image_id}")(self.get_image)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -1292,3 +1303,28 @@ class API:
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def _send_download(self, command: DownloadCommand):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def start_download(
|
||||
self, payload: StartDownloadParams
|
||||
) -> StartDownloadResponse:
|
||||
command = StartDownload(
|
||||
target_node_id=payload.target_node_id,
|
||||
shard_metadata=payload.shard_metadata,
|
||||
)
|
||||
await self._send_download(command)
|
||||
return StartDownloadResponse(command_id=command.command_id)
|
||||
|
||||
async def delete_download(
|
||||
self, node_id: NodeId, model_id: ModelId
|
||||
) -> DeleteDownloadResponse:
|
||||
command = DeleteDownload(
|
||||
target_node_id=node_id,
|
||||
model_id=ModelId(model_id),
|
||||
)
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
@@ -257,7 +257,13 @@ def _find_ip_prioritised(
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.types.commands import ForwarderCommand
|
||||
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
@@ -45,3 +45,6 @@ ELECTION_MESSAGES = TypedTopic(
|
||||
CONNECTION_MESSAGES = TypedTopic(
|
||||
"connection_messages", PublishPolicy.Never, ConnectionMessage
|
||||
)
|
||||
DOWNLOAD_COMMANDS = TypedTopic(
|
||||
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
|
||||
)
|
||||
|
||||
@@ -621,7 +621,7 @@ class ConfigData(BaseModel):
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
@@ -643,11 +643,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
|
||||
async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
|
||||
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
@@ -7,10 +7,11 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -352,3 +353,16 @@ class ImageListItem(BaseModel, frozen=True):
|
||||
|
||||
class ImageListResponse(BaseModel, frozen=True):
|
||||
data: list[ImageListItem]
|
||||
|
||||
|
||||
class StartDownloadParams(CamelCaseModel):
|
||||
target_node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class StartDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
@@ -9,7 +9,7 @@ from exo.shared.types.api import (
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -62,6 +62,19 @@ class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
|
||||
class StartDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DeleteDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
|
||||
|
||||
Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
@@ -79,3 +92,8 @@ Command = (
|
||||
class ForwarderCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
command: Command
|
||||
|
||||
|
||||
class ForwarderDownloadCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
command: DownloadCommand
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Shared types for MLX-related functionality."""
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# Type alias for KV cache - matches make_kv_cache return type
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = list[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
@@ -48,7 +48,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
ecpu_usage: float = 0.0
|
||||
|
||||
|
||||
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
|
||||
InterfaceType = Literal["wifi", "ethernet", "maybe_ethernet", "thunderbolt", "unknown"]
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
|
||||
@@ -400,7 +400,7 @@ class InfoGatherer:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = get_network_interfaces()
|
||||
nics = await get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import socket
|
||||
import sys
|
||||
from subprocess import CalledProcessError, run
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
import psutil
|
||||
from anyio import run_process
|
||||
@@ -16,8 +16,7 @@ async def get_friendly_name() -> str:
|
||||
"""
|
||||
hostname = socket.gethostname()
|
||||
|
||||
# TODO: better non mac support
|
||||
if sys.platform != "darwin": # 'darwin' is the platform name for macOS
|
||||
if sys.platform != "darwin":
|
||||
return hostname
|
||||
|
||||
try:
|
||||
@@ -28,21 +27,20 @@ async def get_friendly_name() -> str:
|
||||
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
|
||||
|
||||
|
||||
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
"""Parse networksetup -listallhardwareports to get interface types."""
|
||||
if sys.platform != "darwin":
|
||||
return {}
|
||||
|
||||
try:
|
||||
result = run(
|
||||
["networksetup", "-listallhardwareports"], capture_output=True, text=True
|
||||
)
|
||||
except Exception:
|
||||
result = await run_process(["networksetup", "-listallhardwareports"])
|
||||
except CalledProcessError:
|
||||
return {}
|
||||
|
||||
types: dict[str, InterfaceType] = {}
|
||||
current_type: InterfaceType = "unknown"
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
for line in result.stdout.decode().splitlines():
|
||||
if line.startswith("Hardware Port:"):
|
||||
port_name = line.split(":", 1)[1].strip()
|
||||
if "Wi-Fi" in port_name:
|
||||
@@ -55,12 +53,15 @@ def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
current_type = "unknown"
|
||||
elif line.startswith("Device:"):
|
||||
device = line.split(":", 1)[1].strip()
|
||||
# enX is ethernet adapters or thunderbolt - these must be deprioritised
|
||||
if device.startswith("en") and device not in ["en0", "en1"]:
|
||||
current_type = "maybe_ethernet"
|
||||
types[device] = current_type
|
||||
|
||||
return types
|
||||
|
||||
|
||||
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
"""
|
||||
Retrieves detailed network interface information on macOS.
|
||||
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
|
||||
@@ -68,7 +69,7 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
Returns a list of NetworkInterfaceInfo objects.
|
||||
"""
|
||||
interfaces_info: list[NetworkInterfaceInfo] = []
|
||||
interface_types = _get_interface_types_from_networksetup()
|
||||
interface_types = await _get_interface_types_from_networksetup()
|
||||
|
||||
for iface, services in psutil.net_if_addrs().items():
|
||||
for service in services:
|
||||
|
||||
32
src/exo/utils/keyed_backoff.py
Normal file
32
src/exo/utils/keyed_backoff.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import time
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
class KeyedBackoff(Generic[K]):
|
||||
"""Tracks exponential backoff state per key."""
|
||||
|
||||
def __init__(self, base: float = 0.5, cap: float = 10.0):
|
||||
self._base = base
|
||||
self._cap = cap
|
||||
self._attempts: dict[K, int] = {}
|
||||
self._last_time: dict[K, float] = {}
|
||||
|
||||
def should_proceed(self, key: K) -> bool:
|
||||
"""Returns True if enough time has elapsed since last attempt."""
|
||||
now = time.monotonic()
|
||||
last = self._last_time.get(key, 0.0)
|
||||
attempts = self._attempts.get(key, 0)
|
||||
delay = min(self._cap, self._base * (2.0**attempts))
|
||||
return now - last >= delay
|
||||
|
||||
def record_attempt(self, key: K) -> None:
|
||||
"""Record that an attempt was made for this key."""
|
||||
self._last_time[key] = time.monotonic()
|
||||
self._attempts[key] = self._attempts.get(key, 0) + 1
|
||||
|
||||
def reset(self, key: K) -> None:
|
||||
"""Reset backoff state for a key (e.g., on success)."""
|
||||
self._attempts.pop(key, None)
|
||||
self._last_time.pop(key, None)
|
||||
@@ -6,10 +6,10 @@ import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import AdvancedImageParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
|
||||
@@ -1,53 +1,39 @@
|
||||
# type: ignore
|
||||
# TODO: Fix this file, including types!
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: KVCacheType
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
):
|
||||
tokenized_prompt = encode_prompt(tokenizer, prompt)
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
logger.info(f"KV cache saved: {len(tokenized_prompt)} tokens")
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> tuple[KVCacheType, mx.array]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
This method finds the best matching cached prefix and returns:
|
||||
- A copy of the cache trimmed to the prefix length
|
||||
- The remaining tokens that need to be prefilled before generation
|
||||
|
||||
The caller is responsible for prefilling the remaining tokens.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache, remaining_tokens) where remaining_tokens are the
|
||||
tokens that still need to be prefilled/processed.
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(tokenizer, prompt)
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
@@ -56,75 +42,63 @@ class KVPrefixCache:
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(prompt_cache, tokens_to_trim)
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:]
|
||||
return self.caches[i]
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(prompt_cache, tokens_to_trim)
|
||||
|
||||
# Return remaining tokens for caller to prefill
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
)
|
||||
|
||||
# Return all tokens for caller to prefill
|
||||
return prompt_cache, tokenized_prompt
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
|
||||
return prompt_cache
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
For chat-templated prompts (which have their own structure markers like
|
||||
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache)
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import time
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
@@ -14,13 +14,11 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -32,59 +30,19 @@ from exo.worker.runner.bootstrap import logger
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> float:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
then trims off the extra generated token.
|
||||
|
||||
Returns:
|
||||
tokens_per_sec
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cache, 1)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
logger.debug(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[KVCache | Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -162,7 +120,6 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -174,16 +131,7 @@ def mlx_generate(
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
if is_bench:
|
||||
kv_prefix_cache = None
|
||||
|
||||
# Use prefix cache if available, otherwise create fresh cache
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
else:
|
||||
caches, prompt_tokens = kv_prefix_cache.get_kv_cache(model, tokenizer, prompt)
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -196,19 +144,11 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[-1:], caches)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
@@ -218,13 +158,12 @@ def mlx_generate(
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
@@ -246,22 +185,6 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
generated_tokens = len(generated_text_parts)
|
||||
generation_tps = (
|
||||
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
|
||||
)
|
||||
logger.debug(
|
||||
f"Generation complete: prefill {prompt_tokens} tokens @ "
|
||||
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
# Save cache for future prefix matching (clear first to keep only the last one)
|
||||
if kv_prefix_cache is not None:
|
||||
kv_prefix_cache.clear()
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
kv_prefix_cache.add_kv_cache(tokenizer, full_prompt, caches)
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -41,6 +41,7 @@ import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
from pydantic import RootModel
|
||||
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -55,7 +56,6 @@ from exo.shared.types.worker.shards import (
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
TimeoutCallback,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, create_task_group, current_time, fail_after
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
@@ -10,7 +11,12 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
RequestEventLog,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
@@ -18,7 +24,6 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -36,23 +41,12 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -62,7 +56,6 @@ class Worker:
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
shard_downloader: ShardDownloader,
|
||||
*,
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
@@ -70,23 +63,22 @@ class Worker:
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
|
||||
self.shard_downloader: ShardDownloader = shard_downloader
|
||||
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.local_event_index = 0
|
||||
self.event_index_counter = event_index_counter
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.connection_message_receiver = connection_message_receiver
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
|
||||
@@ -101,6 +93,8 @@ class Worker:
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
@@ -111,7 +105,6 @@ class Worker:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
@@ -121,6 +114,7 @@ class Worker:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
@@ -179,11 +173,9 @@ class Worker:
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
# 3. based on the updated state, we plan & execute an operation.
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.download_status,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
@@ -207,42 +199,26 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard.model_card.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
model_id = shard.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
continue
|
||||
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
),
|
||||
)
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
@@ -387,104 +363,17 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
initial_progress: RepoDownloadProgress,
|
||||
):
|
||||
"""Manages the shard download process with progress tracking."""
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal self
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
async def download_with_error_handling() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(task.shard_metadata)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(
|
||||
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
|
||||
)
|
||||
failed_status = DownloadFailed(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
error_message=error_message,
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = (
|
||||
failed_status
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed_status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Failed
|
||||
)
|
||||
)
|
||||
|
||||
self._tg.start_soon(download_with_error_handling)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=self.local_event_index,
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
|
||||
)
|
||||
self.local_event_index += 1
|
||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
|
||||
@@ -532,42 +421,3 @@ class Worker:
|
||||
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.debug("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
@@ -45,9 +44,6 @@ def plan(
|
||||
node_id: NodeId,
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
# DL_status is expected to be FRESH and so should not come from state
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
# gdls is not expected to be fresh
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
@@ -59,7 +55,7 @@ def plan(
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
@@ -115,9 +111,15 @@ def _create_runner(
|
||||
|
||||
|
||||
def _model_needs_download(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> DownloadModel | None:
|
||||
local_downloads = global_download_status.get(node_id, [])
|
||||
download_status = {
|
||||
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
|
||||
}
|
||||
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
|
||||
@@ -70,7 +70,6 @@ from exo.worker.engines.image import (
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -104,7 +103,6 @@ def main(
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -173,9 +171,6 @@ def main(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
@@ -243,7 +238,6 @@ def main(
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
|
||||
@@ -11,12 +11,12 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import ModelId, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
@@ -45,13 +45,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ModelId, DownloadProgress] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=download_status,
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -92,14 +88,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -116,7 +104,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -148,23 +135,19 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
# Global state shows shard is downloaded for NODE_A
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -202,12 +185,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -220,7 +197,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -245,7 +221,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -47,8 +47,7 @@ def test_plan_kills_runner_when_instance_missing():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -87,8 +86,7 @@ def test_plan_kills_runner_when_sibling_failed():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -120,7 +118,6 @@ def test_plan_creates_runner_when_missing_for_node():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners,
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -158,8 +155,7 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -189,7 +185,6 @@ def test_plan_does_not_create_runner_for_unassigned_node():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -65,7 +65,6 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -113,7 +112,6 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -158,7 +156,6 @@ def test_plan_does_not_forward_tasks_for_other_instances():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -221,7 +218,6 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -261,7 +257,6 @@ def test_plan_returns_none_when_nothing_to_do():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -57,7 +57,6 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -99,7 +98,6 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -140,7 +138,6 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -185,7 +182,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -202,7 +198,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -246,7 +241,6 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -289,7 +283,6 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -331,7 +324,6 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -11,6 +11,10 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
@@ -36,10 +40,6 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user