mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 19:52:16 -05:00
Compare commits
3 Commits
alexcheema
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b907398a4 | ||
|
|
572e647908 | ||
|
|
e59ebd986d |
@@ -5,21 +5,21 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
isLoading,
|
||||
stopGeneration,
|
||||
sendMessage,
|
||||
generateImage,
|
||||
editImage,
|
||||
@@ -606,92 +605,86 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => stopGeneration()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-red-900/50 hover:text-red-400 border border-exo-medium-gray/50 hover:border-red-500/50 cursor-pointer"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<svg
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<rect x="4" y="4" width="16" height="16" rx="2" />
|
||||
</svg>
|
||||
<span class="hidden sm:inline">STOP</span>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -470,7 +470,6 @@ class AppStore {
|
||||
messages = $state<Message[]>([]);
|
||||
currentResponse = $state("");
|
||||
isLoading = $state(false);
|
||||
private currentAbortController: AbortController | null = null;
|
||||
|
||||
// Performance metrics
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
@@ -1739,11 +1738,9 @@ class AppStore {
|
||||
return;
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -1857,7 +1854,6 @@ class AppStore {
|
||||
"Unknown error",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -1989,10 +1985,6 @@ class AppStore {
|
||||
assistantMessageId: string,
|
||||
errorPrefix = "Failed to get response",
|
||||
): void {
|
||||
// Don't show error for user-initiated abort (stop button)
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
@@ -2034,17 +2026,6 @@ class AppStore {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the current generation by aborting the HTTP connection.
|
||||
* This triggers backend cancellation via the mechanism in PR #1276.
|
||||
*/
|
||||
stopGeneration() {
|
||||
if (this.currentAbortController) {
|
||||
this.currentAbortController.abort();
|
||||
this.currentAbortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to the LLM and stream the response
|
||||
*/
|
||||
@@ -2192,13 +2173,11 @@ class AppStore {
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = 0;
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -2346,7 +2325,6 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2446,13 +2424,11 @@ class AppStore {
|
||||
};
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/images/generations", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
@@ -2601,7 +2577,6 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2730,10 +2705,8 @@ class AppStore {
|
||||
);
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
signal: this.currentAbortController.signal,
|
||||
body: formData,
|
||||
});
|
||||
|
||||
@@ -2843,7 +2816,6 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2972,7 +2944,6 @@ export const hasStartedChat = () => appStore.hasStartedChat;
|
||||
export const messages = () => appStore.messages;
|
||||
export const currentResponse = () => appStore.currentResponse;
|
||||
export const isLoading = () => appStore.isLoading;
|
||||
export const stopGeneration = () => appStore.stopGeneration();
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
|
||||
@@ -118,9 +118,10 @@
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit (self'.packages) metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
default = self'.packages.exo;
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
2
justfile
2
justfile
@@ -20,7 +20,7 @@ sync-clean:
|
||||
|
||||
rust-rebuild:
|
||||
cargo run --bin stub_gen
|
||||
just sync-clean
|
||||
uv sync --reinstall-package exo_pyo3_bindings
|
||||
|
||||
build-dashboard:
|
||||
#!/usr/bin/env bash
|
||||
|
||||
@@ -16,6 +16,7 @@ from exo.download.download_utils import (
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
@@ -53,11 +54,10 @@ class DownloadCoordinator:
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
@@ -108,6 +108,13 @@ class DownloadCoordinator:
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
case CancelDownload(model_id=model_id):
|
||||
await self._cancel_download(model_id)
|
||||
|
||||
async def _cancel_download(self, model_id: ModelId) -> None:
|
||||
if model_id in self.active_downloads and model_id in self.download_status:
|
||||
logger.info(f"Cancelling download for {model_id}")
|
||||
self.active_downloads.pop(model_id).cancel()
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
# I marked this as a dataclass as I want trivial constructors.
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
@@ -106,6 +105,7 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -136,7 +136,6 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -148,6 +147,8 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
@@ -188,6 +189,9 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
|
||||
@@ -176,7 +176,7 @@ async def generate_chat_stream(
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str]:
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
@@ -223,7 +223,7 @@ async def collect_chat_response(
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
@@ -241,5 +241,4 @@ async def collect_chat_response(
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
).model_dump_json()
|
||||
return
|
||||
)
|
||||
|
||||
@@ -123,7 +123,6 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -530,14 +529,16 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
@@ -632,14 +633,11 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
@@ -655,7 +653,8 @@ class API:
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
return await self._collect_text_generation_with_stats(command.command_id)
|
||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
"""Validate a text model exists and return the resolved model ID.
|
||||
@@ -857,11 +856,6 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -943,11 +937,6 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -1331,29 +1320,40 @@ class API:
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
tg.start_soon(self.run_api, shutdown_ev)
|
||||
try:
|
||||
await anyio.sleep_forever()
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
shutdown_ev.set()
|
||||
finally:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
cfg.bind = [f"0.0.0.0:{self.port}"]
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = None
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
cfg,
|
||||
shutdown_trigger=lambda: anyio.sleep_forever(),
|
||||
shutdown_trigger=ev.wait,
|
||||
)
|
||||
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
add_instance_to_placements,
|
||||
cancel_unnecessary_downloads,
|
||||
delete_instance,
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
@@ -16,12 +17,12 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
@@ -37,7 +38,6 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
@@ -68,12 +68,9 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
# Receiving indexed events from the forwarder to be applied to state
|
||||
# Ideally these would be WorkerForwarderEvents but type system says no :(
|
||||
local_event_receiver: Receiver[ForwarderEvent],
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -83,6 +80,7 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -98,16 +96,18 @@ class Master:
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -278,8 +278,16 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id, command=cmd
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
@@ -290,7 +298,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -300,7 +308,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -310,18 +318,6 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -330,9 +326,10 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
|
||||
@@ -15,20 +15,20 @@ from exo.master.placement_utils import (
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -186,7 +186,6 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -202,18 +201,6 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
@@ -221,3 +208,29 @@ def get_transition_events(
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def cancel_unnecessary_downloads(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> Sequence[DownloadCommand]:
|
||||
commands: list[DownloadCommand] = []
|
||||
currently_downloading = [
|
||||
(k, v.shard_metadata.model_card.model_id)
|
||||
for k, vs in download_status.items()
|
||||
for v in vs
|
||||
if isinstance(v, (DownloadOngoing))
|
||||
]
|
||||
active_models = set(
|
||||
(
|
||||
node_id,
|
||||
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
|
||||
)
|
||||
for instance in instances.values()
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
)
|
||||
for pair in currently_downloading:
|
||||
if pair not in active_models:
|
||||
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
|
||||
|
||||
return commands
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -47,6 +48,7 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +69,7 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
download_command_sender=fcds,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
||||
target_instances = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
|
||||
@@ -9,6 +9,7 @@ from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
move_on_after,
|
||||
sleep_forever,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -146,18 +147,21 @@ class Router:
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
with move_on_after(1, shield=True):
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
@@ -166,12 +170,12 @@ class Router:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
logger.info(f"Subscribed to {topic}")
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
logger.info(f"Unsubscribed from {topic}")
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
|
||||
@@ -86,28 +86,29 @@ class Election:
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Election")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election finished")
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
finally:
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election shutdown")
|
||||
|
||||
async def elect(self, em: ElectionMessage) -> None:
|
||||
logger.debug(f"Electing: {em}")
|
||||
|
||||
@@ -48,10 +48,6 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -76,7 +72,12 @@ class DeleteDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
class CancelDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
Command = (
|
||||
@@ -88,7 +89,6 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -24,7 +24,6 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -61,11 +60,6 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -93,7 +87,6 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -194,9 +194,10 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
return await to_thread.run_sync(
|
||||
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -32,6 +32,7 @@ from exo.worker.engines.mlx.constants import (
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -135,6 +136,10 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
@@ -398,3 +403,5 @@ def mlx_generate(
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -67,6 +67,8 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -84,6 +86,30 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -536,23 +562,3 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -32,7 +32,6 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -99,22 +98,23 @@ class Worker:
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
async with create_task_group() as tg:
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
finally:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
logger.info("Stopping Worker")
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
tg.start_soon(runner.shutdown)
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -218,22 +218,15 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await runner.start_task(task)
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -271,18 +264,18 @@ class Worker:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self._start_runner_task(modified_task)
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self._start_runner_task(task)
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
await self.runners[
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -321,6 +314,8 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -54,14 +53,13 @@ def plan(
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
|
||||
@@ -272,7 +270,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -286,7 +284,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -294,33 +292,16 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner_id, runner in runners.items():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id,
|
||||
cancelled_task_id=task.task_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,7 +15,6 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -39,7 +38,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
@@ -88,7 +87,6 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -113,7 +111,6 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
@@ -128,15 +125,11 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -149,7 +142,6 @@ def main(
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -195,7 +187,7 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
@@ -207,7 +199,7 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
@@ -215,6 +207,8 @@ def main(
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -226,30 +220,15 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert inference_model
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
t = time.perf_counter()
|
||||
toks = warmup_inference(
|
||||
model=inference_model,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / (time.perf_counter() - t)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
mx.max(
|
||||
mx.distributed.all_gather(
|
||||
mx.array([check_for_cancel_every]), group=group
|
||||
)
|
||||
).item()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||
)
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
@@ -257,8 +236,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -278,9 +257,9 @@ def main(
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
assert inference_model
|
||||
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
@@ -290,7 +269,7 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=inference_model,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
@@ -314,11 +293,11 @@ def main(
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(inference_model, GptOssModel):
|
||||
elif isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
inference_model, GptOssModel
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -331,18 +310,7 @@ def main(
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
@@ -414,7 +382,7 @@ def main(
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert image_model
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -427,9 +395,7 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||
|
||||
if is_primary_output:
|
||||
@@ -479,7 +445,7 @@ def main(
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert image_model
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -492,9 +458,7 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -560,7 +524,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del inference_model, image_model, tokenizer, group
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
|
||||
@@ -8,10 +8,8 @@ import anyio
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
to_thread,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
@@ -49,12 +47,9 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -65,8 +60,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -74,7 +69,6 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -89,7 +83,6 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -97,42 +90,35 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._forward_events)
|
||||
await self._forward_events()
|
||||
|
||||
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.close()
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 10)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._tg.cancel_scope.cancel()
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
@@ -155,13 +141,6 @@ class RunnerSupervisor:
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
@@ -226,4 +205,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
await self.shutdown()
|
||||
self.shutdown()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
@@ -20,7 +19,6 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -115,8 +113,6 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mx.distributed, "all_gather", make_nothin(mx.array([1])))
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
@@ -167,7 +163,6 @@ def _run(tasks: Iterable[Task]):
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||
event_sender = EventCollector()
|
||||
|
||||
with task_sender:
|
||||
@@ -179,7 +174,7 @@ def _run(tasks: Iterable[Task]):
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
wait
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
@@ -34,21 +34,13 @@ reset=$'\e[0m'
|
||||
i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
{
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
|
||||
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
} <<'EOF'
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q "$1"
|
||||
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
|
||||
EOF
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..."
|
||||
until curl -sf "http://$host:52415/models"; do sleep 1; done
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user