mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-27 07:20:14 -05:00
Compare commits
7 Commits
rust-explo
...
ciaran/ima
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
409fa80600 | ||
|
|
5a94c21daa | ||
|
|
56ec049321 | ||
|
|
b477f88ace | ||
|
|
4ea6e32f7b | ||
|
|
49c5345e93 | ||
|
|
ea593075d7 |
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -514,20 +514,6 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
|
||||
|
||||
[[package]]
|
||||
name = "cluster_membership"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"futures-lite",
|
||||
"futures-timer",
|
||||
"libp2p",
|
||||
"log",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
@@ -1012,7 +998,6 @@ dependencies = [
|
||||
name = "exo_pyo3_bindings"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"cluster_membership",
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"env_logger",
|
||||
@@ -1045,12 +1030,6 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -1159,10 +1138,7 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/util",
|
||||
"rust/cluster_membership",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -26,7 +25,6 @@ opt-level = 3
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
cluster_membership = { path = "rust/cluster_membership" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
@@ -64,7 +62,6 @@ frunk-enum-core = "0.3"
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-lite = "2.6.1"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
|
||||
@@ -5,18 +5,18 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
|
||||
@@ -31,35 +31,6 @@ enum NetworkSetupHelper {
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
|
||||
@@ -3,28 +3,12 @@
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to ONLY include package.json and package-lock.json
|
||||
# This ensures prettier-svelte only rebuilds when lockfiles change
|
||||
dashboardLockfileSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
isDashboardDir = baseName == "dashboard" && type == "directory";
|
||||
isPackageFile =
|
||||
(lib.hasInfix "/dashboard/" path || lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
&& (baseName == "package.json" || baseName == "package-lock.json");
|
||||
in
|
||||
isDashboardDir || isPackageFile;
|
||||
};
|
||||
|
||||
# Stub source with lockfiles and minimal files for build to succeed
|
||||
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
|
||||
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
|
||||
mkdir -p $out
|
||||
cp ${dashboardLockfileSrc}/dashboard/package.json $out/
|
||||
cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/
|
||||
cp ${inputs.self}/dashboard/package.json $out/
|
||||
cp ${inputs.self}/dashboard/package-lock.json $out/
|
||||
# Minimal files so vite build succeeds (produces empty output)
|
||||
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
|
||||
mkdir -p $out/src
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
ttftMs,
|
||||
tps,
|
||||
totalTokens,
|
||||
cancelRequest,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
@@ -605,37 +606,15 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => cancelRequest()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/50 text-exo-light-gray border border-exo-medium-gray/50 hover:border-red-500/50 hover:text-red-400 cursor-pointer"
|
||||
>
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
@@ -644,47 +623,81 @@
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
<span class="hidden sm:inline">CANCEL</span>
|
||||
<span class="sm:hidden">X</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -464,6 +464,7 @@ class AppStore {
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private previousNodeIds: Set<string> = new Set();
|
||||
private activeAbortController: AbortController | null = null;
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -1746,6 +1747,9 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
this.ttftMs = null;
|
||||
@@ -1880,6 +1884,7 @@ class AppStore {
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
}),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1975,6 +1980,9 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
return;
|
||||
}
|
||||
console.error("Error sending message:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -1983,6 +1991,7 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2003,6 +2012,9 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -2088,6 +2100,7 @@ class AppStore {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -2197,6 +2210,19 @@ class AppStore {
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// Clean up the "Generating image..." message on cancellation
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "Cancelled";
|
||||
msg.attachments = [];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
return;
|
||||
}
|
||||
console.error("Error generating image:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -2205,6 +2231,7 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2228,6 +2255,9 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -2336,6 +2366,7 @@ class AppStore {
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!apiResponse.ok) {
|
||||
@@ -2407,6 +2438,19 @@ class AppStore {
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// Clean up the "Editing image..." message on cancellation
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "cancelled";
|
||||
msg.attachments = [];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
return;
|
||||
}
|
||||
console.error("Error editing image:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -2415,11 +2459,24 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel an in-flight request by aborting the active fetch
|
||||
*/
|
||||
cancelRequest(): void {
|
||||
if (this.activeAbortController) {
|
||||
this.activeAbortController.abort();
|
||||
this.activeAbortController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear current chat and go back to welcome state
|
||||
*/
|
||||
@@ -2556,6 +2613,7 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const cancelRequest = () => appStore.cancelRequest();
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
[package]
|
||||
name = "cluster_membership"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
# util
|
||||
anyhow.workspace = true
|
||||
log.workspace = true
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures-timer = { workspace = true }
|
||||
futures-lite = "2.6.1"
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
async-trait = "0.1.89"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
@@ -1,30 +0,0 @@
|
||||
use cluster_membership::Peer;
|
||||
use libp2p::identity::ed25519::SecretKey;
|
||||
use tokio::io::{self, AsyncBufReadExt};
|
||||
use tracing_subscriber::{EnvFilter, filter::LevelFilter};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
let (mut peer, send, mut recv) =
|
||||
Peer::new(SecretKey::generate(), "hello".to_string()).expect("peer should always build");
|
||||
|
||||
let ch = peer.subscribe("chatroom".to_string());
|
||||
let jh = tokio::spawn(async move { peer.run().await });
|
||||
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {send.send((ch.clone(), line.into_bytes())).await.expect("example");}
|
||||
Some(r) = recv.recv() => match r {
|
||||
Ok((_, id, line)) => println!("{:?}:{:?}", id, String::from_utf8_lossy(&line)),
|
||||
Err(e) => eprintln!("{e:?}"),
|
||||
},
|
||||
else => break
|
||||
}
|
||||
}
|
||||
jh.await.expect("task failure");
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
use libp2p::{
|
||||
Multiaddr, PeerId, Swarm, SwarmBuilder,
|
||||
futures::StreamExt,
|
||||
gossipsub::{self, PublishError, Sha256Topic, TopicHash},
|
||||
identify,
|
||||
identity::{Keypair, ed25519},
|
||||
mdns,
|
||||
swarm::{NetworkBehaviour, SwarmEvent, dial_opts::DialOpts},
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::{select, sync::mpsc};
|
||||
|
||||
const DEFAULT_BUFFER_SIZE: usize = 10;
|
||||
const MDNS_IGNORE_DURATION_SECS: u64 = 30;
|
||||
|
||||
impl Peer {
|
||||
pub fn new(
|
||||
identity: ed25519::SecretKey,
|
||||
namespace: String,
|
||||
) -> anyhow::Result<(
|
||||
Self,
|
||||
mpsc::Sender<(TopicHash, Vec<u8>)>,
|
||||
mpsc::Receiver<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
|
||||
)> {
|
||||
let mut id_bytes = identity.as_ref().to_vec();
|
||||
|
||||
let mut swarm =
|
||||
SwarmBuilder::with_existing_identity(Keypair::ed25519_from_bytes(&mut id_bytes)?)
|
||||
.with_tokio()
|
||||
.with_quic()
|
||||
// TODO(evan): .with_bandwidth_metrics();
|
||||
.with_behaviour(|kp| Behaviour::new(kp, namespace.clone()))?
|
||||
.build();
|
||||
|
||||
swarm.listen_on("/ip6/::/udp/0/quic-v1".parse()?)?;
|
||||
swarm.listen_on("/ip4/0.0.0.0/udp/0/quic-v1".parse()?)?;
|
||||
let (to_swarm, from_client) = mpsc::channel(DEFAULT_BUFFER_SIZE);
|
||||
let (to_client, from_swarm) = mpsc::channel(DEFAULT_BUFFER_SIZE);
|
||||
Ok((
|
||||
Self {
|
||||
swarm,
|
||||
namespace,
|
||||
denied: HashMap::new(),
|
||||
from_client,
|
||||
to_client,
|
||||
},
|
||||
to_swarm,
|
||||
from_swarm,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn subscribe(&mut self, topic: String) -> TopicHash {
|
||||
let topic = Sha256Topic::new(topic);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("topic filtered");
|
||||
topic.hash()
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) {
|
||||
loop {
|
||||
select! {
|
||||
ev = self.swarm.select_next_some() => {
|
||||
let Ok(()) = self.handle_swarm_event(ev).await else {
|
||||
return
|
||||
};
|
||||
},
|
||||
Some(msg) = self.from_client.recv() => {
|
||||
if let Err(e) = self.swarm.behaviour_mut().gossipsub.publish(msg.0, msg.1) {
|
||||
let Ok(()) = self.to_client.send(Err(e)).await else {
|
||||
return
|
||||
};
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_swarm_event(&mut self, event: SwarmEvent<BehaviourEvent>) -> Result<(), ()> {
|
||||
let SwarmEvent::Behaviour(event) = event else {
|
||||
if let SwarmEvent::NewListenAddr {
|
||||
listener_id: _,
|
||||
address,
|
||||
} = event
|
||||
{
|
||||
log::info!("new listen address {address}")
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
match event {
|
||||
BehaviourEvent::Mdns(mdns_event) => match mdns_event {
|
||||
mdns::Event::Discovered(vec) => {
|
||||
// Dial everyone
|
||||
let mut addrs = HashMap::<PeerId, Vec<Multiaddr>>::new();
|
||||
vec.into_iter()
|
||||
.filter(|(peer_id, _)| {
|
||||
self.denied.get(peer_id).is_none_or(|t| {
|
||||
t.elapsed() > Duration::from_secs(MDNS_IGNORE_DURATION_SECS)
|
||||
})
|
||||
})
|
||||
.for_each(|(peer_id, addr)| addrs.entry(peer_id).or_default().push(addr));
|
||||
addrs.into_iter().for_each(|(peer_id, addrs)| {
|
||||
let _ = self
|
||||
.swarm
|
||||
.dial(DialOpts::peer_id(peer_id).addresses(addrs).build());
|
||||
});
|
||||
}
|
||||
mdns::Event::Expired(vec) => {
|
||||
vec.iter().for_each(|(peer_id, _)| {
|
||||
log::debug!("{peer_id} no longer reachable on mDNS");
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.remove_explicit_peer(peer_id);
|
||||
});
|
||||
}
|
||||
},
|
||||
BehaviourEvent::Identify(identify::Event::Received {
|
||||
connection_id: _,
|
||||
peer_id,
|
||||
info,
|
||||
}) => {
|
||||
if info
|
||||
.protocols
|
||||
.iter()
|
||||
.any(|p| p.as_ref().contains(&self.namespace))
|
||||
{
|
||||
self.passed_namespace(peer_id);
|
||||
} else {
|
||||
self.failed_namespace(peer_id);
|
||||
}
|
||||
}
|
||||
BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: _,
|
||||
message_id: _,
|
||||
message:
|
||||
gossipsub::Message {
|
||||
topic,
|
||||
data,
|
||||
source: Some(source_peer),
|
||||
..
|
||||
},
|
||||
}) => {
|
||||
let Ok(()) = self.to_client.send(Ok((topic, source_peer, data))).await else {
|
||||
return Err(());
|
||||
};
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn passed_namespace(&mut self, peer: PeerId) {
|
||||
log::info!("new peer {peer:?}");
|
||||
self.denied.remove(&peer);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.remove_blacklisted_peer(&peer);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.add_explicit_peer(&peer);
|
||||
}
|
||||
|
||||
fn failed_namespace(&mut self, peer: PeerId) {
|
||||
log::debug!("{peer} failed handshake");
|
||||
self.denied.insert(peer, Instant::now());
|
||||
self.swarm.behaviour_mut().gossipsub.blacklist_peer(&peer);
|
||||
// we don't care if disconnect fails
|
||||
let _ = self.swarm.disconnect_peer_id(peer);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Peer {
|
||||
pub swarm: Swarm<Behaviour>,
|
||||
denied: HashMap<PeerId, Instant>,
|
||||
namespace: String,
|
||||
to_client: mpsc::Sender<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
|
||||
from_client: mpsc::Receiver<(TopicHash, Vec<u8>)>,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn foo() {
|
||||
fn bar<T: Send>(t: T) {}
|
||||
let p: Peer = unimplemented!();
|
||||
bar(p);
|
||||
}
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
identify: identify::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
fn new(kp: &Keypair, namespace: String) -> Self {
|
||||
let mdns = mdns::tokio::Behaviour::new(Default::default(), kp.public().to_peer_id())
|
||||
.expect("implementation is infallible");
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(kp.clone()),
|
||||
gossipsub::ConfigBuilder::default()
|
||||
.max_transmit_size(1024 * 1024)
|
||||
.protocol_id_prefix(format!("/exo/gossip/{namespace}/v1"))
|
||||
.build()
|
||||
.expect("fixed gossipsub config should always build"),
|
||||
)
|
||||
.expect("fixed gossipsub init should always build");
|
||||
|
||||
let identify = identify::Behaviour::new(
|
||||
identify::Config::new_with_signed_peer_record(format!("/exo/identity/v1"), kp)
|
||||
.with_push_listen_addr_updates(true),
|
||||
);
|
||||
|
||||
Behaviour {
|
||||
mdns,
|
||||
gossipsub,
|
||||
identify,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,6 @@ doc = false
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
cluster_membership.workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
|
||||
@@ -6,41 +6,3 @@
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use cluster_membership::Peer;
|
||||
use libp2p::identity::ed25519::Keypair;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
pub struct PyKeypair(Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyKeypair {
|
||||
#[staticmethod]
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate())
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass]
|
||||
pub struct PyPeer(Mutex<Peer>);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyPeer {
|
||||
#[staticmethod]
|
||||
fn init(kp: PyKeypair, namespace: String) -> PyResult<Self> {
|
||||
Ok(PyPeer(Mutex::new(
|
||||
Peer::new(kp.0.secret(), namespace)
|
||||
.map_err(|e| e.pyerr())?
|
||||
.0,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +88,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -508,16 +509,14 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._chat_completion_queues:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
@@ -901,6 +900,11 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -982,6 +986,11 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
|
||||
@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -35,6 +36,7 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -246,7 +248,7 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
@@ -258,7 +260,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -268,7 +270,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -278,6 +280,18 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -286,10 +300,9 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
|
||||
@@ -20,9 +20,15 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -180,6 +186,7 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -195,6 +202,18 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -84,6 +88,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -60,6 +61,10 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -87,6 +92,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -349,8 +349,13 @@ class InfoGatherer:
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
prev = await MiscData.gather()
|
||||
await self.info_sender.send(prev)
|
||||
while True:
|
||||
await self.info_sender.send(await MiscData.gather())
|
||||
curr = await MiscData.gather()
|
||||
if prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
|
||||
async def _monitor_system_profiler_thunderbolt_data(self):
|
||||
@@ -360,12 +365,15 @@ class InfoGatherer:
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
old_idents = []
|
||||
while True:
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
if idents != old_idents:
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
old_idents = idents
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacThunderboltConnections(conns=conns))
|
||||
@@ -390,17 +398,22 @@ class InfoGatherer:
|
||||
async def _watch_system_info(self):
|
||||
if self.interface_watcher_interval is None:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = await get_network_interfaces()
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_thunderbolt_bridge_status(self):
|
||||
if self.thunderbolt_bridge_poll_interval is None:
|
||||
return
|
||||
prev: ThunderboltBridgeInfo | None = None
|
||||
while True:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None:
|
||||
if curr is not None and prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
@@ -109,6 +109,7 @@ class DistributedImageModel:
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
advanced_params: AdvancedImageParams | None = None,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
if (
|
||||
advanced_params is not None
|
||||
@@ -153,6 +154,7 @@ class DistributedImageModel:
|
||||
guidance_override=guidance_override,
|
||||
negative_prompt=negative_prompt,
|
||||
num_sync_steps=num_sync_steps,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
|
||||
@@ -3,6 +3,7 @@ import io
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
@@ -68,12 +69,18 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
|
||||
def generate_image(
|
||||
model: DistributedImageModel,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Args:
|
||||
model: The distributed image model to use for generation.
|
||||
task: The task parameters for image generation or editing.
|
||||
cancel_checker: Optional callback to check if generation should be cancelled.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
@@ -123,6 +130,7 @@ def generate_image(
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -94,6 +95,8 @@ class DiffusionRunner:
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._guidance_override: float | None = None
|
||||
self._cancel_checker: Callable[[], bool] | None = None
|
||||
self._cancelling = False
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
@@ -148,6 +151,54 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _check_cancellation(self) -> bool:
|
||||
if self._cancelling:
|
||||
return True
|
||||
if (
|
||||
self.is_first_stage
|
||||
and self._cancel_checker is not None
|
||||
and self._cancel_checker()
|
||||
):
|
||||
self._cancelling = True
|
||||
return self._cancelling
|
||||
|
||||
def _is_sentinel(self, tensor: mx.array) -> bool:
|
||||
return bool(mx.all(mx.isnan(tensor)).item())
|
||||
|
||||
def _make_sentinel_like(self, tensor: mx.array) -> mx.array:
|
||||
return mx.full(tensor.shape, float("nan"), dtype=tensor.dtype)
|
||||
|
||||
def _recv(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: mx.Dtype,
|
||||
src: int,
|
||||
) -> mx.array:
|
||||
"""Receive data and check for cancellation sentinel."""
|
||||
data = mx.distributed.recv(shape, dtype, src, group=self.group)
|
||||
mx.eval(data)
|
||||
if self._is_sentinel(data):
|
||||
self._cancelling = True
|
||||
return data
|
||||
|
||||
def _recv_like(self, template: mx.array, src: int) -> mx.array:
|
||||
"""Receive data matching template and check for cancellation sentinel."""
|
||||
data = mx.distributed.recv_like(template, src=src, group=self.group)
|
||||
mx.eval(data)
|
||||
if self._is_sentinel(data):
|
||||
self._cancelling = True
|
||||
return data
|
||||
|
||||
def _send(self, data: mx.array, dst: int) -> mx.array:
|
||||
"""Send data, or sentinel if cancelling."""
|
||||
|
||||
if self._cancelling:
|
||||
data = self._make_sentinel_like(data)
|
||||
|
||||
result = mx.distributed.send(data, dst, group=self.group)
|
||||
mx.async_eval(result)
|
||||
return result
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -244,6 +295,7 @@ class DiffusionRunner:
|
||||
guidance_override: float | None = None,
|
||||
negative_prompt: str | None = None,
|
||||
num_sync_steps: int = 1,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
@@ -255,17 +307,21 @@ class DiffusionRunner:
|
||||
5. Decode to image
|
||||
|
||||
Args:
|
||||
settings: Generation config (steps, height, width)
|
||||
runtime_config: Runtime configuration (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
guidance_override: Optional override for guidance scale (CFG)
|
||||
negative_prompt: Optional negative prompt for CFG
|
||||
num_sync_steps: Number of synchronous pipeline steps
|
||||
cancel_checker: Optional callback to check for cancellation
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
self._guidance_override = guidance_override
|
||||
self._cancel_checker = cancel_checker
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
|
||||
|
||||
@@ -307,7 +363,7 @@ class DiffusionRunner:
|
||||
except StopIteration as e:
|
||||
latents = e.value # pyright: ignore[reportAny]
|
||||
|
||||
if self.is_last_stage:
|
||||
if self.is_last_stage and not self._cancelling:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt) # pyright: ignore[reportAny]
|
||||
|
||||
def _run_diffusion_loop(
|
||||
@@ -323,6 +379,7 @@ class DiffusionRunner:
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
self._cancelling = False
|
||||
self._reset_all_caches()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
@@ -345,9 +402,13 @@ class DiffusionRunner:
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
if self._cancelling:
|
||||
break
|
||||
|
||||
ctx.in_loop( # pyright: ignore[reportAny]
|
||||
t=t,
|
||||
latents=latents,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
@@ -356,7 +417,7 @@ class DiffusionRunner:
|
||||
yield (latents, t)
|
||||
|
||||
except KeyboardInterrupt: # noqa: PERF203
|
||||
ctx.interruption(t=t, latents=latents) # pyright: ignore[reportAny]
|
||||
ctx.interruption(t=t, latents=latents, time_steps=time_steps) # pyright: ignore[reportAny]
|
||||
raise StopImageGenerationException(
|
||||
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
|
||||
) from None
|
||||
@@ -566,6 +627,8 @@ class DiffusionRunner:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_hidden_states_mask)
|
||||
|
||||
self._check_cancellation()
|
||||
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -585,19 +648,12 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
hidden_states = self._recv(
|
||||
(batch_size, num_img_tokens, hidden_dim), dtype, self.prev_rank
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
encoder_hidden_states = self._recv(
|
||||
(batch_size, text_seq_len, hidden_dim), dtype, self.prev_rank
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -619,30 +675,20 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
concatenated = self._send(concatenated, self.next_rank)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
hidden_states = self._send(hidden_states, self.next_rank)
|
||||
encoder_hidden_states = self._send(encoder_hidden_states, self.next_rank)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
hidden_states = self._recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
@@ -654,10 +700,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
hidden_states = self._send(hidden_states, self.next_rank)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
@@ -741,14 +784,13 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
hidden_states = self._send(hidden_states, 0)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
hidden_states = self._recv_like(prev_latents, src=self.world_size - 1)
|
||||
|
||||
if self._cancelling:
|
||||
return prev_latents
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -808,10 +850,9 @@ class DiffusionRunner:
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch = self._recv_like(patch, src=self.prev_rank)
|
||||
|
||||
self._check_cancellation()
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
@@ -842,10 +883,19 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
patch_latents[patch_idx] = self._send(
|
||||
patch_latents[patch_idx], self.next_rank
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
# Drain final rank patch sends if cancelling
|
||||
if (
|
||||
self._cancelling
|
||||
and self.is_first_stage
|
||||
and not self.is_last_stage
|
||||
and t != config.num_inference_steps - 1
|
||||
):
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
_ = self._recv_like(patch_latents[patch_idx], src=self.prev_rank)
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
@@ -884,22 +934,16 @@ class DiffusionRunner:
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
patch = self._recv(
|
||||
(batch_size, patch_len, hidden_dim), patch.dtype, self.prev_rank
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
encoder_hidden_states = self._recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -924,32 +968,25 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
patch_concat = self._send(patch_concat, self.next_rank)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
patch = self._send(patch, self.next_rank)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
encoder_hidden_states = self._send(
|
||||
encoder_hidden_states, self.next_rank
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
patch = self._recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
@@ -961,8 +998,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
patch = self._send(patch, self.next_rank)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
|
||||
@@ -23,7 +23,6 @@ from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
make_kv_cache,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -90,10 +89,6 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
@@ -186,5 +181,3 @@ def mlx_generate(
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -70,8 +70,6 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -89,30 +87,6 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -536,3 +510,23 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -115,8 +116,9 @@ class Worker:
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
async with create_task_group() as tg:
|
||||
for runner in self.runners.values():
|
||||
tg.start_soon(runner.shutdown)
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -220,15 +222,22 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
await runner.start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await runner.shutdown()
|
||||
case CancelTask(cancelled_task_id=cancelled_task_id):
|
||||
await self.runners[self._task_to_runner_id(task)].cancel_task(
|
||||
cancelled_task_id
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -351,8 +360,6 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
@@ -59,7 +60,8 @@ def plan(
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
or _cancel_tasks(runners, tasks)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
)
|
||||
|
||||
|
||||
@@ -270,7 +272,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -284,7 +286,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -292,16 +294,31 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id, cancelled_task_id=task.task_id
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,6 +15,7 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -38,7 +39,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -37,6 +37,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -77,6 +78,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -85,6 +87,7 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
@@ -99,8 +102,11 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
model: Model | DistributedImageModel | None = None
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
|
||||
@@ -111,6 +117,7 @@ def main(
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -155,7 +162,7 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
@@ -165,7 +172,7 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
@@ -174,8 +181,6 @@ def main(
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -186,11 +191,11 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
@@ -202,8 +207,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -222,7 +227,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
@@ -234,7 +239,7 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
@@ -257,11 +262,11 @@ def main(
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
elif isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
inference_model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -273,7 +278,19 @@ def main(
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
cancel_every = 5
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if (
|
||||
@@ -337,72 +354,16 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
case ImageGeneration() | ImageEdits() if isinstance(
|
||||
current_status, RunnerReady
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
assert image_model
|
||||
task_name = (
|
||||
"image generation"
|
||||
if isinstance(task, ImageGeneration)
|
||||
else "image edits"
|
||||
)
|
||||
logger.info(f"received {task_name} request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
@@ -412,39 +373,19 @@ def main(
|
||||
)
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
_run_image_task(
|
||||
task=task,
|
||||
image_model=image_model,
|
||||
shard_metadata=shard_metadata,
|
||||
event_sender=event_sender,
|
||||
cancel_receiver=cancel_receiver,
|
||||
cancelled_tasks=cancelled_tasks,
|
||||
)
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
command_id=task.command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
@@ -476,7 +417,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
@@ -585,6 +526,54 @@ def parse_thinking_models(
|
||||
yield response
|
||||
|
||||
|
||||
def _run_image_task(
|
||||
task: ImageGeneration | ImageEdits,
|
||||
image_model: DistributedImageModel,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
cancelled_tasks: set[TaskId],
|
||||
) -> None:
|
||||
task_id = task.task_id
|
||||
command_id = task.command_id
|
||||
|
||||
def check_cancelled(task_id: TaskId = task_id) -> bool:
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
return (task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model,
|
||||
task=task.task_params,
|
||||
cancel_checker=check_cancelled,
|
||||
):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
encoded_data: str,
|
||||
command_id: CommandId,
|
||||
|
||||
@@ -49,10 +49,12 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -63,8 +65,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -72,6 +74,7 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -86,6 +89,7 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -93,37 +97,41 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._forward_events)
|
||||
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.close()
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
await to_thread.run_sync(self.runner_process.join, 10)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
@@ -131,6 +139,7 @@ class RunnerSupervisor:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -140,7 +149,13 @@ class RunnerSupervisor:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
return
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
@@ -206,4 +221,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
self.shutdown()
|
||||
await self.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user