mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-26 03:06:05 -05:00
Compare commits
11 Commits
remove-pyt
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7b7523bb4 | ||
|
|
3ee1076d44 | ||
|
|
c2f47802c8 | ||
|
|
a067fd0753 | ||
|
|
333119ec2f | ||
|
|
79661ea3a3 | ||
|
|
196f4fb35f | ||
|
|
08a4ca59c6 | ||
|
|
45ed3903f3 | ||
|
|
6657d8600f | ||
|
|
bf06580e0e |
@@ -5,21 +5,21 @@
|
|||||||
[X] Fetching download status of all models on start
|
[X] Fetching download status of all models on start
|
||||||
[X] Deduplication of tasks in plan_step.
|
[X] Deduplication of tasks in plan_step.
|
||||||
[X] resolve_allow_patterns should just be wildcard now.
|
[X] resolve_allow_patterns should just be wildcard now.
|
||||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||||
[] GPTOSS support dropped in auto_parallel.py.
|
[X] GPTOSS support dropped in auto_parallel.py.
|
||||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||||
[X] We put cache limit back in utils_mlx.py.
|
[X] We put cache limit back in utils_mlx.py.
|
||||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import {
|
import {
|
||||||
isLoading,
|
isLoading,
|
||||||
|
stopGeneration,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
generateImage,
|
generateImage,
|
||||||
editImage,
|
editImage,
|
||||||
@@ -648,86 +649,92 @@
|
|||||||
style="min-height: 28px; max-height: 150px;"
|
style="min-height: 28px; max-height: 150px;"
|
||||||
></textarea>
|
></textarea>
|
||||||
|
|
||||||
<button
|
{#if loading}
|
||||||
type="submit"
|
<button
|
||||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
type="button"
|
||||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
onclick={() => stopGeneration()}
|
||||||
{!canSend || loading || isEditOnlyWithoutImage
|
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-red-900/50 hover:text-red-400 border border-exo-medium-gray/50 hover:border-red-500/50 cursor-pointer"
|
||||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
aria-label="Stop generation"
|
||||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
>
|
||||||
aria-label={shouldShowEditMode
|
|
||||||
? "Edit image"
|
|
||||||
: isImageModel()
|
|
||||||
? "Generate image"
|
|
||||||
: "Send message"}
|
|
||||||
>
|
|
||||||
{#if loading}
|
|
||||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||||
<span
|
|
||||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
|
||||||
></span>
|
|
||||||
<span class="hidden sm:inline"
|
|
||||||
>{shouldShowEditMode
|
|
||||||
? "EDITING"
|
|
||||||
: isImageModel()
|
|
||||||
? "GENERATING"
|
|
||||||
: "PROCESSING"}</span
|
|
||||||
>
|
|
||||||
<span class="sm:hidden">...</span>
|
|
||||||
</span>
|
|
||||||
{:else if shouldShowEditMode}
|
|
||||||
<span class="inline-flex items-center gap-1.5">
|
|
||||||
<svg
|
<svg
|
||||||
class="w-3.5 h-3.5"
|
class="w-2.5 h-2.5 sm:w-3 sm:h-3"
|
||||||
fill="none"
|
|
||||||
viewBox="0 0 24 24"
|
viewBox="0 0 24 24"
|
||||||
stroke="currentColor"
|
fill="currentColor"
|
||||||
stroke-width="2"
|
|
||||||
>
|
>
|
||||||
<path
|
<rect x="4" y="4" width="16" height="16" rx="2" />
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
|
||||||
/>
|
|
||||||
</svg>
|
</svg>
|
||||||
<span>EDIT</span>
|
<span class="hidden sm:inline">STOP</span>
|
||||||
</span>
|
</span>
|
||||||
{:else if isEditOnlyWithoutImage}
|
</button>
|
||||||
<span class="inline-flex items-center gap-1.5">
|
{:else}
|
||||||
<svg
|
<button
|
||||||
class="w-3.5 h-3.5"
|
type="submit"
|
||||||
fill="none"
|
disabled={!canSend || isEditOnlyWithoutImage}
|
||||||
viewBox="0 0 24 24"
|
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||||
stroke="currentColor"
|
{!canSend || isEditOnlyWithoutImage
|
||||||
stroke-width="2"
|
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||||
>
|
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||||
<path
|
aria-label={shouldShowEditMode
|
||||||
stroke-linecap="round"
|
? "Edit image"
|
||||||
stroke-linejoin="round"
|
: isImageModel()
|
||||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
? "Generate image"
|
||||||
/>
|
: "Send message"}
|
||||||
</svg>
|
>
|
||||||
<span>EDIT</span>
|
{#if shouldShowEditMode}
|
||||||
</span>
|
<span class="inline-flex items-center gap-1.5">
|
||||||
{:else if isImageModel()}
|
<svg
|
||||||
<span class="inline-flex items-center gap-1.5">
|
class="w-3.5 h-3.5"
|
||||||
<svg
|
fill="none"
|
||||||
class="w-3.5 h-3.5"
|
viewBox="0 0 24 24"
|
||||||
fill="none"
|
stroke="currentColor"
|
||||||
viewBox="0 0 24 24"
|
stroke-width="2"
|
||||||
stroke="currentColor"
|
>
|
||||||
stroke-width="2"
|
<path
|
||||||
>
|
stroke-linecap="round"
|
||||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
stroke-linejoin="round"
|
||||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||||
<polyline points="21 15 16 10 5 21" />
|
/>
|
||||||
</svg>
|
</svg>
|
||||||
<span>GENERATE</span>
|
<span>EDIT</span>
|
||||||
</span>
|
</span>
|
||||||
{:else}
|
{:else if isEditOnlyWithoutImage}
|
||||||
SEND
|
<span class="inline-flex items-center gap-1.5">
|
||||||
{/if}
|
<svg
|
||||||
</button>
|
class="w-3.5 h-3.5"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
<span>EDIT</span>
|
||||||
|
</span>
|
||||||
|
{:else if isImageModel()}
|
||||||
|
<span class="inline-flex items-center gap-1.5">
|
||||||
|
<svg
|
||||||
|
class="w-3.5 h-3.5"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||||
|
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||||
|
<polyline points="21 15 16 10 5 21" />
|
||||||
|
</svg>
|
||||||
|
<span>GENERATE</span>
|
||||||
|
</span>
|
||||||
|
{:else}
|
||||||
|
SEND
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Bottom accent line -->
|
<!-- Bottom accent line -->
|
||||||
|
|||||||
@@ -514,6 +514,7 @@ class AppStore {
|
|||||||
messages = $state<Message[]>([]);
|
messages = $state<Message[]>([]);
|
||||||
currentResponse = $state("");
|
currentResponse = $state("");
|
||||||
isLoading = $state(false);
|
isLoading = $state(false);
|
||||||
|
private currentAbortController: AbortController | null = null;
|
||||||
|
|
||||||
// Performance metrics
|
// Performance metrics
|
||||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||||
@@ -1814,9 +1815,11 @@ class AppStore {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.currentAbortController = new AbortController();
|
||||||
const response = await fetch("/v1/chat/completions", {
|
const response = await fetch("/v1/chat/completions", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
|
signal: this.currentAbortController.signal,
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
messages: apiMessages,
|
messages: apiMessages,
|
||||||
@@ -1930,6 +1933,7 @@ class AppStore {
|
|||||||
"Unknown error",
|
"Unknown error",
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
|
this.currentAbortController = null;
|
||||||
this.isLoading = false;
|
this.isLoading = false;
|
||||||
this.currentResponse = "";
|
this.currentResponse = "";
|
||||||
this.saveConversationsToStorage();
|
this.saveConversationsToStorage();
|
||||||
@@ -2066,6 +2070,10 @@ class AppStore {
|
|||||||
assistantMessageId: string,
|
assistantMessageId: string,
|
||||||
errorPrefix = "Failed to get response",
|
errorPrefix = "Failed to get response",
|
||||||
): void {
|
): void {
|
||||||
|
// Don't show error for user-initiated abort (stop button)
|
||||||
|
if (error instanceof DOMException && error.name === "AbortError") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (this.conversationExists(targetConversationId)) {
|
if (this.conversationExists(targetConversationId)) {
|
||||||
this.updateConversationMessage(
|
this.updateConversationMessage(
|
||||||
targetConversationId,
|
targetConversationId,
|
||||||
@@ -2107,6 +2115,17 @@ class AppStore {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stop the current generation by aborting the HTTP connection.
|
||||||
|
* This triggers backend cancellation via the mechanism in PR #1276.
|
||||||
|
*/
|
||||||
|
stopGeneration() {
|
||||||
|
if (this.currentAbortController) {
|
||||||
|
this.currentAbortController.abort();
|
||||||
|
this.currentAbortController = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Send a message to the LLM and stream the response
|
* Send a message to the LLM and stream the response
|
||||||
*/
|
*/
|
||||||
@@ -2255,11 +2274,13 @@ class AppStore {
|
|||||||
let firstTokenTime: number | null = null;
|
let firstTokenTime: number | null = null;
|
||||||
let tokenCount = 0;
|
let tokenCount = 0;
|
||||||
|
|
||||||
|
this.currentAbortController = new AbortController();
|
||||||
const response = await fetch("/v1/chat/completions", {
|
const response = await fetch("/v1/chat/completions", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
|
signal: this.currentAbortController.signal,
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
messages: apiMessages,
|
messages: apiMessages,
|
||||||
@@ -2410,6 +2431,7 @@ class AppStore {
|
|||||||
"Failed to get response",
|
"Failed to get response",
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
|
this.currentAbortController = null;
|
||||||
this.isLoading = false;
|
this.isLoading = false;
|
||||||
this.currentResponse = "";
|
this.currentResponse = "";
|
||||||
this.saveConversationsToStorage();
|
this.saveConversationsToStorage();
|
||||||
@@ -2514,11 +2536,13 @@ class AppStore {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.currentAbortController = new AbortController();
|
||||||
const response = await fetch("/v1/images/generations", {
|
const response = await fetch("/v1/images/generations", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
|
signal: this.currentAbortController.signal,
|
||||||
body: JSON.stringify(requestBody),
|
body: JSON.stringify(requestBody),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -2667,6 +2691,7 @@ class AppStore {
|
|||||||
"Failed to generate image",
|
"Failed to generate image",
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
|
this.currentAbortController = null;
|
||||||
this.isLoading = false;
|
this.isLoading = false;
|
||||||
this.saveConversationsToStorage();
|
this.saveConversationsToStorage();
|
||||||
}
|
}
|
||||||
@@ -2788,8 +2813,10 @@ class AppStore {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.currentAbortController = new AbortController();
|
||||||
const apiResponse = await fetch("/v1/images/edits", {
|
const apiResponse = await fetch("/v1/images/edits", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
|
signal: this.currentAbortController.signal,
|
||||||
body: formData,
|
body: formData,
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -2899,6 +2926,7 @@ class AppStore {
|
|||||||
"Failed to edit image",
|
"Failed to edit image",
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
|
this.currentAbortController = null;
|
||||||
this.isLoading = false;
|
this.isLoading = false;
|
||||||
this.saveConversationsToStorage();
|
this.saveConversationsToStorage();
|
||||||
}
|
}
|
||||||
@@ -3039,6 +3067,7 @@ export const hasStartedChat = () => appStore.hasStartedChat;
|
|||||||
export const messages = () => appStore.messages;
|
export const messages = () => appStore.messages;
|
||||||
export const currentResponse = () => appStore.currentResponse;
|
export const currentResponse = () => appStore.currentResponse;
|
||||||
export const isLoading = () => appStore.isLoading;
|
export const isLoading = () => appStore.isLoading;
|
||||||
|
export const stopGeneration = () => appStore.stopGeneration();
|
||||||
export const ttftMs = () => appStore.ttftMs;
|
export const ttftMs = () => appStore.ttftMs;
|
||||||
export const tps = () => appStore.tps;
|
export const tps = () => appStore.tps;
|
||||||
export const totalTokens = () => appStore.totalTokens;
|
export const totalTokens = () => appStore.totalTokens;
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ markers = [
|
|||||||
env = [
|
env = [
|
||||||
"EXO_TESTS=1"
|
"EXO_TESTS=1"
|
||||||
]
|
]
|
||||||
addopts = "-m 'not slow'"
|
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
|
||||||
filterwarnings = [
|
filterwarnings = [
|
||||||
"ignore:builtin type Swig:DeprecationWarning",
|
"ignore:builtin type Swig:DeprecationWarning",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ async def generate_chat_stream(
|
|||||||
async def collect_chat_response(
|
async def collect_chat_response(
|
||||||
command_id: CommandId,
|
command_id: CommandId,
|
||||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||||
) -> ChatCompletionResponse:
|
) -> AsyncGenerator[str]:
|
||||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||||
text_parts: list[str] = []
|
text_parts: list[str] = []
|
||||||
tool_calls: list[ToolCall] = []
|
tool_calls: list[ToolCall] = []
|
||||||
@@ -223,7 +223,7 @@ async def collect_chat_response(
|
|||||||
combined_text = "".join(text_parts)
|
combined_text = "".join(text_parts)
|
||||||
assert model is not None
|
assert model is not None
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
id=command_id,
|
id=command_id,
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=model,
|
model=model,
|
||||||
@@ -241,4 +241,5 @@ async def collect_chat_response(
|
|||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
).model_dump_json()
|
||||||
|
return
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ from exo.shared.types.commands import (
|
|||||||
PlaceInstance,
|
PlaceInstance,
|
||||||
SendInputChunk,
|
SendInputChunk,
|
||||||
StartDownload,
|
StartDownload,
|
||||||
|
TaskCancelled,
|
||||||
TaskFinished,
|
TaskFinished,
|
||||||
TextGeneration,
|
TextGeneration,
|
||||||
)
|
)
|
||||||
@@ -540,16 +541,14 @@ class API:
|
|||||||
break
|
break
|
||||||
|
|
||||||
except anyio.get_cancelled_exc_class():
|
except anyio.get_cancelled_exc_class():
|
||||||
# TODO: TaskCancelled
|
command = TaskCancelled(cancelled_command_id=command_id)
|
||||||
"""
|
with anyio.CancelScope(shield=True):
|
||||||
self.command_sender.send_nowait(
|
await self.command_sender.send(
|
||||||
ForwarderCommand(origin=self.node_id, command=command)
|
ForwarderCommand(origin=self.node_id, command=command)
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
command = TaskFinished(finished_command_id=command_id)
|
await self._send(TaskFinished(finished_command_id=command_id))
|
||||||
await self._send(command)
|
|
||||||
if command_id in self._text_generation_queues:
|
if command_id in self._text_generation_queues:
|
||||||
del self._text_generation_queues[command_id]
|
del self._text_generation_queues[command_id]
|
||||||
|
|
||||||
@@ -644,11 +643,14 @@ class API:
|
|||||||
"X-Accel-Buffering": "no",
|
"X-Accel-Buffering": "no",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
return await collect_chat_response(
|
return StreamingResponse(
|
||||||
command.command_id,
|
collect_chat_response(
|
||||||
self._token_chunk_stream(command.command_id),
|
command.command_id,
|
||||||
)
|
self._token_chunk_stream(command.command_id),
|
||||||
|
),
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
async def bench_chat_completions(
|
async def bench_chat_completions(
|
||||||
self, payload: BenchChatCompletionRequest
|
self, payload: BenchChatCompletionRequest
|
||||||
@@ -664,8 +666,7 @@ class API:
|
|||||||
command = TextGeneration(task_params=task_params)
|
command = TextGeneration(task_params=task_params)
|
||||||
await self._send(command)
|
await self._send(command)
|
||||||
|
|
||||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
return await self._collect_text_generation_with_stats(command.command_id)
|
||||||
return response
|
|
||||||
|
|
||||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||||
"""Validate a text model exists and return the resolved model ID.
|
"""Validate a text model exists and return the resolved model ID.
|
||||||
@@ -883,6 +884,11 @@ class API:
|
|||||||
del image_metadata[key]
|
del image_metadata[key]
|
||||||
|
|
||||||
except anyio.get_cancelled_exc_class():
|
except anyio.get_cancelled_exc_class():
|
||||||
|
command = TaskCancelled(cancelled_command_id=command_id)
|
||||||
|
with anyio.CancelScope(shield=True):
|
||||||
|
await self.command_sender.send(
|
||||||
|
ForwarderCommand(origin=self.node_id, command=command)
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
await self._send(TaskFinished(finished_command_id=command_id))
|
await self._send(TaskFinished(finished_command_id=command_id))
|
||||||
@@ -964,6 +970,11 @@ class API:
|
|||||||
|
|
||||||
return (images, stats if capture_stats else None)
|
return (images, stats if capture_stats else None)
|
||||||
except anyio.get_cancelled_exc_class():
|
except anyio.get_cancelled_exc_class():
|
||||||
|
command = TaskCancelled(cancelled_command_id=command_id)
|
||||||
|
with anyio.CancelScope(shield=True):
|
||||||
|
await self.command_sender.send(
|
||||||
|
ForwarderCommand(origin=self.node_id, command=command)
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
await self._send(TaskFinished(finished_command_id=command_id))
|
await self._send(TaskFinished(finished_command_id=command_id))
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from exo.shared.types.commands import (
|
|||||||
PlaceInstance,
|
PlaceInstance,
|
||||||
RequestEventLog,
|
RequestEventLog,
|
||||||
SendInputChunk,
|
SendInputChunk,
|
||||||
|
TaskCancelled,
|
||||||
TaskFinished,
|
TaskFinished,
|
||||||
TestCommand,
|
TestCommand,
|
||||||
TextGeneration,
|
TextGeneration,
|
||||||
@@ -39,6 +40,7 @@ from exo.shared.types.events import (
|
|||||||
NodeTimedOut,
|
NodeTimedOut,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
TaskDeleted,
|
TaskDeleted,
|
||||||
|
TaskStatusUpdated,
|
||||||
TraceEventData,
|
TraceEventData,
|
||||||
TracesCollected,
|
TracesCollected,
|
||||||
TracesMerged,
|
TracesMerged,
|
||||||
@@ -279,7 +281,7 @@ class Master:
|
|||||||
case DeleteInstance():
|
case DeleteInstance():
|
||||||
placement = delete_instance(command, self.state.instances)
|
placement = delete_instance(command, self.state.instances)
|
||||||
transition_events = get_transition_events(
|
transition_events = get_transition_events(
|
||||||
self.state.instances, placement
|
self.state.instances, placement, self.state.tasks
|
||||||
)
|
)
|
||||||
for cmd in cancel_unnecessary_downloads(
|
for cmd in cancel_unnecessary_downloads(
|
||||||
placement, self.state.downloads
|
placement, self.state.downloads
|
||||||
@@ -299,7 +301,7 @@ class Master:
|
|||||||
self.state.node_network,
|
self.state.node_network,
|
||||||
)
|
)
|
||||||
transition_events = get_transition_events(
|
transition_events = get_transition_events(
|
||||||
self.state.instances, placement
|
self.state.instances, placement, self.state.tasks
|
||||||
)
|
)
|
||||||
generated_events.extend(transition_events)
|
generated_events.extend(transition_events)
|
||||||
case CreateInstance():
|
case CreateInstance():
|
||||||
@@ -309,7 +311,7 @@ class Master:
|
|||||||
self.state.instances,
|
self.state.instances,
|
||||||
)
|
)
|
||||||
transition_events = get_transition_events(
|
transition_events = get_transition_events(
|
||||||
self.state.instances, placement
|
self.state.instances, placement, self.state.tasks
|
||||||
)
|
)
|
||||||
generated_events.extend(transition_events)
|
generated_events.extend(transition_events)
|
||||||
case SendInputChunk(chunk=chunk):
|
case SendInputChunk(chunk=chunk):
|
||||||
@@ -319,6 +321,18 @@ class Master:
|
|||||||
chunk=chunk,
|
chunk=chunk,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
case TaskCancelled():
|
||||||
|
if (
|
||||||
|
task_id := self.command_task_mapping.get(
|
||||||
|
command.cancelled_command_id
|
||||||
|
)
|
||||||
|
) is not None:
|
||||||
|
generated_events.append(
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_status=TaskStatus.Cancelled,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
case TaskFinished():
|
case TaskFinished():
|
||||||
generated_events.append(
|
generated_events.append(
|
||||||
TaskDeleted(
|
TaskDeleted(
|
||||||
@@ -327,10 +341,9 @@ class Master:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if command.finished_command_id in self.command_task_mapping:
|
self.command_task_mapping.pop(
|
||||||
del self.command_task_mapping[
|
command.finished_command_id, None
|
||||||
command.finished_command_id
|
)
|
||||||
]
|
|
||||||
case RequestEventLog():
|
case RequestEventLog():
|
||||||
# We should just be able to send everything, since other buffers will ignore old messages
|
# We should just be able to send everything, since other buffers will ignore old messages
|
||||||
# rate limit to 1000 at a time
|
# rate limit to 1000 at a time
|
||||||
|
|||||||
@@ -22,9 +22,15 @@ from exo.shared.types.commands import (
|
|||||||
PlaceInstance,
|
PlaceInstance,
|
||||||
)
|
)
|
||||||
from exo.shared.types.common import NodeId
|
from exo.shared.types.common import NodeId
|
||||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
from exo.shared.types.events import (
|
||||||
|
Event,
|
||||||
|
InstanceCreated,
|
||||||
|
InstanceDeleted,
|
||||||
|
TaskStatusUpdated,
|
||||||
|
)
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||||
|
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||||
from exo.shared.types.worker.downloads import (
|
from exo.shared.types.worker.downloads import (
|
||||||
DownloadOngoing,
|
DownloadOngoing,
|
||||||
DownloadProgress,
|
DownloadProgress,
|
||||||
@@ -186,6 +192,7 @@ def delete_instance(
|
|||||||
def get_transition_events(
|
def get_transition_events(
|
||||||
current_instances: Mapping[InstanceId, Instance],
|
current_instances: Mapping[InstanceId, Instance],
|
||||||
target_instances: Mapping[InstanceId, Instance],
|
target_instances: Mapping[InstanceId, Instance],
|
||||||
|
tasks: Mapping[TaskId, Task],
|
||||||
) -> Sequence[Event]:
|
) -> Sequence[Event]:
|
||||||
events: list[Event] = []
|
events: list[Event] = []
|
||||||
|
|
||||||
@@ -201,6 +208,18 @@ def get_transition_events(
|
|||||||
# find instances to delete
|
# find instances to delete
|
||||||
for instance_id in current_instances:
|
for instance_id in current_instances:
|
||||||
if instance_id not in target_instances:
|
if instance_id not in target_instances:
|
||||||
|
for task in tasks.values():
|
||||||
|
if task.instance_id == instance_id and task.task_status in [
|
||||||
|
TaskStatus.Pending,
|
||||||
|
TaskStatus.Running,
|
||||||
|
]:
|
||||||
|
events.append(
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_status=TaskStatus.Cancelled,
|
||||||
|
task_id=task.task_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
events.append(
|
events.append(
|
||||||
InstanceDeleted(
|
InstanceDeleted(
|
||||||
instance_id=instance_id,
|
instance_id=instance_id,
|
||||||
|
|||||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
|||||||
target_instances = {instance_id: instance}
|
target_instances = {instance_id: instance}
|
||||||
|
|
||||||
# act
|
# act
|
||||||
events = get_transition_events(current_instances, target_instances)
|
events = get_transition_events(current_instances, target_instances, {})
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(events) == 0
|
assert len(events) == 0
|
||||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
|||||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||||
|
|
||||||
# act
|
# act
|
||||||
events = get_transition_events(current_instances, target_instances)
|
events = get_transition_events(current_instances, target_instances, {})
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
|||||||
target_instances: dict[InstanceId, Instance] = {}
|
target_instances: dict[InstanceId, Instance] = {}
|
||||||
|
|
||||||
# act
|
# act
|
||||||
events = get_transition_events(current_instances, target_instances)
|
events = get_transition_events(current_instances, target_instances, {})
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
|
|||||||
@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
|
|||||||
instance_id: InstanceId
|
instance_id: InstanceId
|
||||||
|
|
||||||
|
|
||||||
|
class TaskCancelled(BaseCommand):
|
||||||
|
cancelled_command_id: CommandId
|
||||||
|
|
||||||
|
|
||||||
class TaskFinished(BaseCommand):
|
class TaskFinished(BaseCommand):
|
||||||
finished_command_id: CommandId
|
finished_command_id: CommandId
|
||||||
|
|
||||||
@@ -89,6 +93,7 @@ Command = (
|
|||||||
| PlaceInstance
|
| PlaceInstance
|
||||||
| CreateInstance
|
| CreateInstance
|
||||||
| DeleteInstance
|
| DeleteInstance
|
||||||
|
| TaskCancelled
|
||||||
| TaskFinished
|
| TaskFinished
|
||||||
| SendInputChunk
|
| SendInputChunk
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
|
|||||||
Complete = "Complete"
|
Complete = "Complete"
|
||||||
TimedOut = "TimedOut"
|
TimedOut = "TimedOut"
|
||||||
Failed = "Failed"
|
Failed = "Failed"
|
||||||
|
Cancelled = "Cancelled"
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(TaggedModel):
|
class BaseTask(TaggedModel):
|
||||||
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
|
|||||||
error_message: str | None = Field(default=None)
|
error_message: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class CancelTask(BaseTask):
|
||||||
|
cancelled_task_id: TaskId
|
||||||
|
runner_id: RunnerId
|
||||||
|
|
||||||
|
|
||||||
class ImageGeneration(BaseTask): # emitted by Master
|
class ImageGeneration(BaseTask): # emitted by Master
|
||||||
command_id: CommandId
|
command_id: CommandId
|
||||||
task_params: ImageGenerationTaskParams
|
task_params: ImageGenerationTaskParams
|
||||||
@@ -87,6 +93,7 @@ Task = (
|
|||||||
| LoadModel
|
| LoadModel
|
||||||
| StartWarmup
|
| StartWarmup
|
||||||
| TextGeneration
|
| TextGeneration
|
||||||
|
| CancelTask
|
||||||
| ImageGeneration
|
| ImageGeneration
|
||||||
| ImageEdits
|
| ImageEdits
|
||||||
| Shutdown
|
| Shutdown
|
||||||
|
|||||||
@@ -64,8 +64,6 @@ from exo.worker.runner.bootstrap import logger
|
|||||||
Group = mx.distributed.Group
|
Group = mx.distributed.Group
|
||||||
|
|
||||||
|
|
||||||
# TODO: Test this
|
|
||||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
|
||||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||||
return Memory.from_float_kb(
|
return Memory.from_float_kb(
|
||||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||||
@@ -83,30 +81,6 @@ class ModelLoadingTimeoutError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def mx_barrier(group: Group | None = None):
|
|
||||||
mx.eval(
|
|
||||||
mx.distributed.all_sum(
|
|
||||||
mx.array(1.0),
|
|
||||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
|
||||||
group=group,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
|
||||||
if group is None:
|
|
||||||
return value
|
|
||||||
|
|
||||||
if group.rank() == 0:
|
|
||||||
a = mx.array([value], dtype=mx.int32)
|
|
||||||
else:
|
|
||||||
a = mx.array([0], dtype=mx.int32)
|
|
||||||
|
|
||||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
|
||||||
mx.eval(m)
|
|
||||||
return int(m.item())
|
|
||||||
|
|
||||||
|
|
||||||
class HostList(RootModel[list[str]]):
|
class HostList(RootModel[list[str]]):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||||
@@ -591,3 +565,23 @@ def mlx_cleanup(
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||||
|
if group is None:
|
||||||
|
return bool_
|
||||||
|
num_true = mx.distributed.all_sum(
|
||||||
|
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||||
|
)
|
||||||
|
mx.eval(num_true)
|
||||||
|
return num_true.item() > 0
|
||||||
|
|
||||||
|
|
||||||
|
def mx_barrier(group: Group | None):
|
||||||
|
if group is None:
|
||||||
|
return
|
||||||
|
mx.eval(
|
||||||
|
mx.distributed.all_sum(
|
||||||
|
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from exo.shared.types.events import (
|
|||||||
from exo.shared.types.multiaddr import Multiaddr
|
from exo.shared.types.multiaddr import Multiaddr
|
||||||
from exo.shared.types.state import State
|
from exo.shared.types.state import State
|
||||||
from exo.shared.types.tasks import (
|
from exo.shared.types.tasks import (
|
||||||
|
CancelTask,
|
||||||
CreateRunner,
|
CreateRunner,
|
||||||
DownloadModel,
|
DownloadModel,
|
||||||
ImageEdits,
|
ImageEdits,
|
||||||
@@ -224,15 +225,22 @@ class Worker:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
case Shutdown(runner_id=runner_id):
|
case Shutdown(runner_id=runner_id):
|
||||||
|
runner = self.runners.pop(runner_id)
|
||||||
try:
|
try:
|
||||||
with fail_after(3):
|
with fail_after(3):
|
||||||
await self.runners.pop(runner_id).start_task(task)
|
await runner.start_task(task)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
await self.event_sender.send(
|
await self.event_sender.send(
|
||||||
TaskStatusUpdated(
|
TaskStatusUpdated(
|
||||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
runner.shutdown()
|
||||||
|
case CancelTask(
|
||||||
|
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||||
|
):
|
||||||
|
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||||
# Assemble image from chunks and inject into task
|
# Assemble image from chunks and inject into task
|
||||||
cmd_id = task.command_id
|
cmd_id = task.command_id
|
||||||
@@ -270,18 +278,18 @@ class Worker:
|
|||||||
del self.input_chunk_buffer[cmd_id]
|
del self.input_chunk_buffer[cmd_id]
|
||||||
if cmd_id in self.input_chunk_counts:
|
if cmd_id in self.input_chunk_counts:
|
||||||
del self.input_chunk_counts[cmd_id]
|
del self.input_chunk_counts[cmd_id]
|
||||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
await self._start_runner_task(modified_task)
|
||||||
modified_task
|
|
||||||
)
|
|
||||||
case task:
|
case task:
|
||||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
await self._start_runner_task(task)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self._tg.cancel_scope.cancel()
|
self._tg.cancel_scope.cancel()
|
||||||
|
|
||||||
def _task_to_runner_id(self, task: Task):
|
async def _start_runner_task(self, task: Task):
|
||||||
instance = self.state.instances[task.instance_id]
|
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
await self.runners[
|
||||||
|
instance.shard_assignments.node_to_runner[self.node_id]
|
||||||
|
].start_task(task)
|
||||||
|
|
||||||
async def _nack_request(self, since_idx: int) -> None:
|
async def _nack_request(self, since_idx: int) -> None:
|
||||||
# We request all events after (and including) the missing index.
|
# We request all events after (and including) the missing index.
|
||||||
@@ -320,8 +328,6 @@ class Worker:
|
|||||||
for event in self.out_for_delivery.copy().values():
|
for event in self.out_for_delivery.copy().values():
|
||||||
await self.local_event_sender.send(event)
|
await self.local_event_sender.send(event)
|
||||||
|
|
||||||
## Op Executors
|
|
||||||
|
|
||||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||||
runner = RunnerSupervisor.create(
|
runner = RunnerSupervisor.create(
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
|||||||
|
|
||||||
from exo.shared.types.common import CommandId, NodeId
|
from exo.shared.types.common import CommandId, NodeId
|
||||||
from exo.shared.types.tasks import (
|
from exo.shared.types.tasks import (
|
||||||
|
CancelTask,
|
||||||
ConnectToGroup,
|
ConnectToGroup,
|
||||||
CreateRunner,
|
CreateRunner,
|
||||||
DownloadModel,
|
DownloadModel,
|
||||||
@@ -53,13 +54,14 @@ def plan(
|
|||||||
) -> Task | None:
|
) -> Task | None:
|
||||||
# Python short circuiting OR logic should evaluate these sequentially.
|
# Python short circuiting OR logic should evaluate these sequentially.
|
||||||
return (
|
return (
|
||||||
_kill_runner(runners, all_runners, instances)
|
_cancel_tasks(runners, tasks)
|
||||||
|
or _kill_runner(runners, all_runners, instances)
|
||||||
or _create_runner(node_id, runners, instances)
|
or _create_runner(node_id, runners, instances)
|
||||||
or _model_needs_download(node_id, runners, global_download_status)
|
or _model_needs_download(node_id, runners, global_download_status)
|
||||||
or _init_distributed_backend(runners, all_runners)
|
or _init_distributed_backend(runners, all_runners)
|
||||||
or _load_model(runners, all_runners, global_download_status)
|
or _load_model(runners, all_runners, global_download_status)
|
||||||
or _ready_to_warmup(runners, all_runners)
|
or _ready_to_warmup(runners, all_runners)
|
||||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -270,7 +272,7 @@ def _pending_tasks(
|
|||||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
tasks: Mapping[TaskId, Task],
|
tasks: Mapping[TaskId, Task],
|
||||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||||
) -> Task | None:
|
) -> Task | None:
|
||||||
for task in tasks.values():
|
for task in tasks.values():
|
||||||
# for now, just forward chat completions
|
# for now, just forward chat completions
|
||||||
@@ -284,7 +286,7 @@ def _pending_tasks(
|
|||||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||||
cmd_id = task.command_id
|
cmd_id = task.command_id
|
||||||
expected = task.task_params.total_input_chunks
|
expected = task.task_params.total_input_chunks
|
||||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||||
if received < expected:
|
if received < expected:
|
||||||
continue # Wait for all chunks to arrive
|
continue # Wait for all chunks to arrive
|
||||||
|
|
||||||
@@ -292,16 +294,33 @@ def _pending_tasks(
|
|||||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
# the task status _should_ be set to completed by the LAST runner
|
||||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
# it is currently set by the first
|
||||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
# this is definitely a hack
|
||||||
if task.task_id in runner.completed:
|
if task.task_id in runner.completed:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
|
||||||
|
|
||||||
if isinstance(runner.status, RunnerReady) and all(
|
if isinstance(runner.status, RunnerReady) and all(
|
||||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||||
):
|
):
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
def _cancel_tasks(
|
||||||
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
|
tasks: Mapping[TaskId, Task],
|
||||||
|
) -> Task | None:
|
||||||
|
for task in tasks.values():
|
||||||
|
if task.task_status != TaskStatus.Cancelled:
|
||||||
|
continue
|
||||||
|
for runner_id, runner in runners.items():
|
||||||
|
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||||
|
continue
|
||||||
|
if task.task_id in runner.cancelled:
|
||||||
|
continue
|
||||||
|
return CancelTask(
|
||||||
|
instance_id=task.instance_id,
|
||||||
|
cancelled_task_id=task.task_id,
|
||||||
|
runner_id=runner_id,
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import loguru
|
import loguru
|
||||||
|
|
||||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||||
from exo.shared.types.tasks import Task
|
from exo.shared.types.tasks import Task, TaskId
|
||||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||||
from exo.shared.types.worker.runners import RunnerFailed
|
from exo.shared.types.worker.runners import RunnerFailed
|
||||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||||
@@ -15,6 +15,7 @@ def entrypoint(
|
|||||||
bound_instance: BoundInstance,
|
bound_instance: BoundInstance,
|
||||||
event_sender: MpSender[Event],
|
event_sender: MpSender[Event],
|
||||||
task_receiver: MpReceiver[Task],
|
task_receiver: MpReceiver[Task],
|
||||||
|
cancel_receiver: MpReceiver[TaskId],
|
||||||
_logger: "loguru.Logger",
|
_logger: "loguru.Logger",
|
||||||
) -> None:
|
) -> None:
|
||||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||||
@@ -38,7 +39,7 @@ def entrypoint(
|
|||||||
try:
|
try:
|
||||||
from exo.worker.runner.runner import main
|
from exo.worker.runner.runner import main
|
||||||
|
|
||||||
main(bound_instance, event_sender, task_receiver)
|
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||||
except ClosedResourceError:
|
except ClosedResourceError:
|
||||||
logger.warning("Runner communication closed unexpectedly")
|
logger.warning("Runner communication closed unexpectedly")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import resource
|
import resource
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
@@ -88,6 +89,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
|||||||
initialize_mlx,
|
initialize_mlx,
|
||||||
load_mlx_items,
|
load_mlx_items,
|
||||||
mlx_force_oom,
|
mlx_force_oom,
|
||||||
|
mx_any,
|
||||||
)
|
)
|
||||||
from exo.worker.runner.bootstrap import logger
|
from exo.worker.runner.bootstrap import logger
|
||||||
|
|
||||||
@@ -112,6 +114,7 @@ def main(
|
|||||||
bound_instance: BoundInstance,
|
bound_instance: BoundInstance,
|
||||||
event_sender: MpSender[Event],
|
event_sender: MpSender[Event],
|
||||||
task_receiver: MpReceiver[Task],
|
task_receiver: MpReceiver[Task],
|
||||||
|
cancel_receiver: MpReceiver[TaskId],
|
||||||
):
|
):
|
||||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||||
@@ -129,11 +132,15 @@ def main(
|
|||||||
time.sleep(timeout)
|
time.sleep(timeout)
|
||||||
|
|
||||||
setup_start_time = time.time()
|
setup_start_time = time.time()
|
||||||
|
cancelled_tasks = set[TaskId]()
|
||||||
|
|
||||||
model: Model | DistributedImageModel | None = None
|
# type checker was unhappy with me - splitting these fixed it
|
||||||
|
inference_model: Model | None = None
|
||||||
|
image_model: DistributedImageModel | None = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
group = None
|
group = None
|
||||||
kv_prefix_cache: KVPrefixCache | None = None
|
kv_prefix_cache: KVPrefixCache | None = None
|
||||||
|
check_for_cancel_every: int | None = None
|
||||||
|
|
||||||
current_status: RunnerStatus = RunnerIdle()
|
current_status: RunnerStatus = RunnerIdle()
|
||||||
logger.info("runner created")
|
logger.info("runner created")
|
||||||
@@ -146,6 +153,7 @@ def main(
|
|||||||
if task.task_id in seen:
|
if task.task_id in seen:
|
||||||
logger.warning("repeat task - potential error")
|
logger.warning("repeat task - potential error")
|
||||||
seen.add(task.task_id)
|
seen.add(task.task_id)
|
||||||
|
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
event_sender.send(
|
event_sender.send(
|
||||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||||
)
|
)
|
||||||
@@ -191,7 +199,7 @@ def main(
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||||
model, tokenizer = load_mlx_items(
|
inference_model, tokenizer = load_mlx_items(
|
||||||
bound_instance, group, on_timeout=on_model_load_timeout
|
bound_instance, group, on_timeout=on_model_load_timeout
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -203,7 +211,7 @@ def main(
|
|||||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||||
):
|
):
|
||||||
model = initialize_image_model(bound_instance)
|
image_model = initialize_image_model(bound_instance)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||||
@@ -211,8 +219,6 @@ def main(
|
|||||||
current_status = RunnerLoaded()
|
current_status = RunnerLoaded()
|
||||||
logger.info("runner loaded")
|
logger.info("runner loaded")
|
||||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||||
assert model
|
|
||||||
|
|
||||||
current_status = RunnerWarmingUp()
|
current_status = RunnerWarmingUp()
|
||||||
logger.info("runner warming up")
|
logger.info("runner warming up")
|
||||||
event_sender.send(
|
event_sender.send(
|
||||||
@@ -224,16 +230,31 @@ def main(
|
|||||||
|
|
||||||
logger.info(f"warming up inference for instance: {instance}")
|
logger.info(f"warming up inference for instance: {instance}")
|
||||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||||
assert not isinstance(model, DistributedImageModel)
|
assert inference_model
|
||||||
assert tokenizer
|
assert tokenizer
|
||||||
|
|
||||||
|
t = time.perf_counter()
|
||||||
toks = warmup_inference(
|
toks = warmup_inference(
|
||||||
model=model,
|
model=inference_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
group=group,
|
group=group,
|
||||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
|
||||||
)
|
)
|
||||||
logger.info(f"warmed up by generating {toks} tokens")
|
logger.info(f"warmed up by generating {toks} tokens")
|
||||||
|
check_for_cancel_every = min(
|
||||||
|
math.ceil(toks / (time.perf_counter() - t)), 100
|
||||||
|
)
|
||||||
|
if group is not None:
|
||||||
|
check_for_cancel_every = int(
|
||||||
|
mx.max(
|
||||||
|
mx.distributed.all_gather(
|
||||||
|
mx.array([check_for_cancel_every]), group=group
|
||||||
|
)
|
||||||
|
).item()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||||
)
|
)
|
||||||
@@ -241,8 +262,8 @@ def main(
|
|||||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||||
):
|
):
|
||||||
assert isinstance(model, DistributedImageModel)
|
assert image_model
|
||||||
image = warmup_image_generator(model=model)
|
image = warmup_image_generator(model=image_model)
|
||||||
if image is not None:
|
if image is not None:
|
||||||
logger.info(f"warmed up by generating {image.size} image")
|
logger.info(f"warmed up by generating {image.size} image")
|
||||||
else:
|
else:
|
||||||
@@ -262,9 +283,9 @@ def main(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||||
|
assert inference_model
|
||||||
assert model and not isinstance(model, DistributedImageModel)
|
|
||||||
assert tokenizer
|
assert tokenizer
|
||||||
|
assert check_for_cancel_every
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_check_for_debug_prompts(task_params)
|
_check_for_debug_prompts(task_params)
|
||||||
@@ -274,7 +295,7 @@ def main(
|
|||||||
|
|
||||||
# Generate responses using the actual MLX generation
|
# Generate responses using the actual MLX generation
|
||||||
mlx_generator = mlx_generate(
|
mlx_generator = mlx_generate(
|
||||||
model=model,
|
model=inference_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
task=task_params,
|
task=task_params,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -299,11 +320,11 @@ def main(
|
|||||||
patch_glm_tokenizer(tokenizer)
|
patch_glm_tokenizer(tokenizer)
|
||||||
|
|
||||||
# GPT-OSS specific parsing to match other model formats.
|
# GPT-OSS specific parsing to match other model formats.
|
||||||
elif isinstance(model, GptOssModel):
|
elif isinstance(inference_model, GptOssModel):
|
||||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||||
|
|
||||||
if tokenizer.has_tool_calling and not isinstance(
|
if tokenizer.has_tool_calling and not isinstance(
|
||||||
model, GptOssModel
|
inference_model, GptOssModel
|
||||||
):
|
):
|
||||||
assert tokenizer.tool_call_start
|
assert tokenizer.tool_call_start
|
||||||
assert tokenizer.tool_call_end
|
assert tokenizer.tool_call_end
|
||||||
@@ -316,7 +337,18 @@ def main(
|
|||||||
)
|
)
|
||||||
|
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
|
tokens_since_last_cancel_check = 0
|
||||||
for response in mlx_generator:
|
for response in mlx_generator:
|
||||||
|
tokens_since_last_cancel_check += 1
|
||||||
|
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||||
|
tokens_since_last_cancel_check = 0
|
||||||
|
cancelled_tasks.update(cancel_receiver.collect())
|
||||||
|
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||||
|
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||||
|
)
|
||||||
|
if mx_any(want_to_cancel, group):
|
||||||
|
break
|
||||||
|
|
||||||
match response:
|
match response:
|
||||||
case GenerationResponse():
|
case GenerationResponse():
|
||||||
completion_tokens += 1
|
completion_tokens += 1
|
||||||
@@ -388,7 +420,7 @@ def main(
|
|||||||
case ImageGeneration(
|
case ImageGeneration(
|
||||||
task_params=task_params, command_id=command_id
|
task_params=task_params, command_id=command_id
|
||||||
) if isinstance(current_status, RunnerReady):
|
) if isinstance(current_status, RunnerReady):
|
||||||
assert isinstance(model, DistributedImageModel)
|
assert image_model
|
||||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||||
current_status = RunnerRunning()
|
current_status = RunnerRunning()
|
||||||
logger.info("runner running")
|
logger.info("runner running")
|
||||||
@@ -401,7 +433,9 @@ def main(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
image_index = 0
|
image_index = 0
|
||||||
for response in generate_image(model=model, task=task_params):
|
for response in generate_image(
|
||||||
|
model=image_model, task=task_params
|
||||||
|
):
|
||||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||||
|
|
||||||
if is_primary_output:
|
if is_primary_output:
|
||||||
@@ -451,7 +485,7 @@ def main(
|
|||||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||||
isinstance(current_status, RunnerReady)
|
isinstance(current_status, RunnerReady)
|
||||||
):
|
):
|
||||||
assert isinstance(model, DistributedImageModel)
|
assert image_model
|
||||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||||
current_status = RunnerRunning()
|
current_status = RunnerRunning()
|
||||||
logger.info("runner running")
|
logger.info("runner running")
|
||||||
@@ -464,7 +498,9 @@ def main(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
image_index = 0
|
image_index = 0
|
||||||
for response in generate_image(model=model, task=task_params):
|
for response in generate_image(
|
||||||
|
model=image_model, task=task_params
|
||||||
|
):
|
||||||
if _is_primary_output_node(shard_metadata):
|
if _is_primary_output_node(shard_metadata):
|
||||||
match response:
|
match response:
|
||||||
case PartialImageResponse():
|
case PartialImageResponse():
|
||||||
@@ -530,7 +566,7 @@ def main(
|
|||||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||||
)
|
)
|
||||||
if isinstance(current_status, RunnerShutdown):
|
if isinstance(current_status, RunnerShutdown):
|
||||||
del model, tokenizer, group
|
del inference_model, image_model, tokenizer, group
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
|||||||
@@ -47,9 +47,11 @@ class RunnerSupervisor:
|
|||||||
_ev_recv: MpReceiver[Event]
|
_ev_recv: MpReceiver[Event]
|
||||||
_task_sender: MpSender[Task]
|
_task_sender: MpSender[Task]
|
||||||
_event_sender: Sender[Event]
|
_event_sender: Sender[Event]
|
||||||
|
_cancel_sender: MpSender[TaskId]
|
||||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||||
|
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@@ -60,8 +62,8 @@ class RunnerSupervisor:
|
|||||||
initialize_timeout: float = 400,
|
initialize_timeout: float = 400,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
ev_send, ev_recv = mp_channel[Event]()
|
ev_send, ev_recv = mp_channel[Event]()
|
||||||
# A task is kind of a runner command
|
|
||||||
task_sender, task_recv = mp_channel[Task]()
|
task_sender, task_recv = mp_channel[Task]()
|
||||||
|
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||||
|
|
||||||
runner_process = Process(
|
runner_process = Process(
|
||||||
target=entrypoint,
|
target=entrypoint,
|
||||||
@@ -69,6 +71,7 @@ class RunnerSupervisor:
|
|||||||
bound_instance,
|
bound_instance,
|
||||||
ev_send,
|
ev_send,
|
||||||
task_recv,
|
task_recv,
|
||||||
|
cancel_recv,
|
||||||
logger,
|
logger,
|
||||||
),
|
),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
@@ -83,6 +86,7 @@ class RunnerSupervisor:
|
|||||||
initialize_timeout=initialize_timeout,
|
initialize_timeout=initialize_timeout,
|
||||||
_ev_recv=ev_recv,
|
_ev_recv=ev_recv,
|
||||||
_task_sender=task_sender,
|
_task_sender=task_sender,
|
||||||
|
_cancel_sender=cancel_sender,
|
||||||
_event_sender=event_sender,
|
_event_sender=event_sender,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,6 +101,7 @@ class RunnerSupervisor:
|
|||||||
self._ev_recv.close()
|
self._ev_recv.close()
|
||||||
self._task_sender.close()
|
self._task_sender.close()
|
||||||
self._event_sender.close()
|
self._event_sender.close()
|
||||||
|
self._cancel_sender.close()
|
||||||
self.runner_process.join(1)
|
self.runner_process.join(1)
|
||||||
if not self.runner_process.is_alive():
|
if not self.runner_process.is_alive():
|
||||||
logger.info("Runner process succesfully terminated")
|
logger.info("Runner process succesfully terminated")
|
||||||
@@ -112,14 +117,6 @@ class RunnerSupervisor:
|
|||||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||||
self.runner_process.kill()
|
self.runner_process.kill()
|
||||||
|
|
||||||
self.runner_process.join(1)
|
|
||||||
if not self.runner_process.is_alive():
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.critical(
|
|
||||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start_task(self, task: Task):
|
async def start_task(self, task: Task):
|
||||||
if task.task_id in self.pending:
|
if task.task_id in self.pending:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -141,6 +138,13 @@ class RunnerSupervisor:
|
|||||||
return
|
return
|
||||||
await event.wait()
|
await event.wait()
|
||||||
|
|
||||||
|
async def cancel_task(self, task_id: TaskId):
|
||||||
|
if task_id in self.completed:
|
||||||
|
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||||
|
return
|
||||||
|
self.cancelled.add(task_id)
|
||||||
|
await self._cancel_sender.send_async(task_id)
|
||||||
|
|
||||||
async def _forward_events(self):
|
async def _forward_events(self):
|
||||||
with self._ev_recv as events:
|
with self._ev_recv as events:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import exo.worker.runner.runner as mlx_runner
|
import exo.worker.runner.runner as mlx_runner
|
||||||
@@ -19,6 +20,7 @@ from exo.shared.types.tasks import (
|
|||||||
Shutdown,
|
Shutdown,
|
||||||
StartWarmup,
|
StartWarmup,
|
||||||
Task,
|
Task,
|
||||||
|
TaskId,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
TextGeneration,
|
TextGeneration,
|
||||||
)
|
)
|
||||||
@@ -113,6 +115,8 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
|||||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||||
|
monkeypatch.setattr(mx.distributed, "all_gather", make_nothin(mx.array([1])))
|
||||||
|
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||||
@@ -163,6 +167,7 @@ def _run(tasks: Iterable[Task]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
task_sender, task_receiver = mp_channel[Task]()
|
task_sender, task_receiver = mp_channel[Task]()
|
||||||
|
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||||
event_sender = EventCollector()
|
event_sender = EventCollector()
|
||||||
|
|
||||||
with task_sender:
|
with task_sender:
|
||||||
@@ -174,7 +179,7 @@ def _run(tasks: Iterable[Task]):
|
|||||||
task_receiver.close = nothin
|
task_receiver.close = nothin
|
||||||
task_receiver.join = nothin
|
task_receiver.join = nothin
|
||||||
|
|
||||||
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
|
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
|
||||||
|
|
||||||
return event_sender.events
|
return event_sender.events
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user