mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-06 04:01:16 -05:00
Compare commits
9 Commits
runner-can
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2f47802c8 | ||
|
|
a067fd0753 | ||
|
|
333119ec2f | ||
|
|
79661ea3a3 | ||
|
|
196f4fb35f | ||
|
|
08a4ca59c6 | ||
|
|
45ed3903f3 | ||
|
|
6657d8600f | ||
|
|
bf06580e0e |
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
isLoading,
|
||||
stopGeneration,
|
||||
sendMessage,
|
||||
generateImage,
|
||||
editImage,
|
||||
@@ -605,86 +606,92 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => stopGeneration()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-red-900/50 hover:text-red-400 border border-exo-medium-gray/50 hover:border-red-500/50 cursor-pointer"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
<rect x="4" y="4" width="16" height="16" rx="2" />
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
<span class="hidden sm:inline">STOP</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -470,6 +470,7 @@ class AppStore {
|
||||
messages = $state<Message[]>([]);
|
||||
currentResponse = $state("");
|
||||
isLoading = $state(false);
|
||||
private currentAbortController: AbortController | null = null;
|
||||
|
||||
// Performance metrics
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
@@ -1738,9 +1739,11 @@ class AppStore {
|
||||
return;
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -1854,6 +1857,7 @@ class AppStore {
|
||||
"Unknown error",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -1985,6 +1989,10 @@ class AppStore {
|
||||
assistantMessageId: string,
|
||||
errorPrefix = "Failed to get response",
|
||||
): void {
|
||||
// Don't show error for user-initiated abort (stop button)
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
@@ -2026,6 +2034,17 @@ class AppStore {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the current generation by aborting the HTTP connection.
|
||||
* This triggers backend cancellation via the mechanism in PR #1276.
|
||||
*/
|
||||
stopGeneration() {
|
||||
if (this.currentAbortController) {
|
||||
this.currentAbortController.abort();
|
||||
this.currentAbortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to the LLM and stream the response
|
||||
*/
|
||||
@@ -2173,11 +2192,13 @@ class AppStore {
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = 0;
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -2325,6 +2346,7 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2424,11 +2446,13 @@ class AppStore {
|
||||
};
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/images/generations", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
@@ -2577,6 +2601,7 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2705,8 +2730,10 @@ class AppStore {
|
||||
);
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
signal: this.currentAbortController.signal,
|
||||
body: formData,
|
||||
});
|
||||
|
||||
@@ -2816,6 +2843,7 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2944,6 +2972,7 @@ export const hasStartedChat = () => appStore.hasStartedChat;
|
||||
export const messages = () => appStore.messages;
|
||||
export const currentResponse = () => appStore.currentResponse;
|
||||
export const isLoading = () => appStore.isLoading;
|
||||
export const stopGeneration = () => appStore.stopGeneration();
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
|
||||
@@ -118,10 +118,9 @@
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
inherit (self'.packages) metal-toolchain;
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
default = self'.packages.exo;
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -53,10 +53,11 @@ class DownloadCoordinator:
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
|
||||
@@ -27,6 +27,7 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
# I marked this as a dataclass as I want trivial constructors.
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
@@ -135,6 +136,7 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -146,8 +148,6 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
|
||||
@@ -1331,40 +1331,29 @@ class API:
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
tg.start_soon(self.run_api, shutdown_ev)
|
||||
try:
|
||||
await anyio.sleep_forever()
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
shutdown_ev.set()
|
||||
finally:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
cfg.bind = [f"0.0.0.0:{self.port}"]
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = None
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
with anyio.CancelScope(shield=True):
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
cfg,
|
||||
shutdown_trigger=ev.wait,
|
||||
shutdown_trigger=lambda: anyio.sleep_forever(),
|
||||
)
|
||||
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -98,18 +98,16 @@ class Master:
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
|
||||
@@ -9,7 +9,6 @@ from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
move_on_after,
|
||||
sleep_forever,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -147,21 +146,18 @@ class Router:
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
with move_on_after(1, shield=True):
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
@@ -170,12 +166,12 @@ class Router:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
logger.info(f"Subscribed to {topic}")
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
logger.info(f"Unsubscribed from {topic}")
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
|
||||
@@ -86,29 +86,28 @@ class Election:
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Election")
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
finally:
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election shutdown")
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election finished")
|
||||
|
||||
async def elect(self, em: ElectionMessage) -> None:
|
||||
logger.debug(f"Electing: {em}")
|
||||
|
||||
@@ -125,9 +125,7 @@ class MpSender[T]:
|
||||
self._state.buffer.put(item, block=True)
|
||||
|
||||
async def send_async(self, item: T) -> None:
|
||||
await to_thread.run_sync(
|
||||
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
@@ -196,10 +194,9 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(
|
||||
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -99,23 +99,22 @@ class Worker:
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
finally:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
logger.info("Stopping Worker")
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
async with create_task_group() as tg:
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
tg.start_soon(runner.shutdown)
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -230,7 +229,7 @@ class Worker:
|
||||
)
|
||||
)
|
||||
finally:
|
||||
runner.shutdown()
|
||||
await runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
|
||||
@@ -8,8 +8,10 @@ import anyio
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
to_thread,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
@@ -48,6 +50,7 @@ class RunnerSupervisor:
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
@@ -94,29 +97,42 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
await self._forward_events()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._forward_events)
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.close()
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
await to_thread.run_sync(self.runner_process.join, 10)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
@@ -144,11 +160,7 @@ class RunnerSupervisor:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
with anyio.move_on_after(0.5) as scope:
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
if scope.cancel_called:
|
||||
logger.error("RunnerSupervisor cancel pipe blocked")
|
||||
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
@@ -214,4 +226,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
self.shutdown()
|
||||
await self.shutdown()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
import unittest.mock
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
@@ -116,6 +115,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mx.distributed, "all_gather", make_nothin(mx.array([1])))
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
@@ -178,11 +178,8 @@ def _run(tasks: Iterable[Task]):
|
||||
# this is some c++ nonsense
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
with unittest.mock.patch(
|
||||
"exo.worker.runner.runner.mx.distributed.all_gather", make_nothin(mx.array([1]))
|
||||
):
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
|
||||
done
|
||||
wait
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
@@ -34,13 +34,21 @@ reset=$'\e[0m'
|
||||
i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
{
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
|
||||
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
} <<'EOF'
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q "$1"
|
||||
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
|
||||
EOF
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..."
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
until curl -sf "http://$host:52415/models"; do sleep 1; done
|
||||
done
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user