mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 23:36:30 -05:00
Compare commits
6 Commits
gh-screens
...
meta-insta
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5c31e50f3 | ||
|
|
aa3f106fb9 | ||
|
|
2e29605194 | ||
|
|
cacb456cb2 | ||
|
|
51021f6fc6 | ||
|
|
025ed9fd82 |
13
Cargo.lock
generated
13
Cargo.lock
generated
@@ -890,7 +890,7 @@ dependencies = [
|
||||
"delegate",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
@@ -914,6 +914,12 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -1022,7 +1028,10 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -2753,7 +2762,7 @@ dependencies = [
|
||||
"delegate",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"futures-timer",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
|
||||
@@ -29,14 +29,13 @@ util = { path = "rust/util" }
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-lite = "2.6.1"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
|
||||
@@ -72,12 +72,19 @@ There are two ways to run exo:
|
||||
|
||||
### Run from Source (macOS)
|
||||
|
||||
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
|
||||
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly:
|
||||
|
||||
```bash
|
||||
nix run .#exo
|
||||
```
|
||||
|
||||
**Note:** To accept the Cachix binary cache (and avoid the Xcode Metal ToolChain), add to `/etc/nix/nix.conf`:
|
||||
```
|
||||
trusted-users = root (or your username)
|
||||
experimental-features = nix-command flakes
|
||||
```
|
||||
Then restart the Nix daemon: `sudo launchctl kickstart -k system/org.nixos.nix-daemon`
|
||||
|
||||
**Prerequisites:**
|
||||
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
@@ -38,6 +38,8 @@ class Scenario:
|
||||
expected_function: str | None = None
|
||||
required_arg_keys: list[str] | None = None
|
||||
tool_result: str | None = None
|
||||
nested_array_key: str | None = None
|
||||
required_item_keys: list[str] | None = None
|
||||
|
||||
|
||||
def load_scenarios(path: Path) -> list[Scenario]:
|
||||
@@ -105,6 +107,8 @@ def load_scenarios(path: Path) -> list[Scenario]:
|
||||
expected_function=s.get("expected_function"),
|
||||
required_arg_keys=s.get("required_arg_keys"),
|
||||
tool_result=tool_result,
|
||||
nested_array_key=s.get("nested_array_key"),
|
||||
required_item_keys=s.get("required_item_keys"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -147,6 +151,35 @@ def validate_args(args_str: str, required_keys: list[str]) -> tuple[bool, str |
|
||||
return True, None
|
||||
|
||||
|
||||
def validate_nested_args(
|
||||
args_str: str,
|
||||
array_key: str,
|
||||
required_item_keys: list[str],
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check that args[array_key] is a list of objects with required keys."""
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
return False, f"Invalid JSON: {exc}"
|
||||
if not isinstance(args, dict):
|
||||
return False, f"Expected dict, got {type(args).__name__}"
|
||||
arr = args.get(array_key)
|
||||
if not isinstance(arr, list):
|
||||
return False, f"'{array_key}' is not an array (got {type(arr).__name__})"
|
||||
if len(arr) == 0:
|
||||
return False, f"'{array_key}' is empty"
|
||||
for i, item in enumerate(arr):
|
||||
if not isinstance(item, dict):
|
||||
return (
|
||||
False,
|
||||
f"'{array_key}[{i}]' is not an object (got {type(item).__name__})",
|
||||
)
|
||||
missing = [k for k in required_item_keys if k not in item]
|
||||
if missing:
|
||||
return False, f"'{array_key}[{i}]' missing keys: {missing}"
|
||||
return True, None
|
||||
|
||||
|
||||
def call_api(
|
||||
client: httpx.Client,
|
||||
host: str,
|
||||
@@ -699,6 +732,15 @@ def run_scenario(
|
||||
checks["valid_arguments"] = ok
|
||||
else:
|
||||
checks["valid_arguments"] = True
|
||||
if scenario.nested_array_key and scenario.required_item_keys:
|
||||
ok, nested_err = validate_nested_args(
|
||||
parsed.tool_call["arguments"],
|
||||
scenario.nested_array_key,
|
||||
scenario.required_item_keys,
|
||||
)
|
||||
checks["valid_nested_structure"] = ok
|
||||
if not ok:
|
||||
args_err = nested_err
|
||||
else:
|
||||
checks["correct_function"] = False
|
||||
checks["valid_arguments"] = False
|
||||
|
||||
@@ -39,6 +39,30 @@ description = "Product category to filter by"
|
||||
type = "number"
|
||||
description = "Maximum price in USD"
|
||||
|
||||
[tools.create_todos]
|
||||
description = "Create a structured todo list"
|
||||
required = ["todos"]
|
||||
|
||||
[tools.create_todos.properties.todos]
|
||||
type = "array"
|
||||
description = "List of todo items"
|
||||
|
||||
[tools.create_todos.properties.todos.items]
|
||||
type = "object"
|
||||
required = ["content", "status", "priority"]
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.content]
|
||||
type = "string"
|
||||
description = "The todo item text"
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.status]
|
||||
type = "string"
|
||||
description = "Status: pending, in_progress, or completed"
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.priority]
|
||||
type = "string"
|
||||
description = "Priority: low, normal, or high"
|
||||
|
||||
# -- Should call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
@@ -219,6 +243,48 @@ role = "tool"
|
||||
tool_call_id = "call_4"
|
||||
content = '{"temperature": "18C", "condition": "cloudy"}'
|
||||
|
||||
# -- Nested object schema (regression for lossy chat template rendering) --
|
||||
|
||||
[[scenarios]]
|
||||
name = "nested_schema_tool_call"
|
||||
description = "Tool call with nested object array schema -> create_todos"
|
||||
expect_tool_call = true
|
||||
expected_function = "create_todos"
|
||||
required_arg_keys = ["todos"]
|
||||
nested_array_key = "todos"
|
||||
required_item_keys = ["content", "status", "priority"]
|
||||
tools = ["create_todos"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Create a todo list with 3 items to learn Python"
|
||||
|
||||
# -- Tool name integrity (regression for harmony token leaking into name) --
|
||||
|
||||
[tools.glob]
|
||||
description = "Search for files matching a glob pattern in the codebase"
|
||||
required = ["pattern"]
|
||||
|
||||
[tools.glob.properties.pattern]
|
||||
type = "string"
|
||||
description = "The glob pattern to match files against, e.g. '**/*.py'"
|
||||
|
||||
[tools.glob.properties.path]
|
||||
type = "string"
|
||||
description = "The directory to search in"
|
||||
|
||||
[[scenarios]]
|
||||
name = "tool_name_integrity"
|
||||
description = "Tool name must not contain harmony tokens like <|channel|>"
|
||||
expect_tool_call = true
|
||||
expected_function = "glob"
|
||||
required_arg_keys = ["pattern"]
|
||||
tools = ["glob"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Find all Python files in the src directory"
|
||||
|
||||
# -- Should NOT call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
totalTokens,
|
||||
thinkingEnabled as thinkingEnabledStore,
|
||||
setConversationThinking,
|
||||
stopGeneration,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
@@ -653,86 +654,92 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => stopGeneration()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-exo-medium-gray hover:text-white"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
class="w-3 h-3 sm:w-3.5 sm:h-3.5"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
<rect x="6" y="6" width="12" height="12" rx="1" />
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
<span class="hidden sm:inline">Cancel</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -3,16 +3,17 @@
|
||||
messages,
|
||||
currentResponse,
|
||||
isLoading,
|
||||
prefillProgress,
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken,
|
||||
setEditingImage,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import type { Message } from "$lib/stores/app.svelte";
|
||||
import type { MessageAttachment } from "$lib/stores/app.svelte";
|
||||
import MarkdownContent from "./MarkdownContent.svelte";
|
||||
import TokenHeatmap from "./TokenHeatmap.svelte";
|
||||
import PrefillProgressBar from "./PrefillProgressBar.svelte";
|
||||
import ImageLightbox from "./ImageLightbox.svelte";
|
||||
|
||||
interface Props {
|
||||
@@ -25,6 +26,7 @@
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
const prefill = $derived(prefillProgress());
|
||||
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
@@ -428,6 +430,9 @@
|
||||
{:else}
|
||||
<!-- Assistant message styling -->
|
||||
<div class="p-3 sm:p-4">
|
||||
{#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}
|
||||
<PrefillProgressBar progress={prefill} class="mb-3" />
|
||||
{/if}
|
||||
{#if message.thinking && message.thinking.trim().length > 0}
|
||||
<div
|
||||
class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40"
|
||||
|
||||
@@ -26,7 +26,8 @@
|
||||
downloadedOnNodes = [],
|
||||
}: HuggingFaceResultItemProps = $props();
|
||||
|
||||
function formatNumber(num: number): string {
|
||||
function formatNumber(num: number | undefined): string {
|
||||
if (num == null) return "0";
|
||||
if (num >= 1000000) {
|
||||
return `${(num / 1000000).toFixed(1)}M`;
|
||||
} else if (num >= 1000) {
|
||||
|
||||
52
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
52
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
@@ -0,0 +1,52 @@
|
||||
<script lang="ts">
|
||||
import type { PrefillProgress } from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
progress: PrefillProgress;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { progress, class: className = "" }: Props = $props();
|
||||
|
||||
const percentage = $derived(
|
||||
progress.total > 0
|
||||
? Math.round((progress.processed / progress.total) * 100)
|
||||
: 0,
|
||||
);
|
||||
|
||||
function formatTokenCount(count: number | undefined): string {
|
||||
if (count == null) return "0";
|
||||
if (count >= 1000) {
|
||||
return `${(count / 1000).toFixed(1)}k`;
|
||||
}
|
||||
return count.toString();
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="prefill-progress {className}">
|
||||
<div
|
||||
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
|
||||
>
|
||||
<span>Processing prompt</span>
|
||||
<span class="font-mono">
|
||||
{formatTokenCount(progress.processed)} / {formatTokenCount(
|
||||
progress.total,
|
||||
)} tokens
|
||||
</span>
|
||||
</div>
|
||||
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
|
||||
style="width: {percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
|
||||
{percentage}%
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.prefill-progress {
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
@@ -273,6 +273,11 @@ export interface TokenData {
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface PrefillProgress {
|
||||
processed: number;
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -520,6 +525,10 @@ class AppStore {
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
tps = $state<number | null>(null); // Tokens per second
|
||||
totalTokens = $state<number>(0); // Total tokens in current response
|
||||
prefillProgress = $state<PrefillProgress | null>(null);
|
||||
|
||||
// Abort controller for stopping generation
|
||||
private currentAbortController: AbortController | null = null;
|
||||
|
||||
// Topology state
|
||||
topologyData = $state<TopologyData | null>(null);
|
||||
@@ -2005,6 +2014,7 @@ class AppStore {
|
||||
reader: ReadableStreamDefaultReader<Uint8Array>,
|
||||
targetConversationId: string,
|
||||
onChunk: (parsed: T) => void,
|
||||
onEvent?: Record<string, (data: unknown) => void>,
|
||||
): Promise<void> {
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
@@ -2025,6 +2035,24 @@ class AppStore {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) continue;
|
||||
|
||||
// Handle SSE comments (": key json") for prefill progress etc.
|
||||
if (trimmed.startsWith(": ") && onEvent) {
|
||||
const comment = trimmed.slice(2);
|
||||
const spaceIdx = comment.indexOf(" ");
|
||||
if (spaceIdx > 0) {
|
||||
const key = comment.slice(0, spaceIdx);
|
||||
if (onEvent[key]) {
|
||||
try {
|
||||
const parsed = JSON.parse(comment.slice(spaceIdx + 1));
|
||||
onEvent[key](parsed);
|
||||
} catch {
|
||||
// Skip malformed JSON in comment
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (trimmed.startsWith("data: ")) {
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
@@ -2256,6 +2284,9 @@ class AppStore {
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = 0;
|
||||
|
||||
const abortController = new AbortController();
|
||||
this.currentAbortController = abortController;
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -2272,6 +2303,7 @@ class AppStore {
|
||||
enable_thinking: enableThinking,
|
||||
}),
|
||||
}),
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -2309,6 +2341,11 @@ class AppStore {
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
// Clear prefill progress when first token data arrives
|
||||
if (this.prefillProgress) {
|
||||
this.prefillProgress = null;
|
||||
}
|
||||
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
|
||||
@@ -2371,8 +2408,26 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
},
|
||||
{
|
||||
prefill_progress: (data) => {
|
||||
// TaggedModel wraps as {"PrefillProgressChunk": {...}}
|
||||
// model_dump_json() uses snake_case (by_alias defaults to False)
|
||||
const raw = data as Record<string, unknown>;
|
||||
const inner = (raw["PrefillProgressChunk"] ?? raw) as {
|
||||
processed_tokens: number;
|
||||
total_tokens: number;
|
||||
};
|
||||
this.prefillProgress = {
|
||||
processed: inner.processed_tokens,
|
||||
total: inner.total_tokens,
|
||||
};
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Clear prefill progress after stream ends
|
||||
this.prefillProgress = null;
|
||||
|
||||
// Calculate final TPS
|
||||
if (firstTokenTime !== null && tokenCount > 1) {
|
||||
const totalGenerationTime = performance.now() - firstTokenTime;
|
||||
@@ -2403,20 +2458,31 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error sending message:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
"Failed to get response",
|
||||
);
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
// User stopped generation — not an error
|
||||
} else {
|
||||
console.error("Error sending message:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
"Failed to get response",
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.prefillProgress = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
stopGeneration(): void {
|
||||
this.currentAbortController?.abort();
|
||||
this.currentAbortController = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an image using the image generation API
|
||||
*/
|
||||
@@ -3043,6 +3109,7 @@ export const isLoading = () => appStore.isLoading;
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const prefillProgress = () => appStore.prefillProgress;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const runners = () => appStore.runners;
|
||||
@@ -3060,6 +3127,7 @@ export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
|
||||
// Actions
|
||||
export const stopGeneration = () => appStore.stopGeneration();
|
||||
export const startChat = () => appStore.startChat();
|
||||
export const sendMessage = (
|
||||
content: string,
|
||||
|
||||
@@ -932,13 +932,6 @@
|
||||
};
|
||||
}
|
||||
|
||||
// Debug: Log downloads data when it changes
|
||||
$effect(() => {
|
||||
if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
console.log("[Download Debug] Current downloads:", downloadsData);
|
||||
}
|
||||
});
|
||||
|
||||
// Helper to get download status for an instance
|
||||
function getInstanceDownloadStatus(
|
||||
instanceId: string,
|
||||
|
||||
@@ -74,7 +74,6 @@
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
@@ -27,7 +27,7 @@ networking = { workspace = true }
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
"nightly", # enables better-supported GIL integration
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
@@ -45,11 +45,10 @@ pyo3-log = "0.13.2"
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
@@ -60,3 +59,4 @@ env_logger = "0.11"
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -26,8 +25,8 @@ where
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
F: Future + Send,
|
||||
F::Output: Send,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
|
||||
@@ -4,25 +4,12 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
mod ident;
|
||||
mod networking;
|
||||
|
||||
use crate::ident::ident_submodule;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
@@ -32,14 +19,6 @@ pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
@@ -180,7 +159,6 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::{PyKeypair, PyPeerId};
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
@@ -1,81 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -22,7 +22,7 @@ delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use futures::stream::StreamExt as _;
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
@@ -38,19 +38,19 @@ async fn main() {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
@@ -64,7 +64,7 @@ async fn main() {
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_lite::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
if self.retry_delay.poll(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
|
||||
@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use futures_lite::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
{ inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
rustToolchain = inputs'.fenix.packages.stable.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
@@ -19,7 +19,12 @@ from exo.shared.types.api import (
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
@@ -123,67 +128,81 @@ def chunk_to_response(
|
||||
|
||||
async def generate_chat_stream(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
PrefillProgressChunk | ErrorChunk | ToolCallChunk | TokenChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
# Use SSE comment so third-party clients ignore it
|
||||
yield f": prefill_progress {chunk.model_dump_json()}\n\n"
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
case ErrorChunk():
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
case ToolCallChunk():
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(
|
||||
update={"usage": last_usage}
|
||||
)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -197,38 +216,43 @@ async def collect_chat_response(
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
case ErrorChunk():
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
)
|
||||
)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
@@ -5,7 +5,12 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import FinishReason, Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
@@ -160,7 +165,9 @@ def claude_request_to_text_generation(
|
||||
async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -172,6 +179,9 @@ async def collect_claude_response(
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
@@ -230,7 +240,9 @@ async def collect_claude_response(
|
||||
async def generate_claude_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Claude Messages API streaming events from TokenChunks."""
|
||||
# Initial message_start event
|
||||
@@ -256,6 +268,9 @@ async def generate_claude_stream(
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
@@ -5,7 +5,12 @@ from itertools import count
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
FunctionCallInputItem,
|
||||
@@ -26,6 +31,7 @@ from exo.shared.types.openai_responses import (
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponsesStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
@@ -33,6 +39,11 @@ from exo.shared.types.openai_responses import (
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||
"""Format a streaming event as an SSE message."""
|
||||
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||
if isinstance(content, str):
|
||||
@@ -121,7 +132,9 @@ def responses_request_to_text_generation(
|
||||
async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -134,6 +147,9 @@ async def collect_responses_response(
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
@@ -189,7 +205,9 @@ async def collect_responses_response(
|
||||
async def generate_responses_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
@@ -207,13 +225,13 @@ async def generate_responses_stream(
|
||||
created_event = ResponseCreatedEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(created_event)
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(in_progress_event)
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
@@ -224,7 +242,7 @@ async def generate_responses_stream(
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq), output_index=0, item=initial_item
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(item_added)
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
@@ -235,7 +253,7 @@ async def generate_responses_stream(
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
@@ -243,6 +261,9 @@ async def generate_responses_stream(
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
@@ -266,7 +287,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
item=fc_item,
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(fc_added)
|
||||
|
||||
# response.function_call_arguments.delta
|
||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||
@@ -275,7 +296,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
delta=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
||||
yield _format_sse(args_delta)
|
||||
|
||||
# response.function_call_arguments.done
|
||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||
@@ -285,7 +306,7 @@ async def generate_responses_stream(
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(args_done)
|
||||
|
||||
# response.output_item.done
|
||||
fc_done_item = ResponseFunctionCallItem(
|
||||
@@ -300,7 +321,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
item=fc_done_item,
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(fc_item_done)
|
||||
|
||||
function_call_items.append(fc_done_item)
|
||||
next_output_index += 1
|
||||
@@ -316,7 +337,7 @@ async def generate_responses_stream(
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(delta_event)
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
@@ -326,7 +347,7 @@ async def generate_responses_stream(
|
||||
content_index=0,
|
||||
text=accumulated_text,
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(text_done)
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
@@ -337,7 +358,7 @@ async def generate_responses_stream(
|
||||
content_index=0,
|
||||
part=final_part,
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(part_done)
|
||||
|
||||
# response.output_item.done
|
||||
final_message_item = ResponseMessageItem(
|
||||
@@ -348,7 +369,7 @@ async def generate_responses_stream(
|
||||
item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(item_done)
|
||||
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
@@ -373,4 +394,4 @@ async def generate_responses_stream(
|
||||
completed_event = ResponseCompletedEvent(
|
||||
sequence_number=next(seq), response=final_response
|
||||
)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(completed_event)
|
||||
|
||||
@@ -107,6 +107,7 @@ from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
@@ -137,6 +138,7 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -221,7 +223,8 @@ class API:
|
||||
)
|
||||
|
||||
self._text_generation_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
CommandId,
|
||||
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
@@ -527,19 +530,23 @@ class API:
|
||||
|
||||
async def _token_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
) -> AsyncGenerator[
|
||||
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
|
||||
]:
|
||||
"""Yield chunks for a given command until completion.
|
||||
|
||||
This is the internal low-level stream used by all API adapters.
|
||||
"""
|
||||
try:
|
||||
self._text_generation_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
@@ -566,6 +573,9 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._token_chunk_stream(command_id):
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -1446,6 +1456,21 @@ class API:
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
elif isinstance(event, PrefillProgress):
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
try:
|
||||
await queue.send(
|
||||
PrefillProgressChunk(
|
||||
model=event.model,
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -60,6 +62,7 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
|
||||
@@ -94,6 +97,7 @@ class Master:
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -407,6 +411,11 @@ class Master:
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
# After broadcasting JacclSideChannelData, accumulate and
|
||||
# emit gathered result when all runners have contributed.
|
||||
if isinstance(event, JacclSideChannelData):
|
||||
await self._handle_jaccl_side_channel(event)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
@@ -460,3 +469,42 @@ class Master:
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
|
||||
"""Accumulate SideChannel contributions; when all runners for an instance
|
||||
have submitted for the same sequence, emit JacclSideChannelGathered."""
|
||||
iid = event.instance_id
|
||||
seq = event.sequence
|
||||
|
||||
if iid not in self._jaccl_pending:
|
||||
self._jaccl_pending[iid] = {}
|
||||
if seq not in self._jaccl_pending[iid]:
|
||||
self._jaccl_pending[iid][seq] = {}
|
||||
self._jaccl_pending[iid][seq][event.runner_id] = event.data
|
||||
|
||||
instance = self.state.instances.get(iid)
|
||||
if instance is None:
|
||||
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
|
||||
return
|
||||
|
||||
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
|
||||
submitted = set(self._jaccl_pending[iid][seq].keys())
|
||||
|
||||
logger.info(
|
||||
f"JACCL side channel: instance={iid} seq={seq} "
|
||||
f"submitted={len(submitted)}/{len(expected_runners)}"
|
||||
)
|
||||
|
||||
if submitted >= expected_runners:
|
||||
gathered = dict(self._jaccl_pending[iid][seq])
|
||||
del self._jaccl_pending[iid][seq]
|
||||
if not self._jaccl_pending[iid]:
|
||||
del self._jaccl_pending[iid]
|
||||
|
||||
await self.event_sender.send(
|
||||
JacclSideChannelGathered(
|
||||
instance_id=iid,
|
||||
sequence=seq,
|
||||
gathered_data=gathered,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -12,9 +12,12 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -64,8 +67,11 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
| JacclSideChannelData()
|
||||
| JacclSideChannelGathered()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
|
||||
@@ -76,4 +76,13 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
class PrefillProgressChunk(BaseChunk):
|
||||
"""Data class for prefill progress events during streaming."""
|
||||
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
GenerationChunk = (
|
||||
TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk | PrefillProgressChunk
|
||||
)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from typing import Annotated, final
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BeforeValidator, Field, PlainSerializer
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -14,6 +16,28 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
|
||||
|
||||
def _decode_base64_bytes(v: bytes | str) -> bytes:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
return base64.b64decode(v)
|
||||
|
||||
|
||||
def _encode_base64_bytes(v: bytes) -> str:
|
||||
return base64.b64encode(v).decode("ascii")
|
||||
|
||||
|
||||
Base64Bytes = Annotated[
|
||||
bytes,
|
||||
BeforeValidator(_decode_base64_bytes),
|
||||
PlainSerializer(_encode_base64_bytes, return_type=str),
|
||||
]
|
||||
"""bytes that serialize to/from base64 strings in JSON.
|
||||
|
||||
Needed because TaggedModel's wrap validator converts JSON→Python validation
|
||||
context, which breaks strict-mode bytes deserialization from JSON strings.
|
||||
"""
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
"""
|
||||
Newtype around `ID`
|
||||
@@ -102,6 +126,13 @@ class InputChunkReceived(BaseEvent):
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
model: ModelId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -132,6 +163,25 @@ class TracesMerged(BaseEvent):
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelData(BaseEvent):
|
||||
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
runner_id: RunnerId
|
||||
sequence: int
|
||||
data: Base64Bytes
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelGathered(BaseEvent):
|
||||
"""Gathered result of a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
sequence: int
|
||||
gathered_data: Mapping[RunnerId, Base64Bytes]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -148,10 +198,13 @@ Event = (
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
| TracesMerged
|
||||
| JacclSideChannelData
|
||||
| JacclSideChannelGathered
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -67,3 +67,8 @@ class ToolCallResponse(BaseRunnerResponse):
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
|
||||
class PrefillProgressResponse(BaseRunnerResponse):
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
@@ -51,6 +51,10 @@ generation_stream = mx.new_stream(mx.default_device())
|
||||
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
|
||||
|
||||
|
||||
class PrefillCancelled(BaseException):
|
||||
"""Raised when prefill is cancelled via the progress callback."""
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -58,6 +62,7 @@ def prefill(
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -65,7 +70,7 @@ def prefill(
|
||||
then trims off the extra generated token.
|
||||
|
||||
Returns:
|
||||
tokens_per_sec
|
||||
(tokens_per_sec, num_tokens, snapshots)
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
@@ -76,6 +81,7 @@ def prefill(
|
||||
has_ssm = has_non_kv_caches(cache)
|
||||
snapshots: list[CacheSnapshot] = []
|
||||
|
||||
# TODO(evan): kill the callbacks/runner refactor
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
@@ -85,6 +91,9 @@ def prefill(
|
||||
if has_ssm:
|
||||
snapshots.append(snapshot_ssm_states(cache))
|
||||
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
|
||||
mx_barrier(group)
|
||||
@@ -92,19 +101,23 @@ def prefill(
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=8192,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
try:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=4096,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
except PrefillCancelled:
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
raise
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
|
||||
@@ -257,6 +270,7 @@ def mlx_generate(
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -311,7 +325,13 @@ def mlx_generate(
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
group,
|
||||
on_prefill_progress,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -407,6 +408,56 @@ def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
|
||||
func["arguments"] = json.loads(args)
|
||||
|
||||
|
||||
def _collect_nested_property_names(schema: dict[str, Any]) -> set[str]:
|
||||
names: set[str] = set()
|
||||
properties: dict[str, Any] = schema.get("properties", {}) # type: ignore[reportAny]
|
||||
for prop_spec in properties.values(): # pyright: ignore[reportAny]
|
||||
if not isinstance(prop_spec, dict):
|
||||
continue
|
||||
if prop_spec.get("type") == "array": # type: ignore[reportAny]
|
||||
items: dict[str, Any] | None = prop_spec.get("items") # type: ignore[reportAny]
|
||||
if isinstance(items, dict) and items.get("type") == "object": # type: ignore[reportAny]
|
||||
inner_props: dict[str, Any] = items.get("properties", {}) # type: ignore[reportAny]
|
||||
for k in inner_props: # pyright: ignore[reportUnknownVariableType]
|
||||
names.add(str(k)) # pyright: ignore[reportUnknownArgumentType]
|
||||
names.update(_collect_nested_property_names(items)) # pyright: ignore[reportUnknownArgumentType]
|
||||
return names
|
||||
|
||||
|
||||
def _schemas_lost_in_prompt(prompt: str, tools: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if nested property names from any tool schema are absent."""
|
||||
for tool in tools:
|
||||
fn: dict[str, Any] = tool.get("function", {}) # type: ignore
|
||||
params: dict[str, Any] = fn.get("parameters", {}) # type: ignore
|
||||
nested = _collect_nested_property_names(params)
|
||||
if nested and not all(name in prompt for name in nested):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_LOSSY_TEMPLATE_PATTERN = re.compile(
|
||||
r"""inner_type\s*==\s*["']object \| object["']\s*or\s*inner_type\|length\s*>\s*\d+""",
|
||||
)
|
||||
|
||||
|
||||
def _patch_lossy_chat_template(template: str) -> str | None:
|
||||
"""Patch chat templates that collapse nested object schemas to ``any[]``.
|
||||
|
||||
Some templates (e.g., GPT-OSS) have a guard like::
|
||||
|
||||
inner_type == "object | object" or inner_type|length > 50
|
||||
|
||||
The length check silently drops complex array-of-object schemas.
|
||||
We remove the length guard, keeping only the object-union check.
|
||||
Returns the patched template, or *None* if no patch was needed.
|
||||
"""
|
||||
patched, n = _LOSSY_TEMPLATE_PATTERN.subn(
|
||||
lambda m: m.group(0).split(" or ")[0], # keep only the object-union check
|
||||
template,
|
||||
)
|
||||
return patched if n > 0 else None
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task_params: TextGenerationTaskParams,
|
||||
@@ -453,14 +504,28 @@ def apply_chat_template(
|
||||
extra_kwargs["enable_thinking"] = task_params.enable_thinking
|
||||
extra_kwargs["thinking"] = task_params.enable_thinking
|
||||
|
||||
patched_template: str | None = None
|
||||
if task_params.tools:
|
||||
original_template: str | None = getattr(tokenizer, "chat_template", None)
|
||||
if isinstance(original_template, str):
|
||||
patched_template = _patch_lossy_chat_template(original_template)
|
||||
if patched_template is not None:
|
||||
logger.info(
|
||||
"Patched lossy chat template (removed inner_type length guard)"
|
||||
)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=task_params.tools,
|
||||
**({"chat_template": patched_template} if patched_template is not None else {}),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if task_params.tools and _schemas_lost_in_prompt(prompt, task_params.tools):
|
||||
logger.warning("Chat template lost nested tool schemas even after patching")
|
||||
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
|
||||
@@ -578,6 +643,11 @@ def mlx_cleanup(
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
"""Synchronize a boolean across all distributed nodes.
|
||||
|
||||
Returns True if any node has bool_=True. Uses all_sum so every
|
||||
node participates in the collective — preventing GPU deadlocks.
|
||||
"""
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
|
||||
@@ -24,6 +24,7 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
JacclSideChannelGathered,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -159,6 +160,15 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
|
||||
if isinstance(event, JacclSideChannelGathered):
|
||||
for runner in self.runners.values():
|
||||
if (
|
||||
runner.bound_instance.instance.instance_id
|
||||
== event.instance_id
|
||||
):
|
||||
runner.notify_gathered(event)
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
@@ -241,6 +251,11 @@ class Worker:
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
|
||||
@@ -17,6 +17,7 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
pipe_fifo_paths: tuple[str, str] | None = None,
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
@@ -30,6 +31,16 @@ def entrypoint(
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
|
||||
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
|
||||
if pipe_fifo_paths is not None:
|
||||
fifo_c2p, fifo_p2c = pipe_fifo_paths
|
||||
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
|
||||
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
|
||||
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
|
||||
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
|
||||
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
@@ -81,7 +82,11 @@ from exo.worker.engines.image import (
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.generator.generate import (
|
||||
PrefillCancelled,
|
||||
mlx_generate,
|
||||
warmup_inference,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
@@ -298,6 +303,32 @@ def main(
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
# Define callback to send prefill progress events
|
||||
# and check for cancellation between prefill chunks.
|
||||
# TODO(evan): kill the callbacks/runner refactor
|
||||
# Specifically the part that this is literally duplicated code.
|
||||
def on_prefill_progress(
|
||||
processed: int,
|
||||
total: int,
|
||||
_task_id: TaskId = task.task_id,
|
||||
_group: mx.distributed.Group | None = group,
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
PrefillProgress(
|
||||
command_id=command_id,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
)
|
||||
)
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (_task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, _group):
|
||||
raise PrefillCancelled()
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
|
||||
@@ -311,6 +342,7 @@ def main(
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
group=group,
|
||||
)
|
||||
|
||||
@@ -392,6 +424,8 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
except PrefillCancelled:
|
||||
logger.info(f"Prefill cancelled for task {task.task_id}")
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if device_rank == 0:
|
||||
@@ -599,11 +633,21 @@ def parse_gpt_oss(
|
||||
ch = stream.current_channel
|
||||
recipient = stream.current_recipient
|
||||
|
||||
# Debug: log every token with state
|
||||
logger.debug(
|
||||
f"parse_gpt_oss token={response.token} text={response.text!r} "
|
||||
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
|
||||
f"state={stream.state} current_tool={current_tool_name!r}"
|
||||
)
|
||||
|
||||
if recipient != current_tool_name:
|
||||
if current_tool_name is not None:
|
||||
prefix = "functions."
|
||||
if current_tool_name.startswith(prefix):
|
||||
current_tool_name = current_tool_name[len(prefix) :]
|
||||
logger.info(
|
||||
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
|
||||
)
|
||||
yield ToolCallResponse(
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from multiprocessing import Process
|
||||
from typing import Self
|
||||
|
||||
@@ -14,12 +18,14 @@ from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
@@ -34,6 +40,26 @@ from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
|
||||
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
|
||||
data = b""
|
||||
while len(data) < n:
|
||||
chunk = os.read(fd, n - len(data))
|
||||
if not chunk:
|
||||
return None
|
||||
data += chunk
|
||||
return data
|
||||
|
||||
|
||||
def _pipe_write_all(fd: int, data: bytes) -> None:
|
||||
"""Write all bytes to a file descriptor."""
|
||||
view = memoryview(data)
|
||||
while view:
|
||||
written = os.write(fd, view)
|
||||
view = view[written:]
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
DECODE_TIMEOUT_SECONDS = 5
|
||||
|
||||
@@ -48,10 +74,19 @@ class RunnerSupervisor:
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_pipe_read_fd: int | None = None # Python reads runner's pipe output
|
||||
_pipe_write_fd: int | None = None # Python writes gathered data to runner
|
||||
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
|
||||
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
|
||||
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
|
||||
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
_gathered_waiters: dict[
|
||||
int, tuple[anyio.Event, JacclSideChannelGathered | None]
|
||||
] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -65,6 +100,23 @@ class RunnerSupervisor:
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
|
||||
# Named pipes work across multiprocessing.Process spawn (macOS default).
|
||||
# FIFO c2p: C++ writes local data → Python reads it
|
||||
# FIFO p2c: Python writes gathered data → C++ reads it
|
||||
fifo_dir: str | None = None
|
||||
fifo_c2p: str | None = None
|
||||
fifo_p2c: str | None = None
|
||||
pipe_fifo_paths: tuple[str, str] | None = None
|
||||
|
||||
if isinstance(bound_instance.instance, MlxJacclInstance):
|
||||
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
|
||||
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
|
||||
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
|
||||
os.mkfifo(fifo_c2p)
|
||||
os.mkfifo(fifo_p2c)
|
||||
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
@@ -73,6 +125,7 @@ class RunnerSupervisor:
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
pipe_fifo_paths,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
@@ -88,22 +141,58 @@ class RunnerSupervisor:
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
_fifo_dir=fifo_dir,
|
||||
_fifo_c2p=fifo_c2p,
|
||||
_fifo_p2c=fifo_p2c,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
await self._forward_events()
|
||||
|
||||
if self._fifo_c2p is not None and self._fifo_p2c is not None:
|
||||
# Open FIFOs from parent side. These block until child opens the other end,
|
||||
# so we run them in threads concurrently to avoid deadlock.
|
||||
fifo_c2p = self._fifo_c2p
|
||||
fifo_p2c = self._fifo_p2c
|
||||
|
||||
async def open_read() -> None:
|
||||
self._pipe_read_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_c2p, os.O_RDONLY)
|
||||
)
|
||||
|
||||
async def open_write() -> None:
|
||||
self._pipe_write_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_p2c, os.O_WRONLY)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as open_tg:
|
||||
open_tg.start_soon(open_read)
|
||||
open_tg.start_soon(open_write)
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(self._pipe_relay)
|
||||
tg.start_soon(self._forward_events)
|
||||
else:
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
try:
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
except ClosedResourceError:
|
||||
pass
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
self._close_pipe_fds()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
@@ -140,6 +229,7 @@ class RunnerSupervisor:
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
"""Send a cancellation signal to the runner process."""
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
@@ -181,6 +271,110 @@ class RunnerSupervisor:
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
|
||||
def _close_pipe_fds(self) -> None:
|
||||
if self._pipe_read_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_read_fd)
|
||||
self._pipe_read_fd = None
|
||||
if self._pipe_write_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_write_fd)
|
||||
self._pipe_write_fd = None
|
||||
if self._child_pipe_fds is not None:
|
||||
for fd in self._child_pipe_fds:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
self._child_pipe_fds = None
|
||||
# Clean up FIFO files
|
||||
if self._fifo_c2p is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_c2p)
|
||||
self._fifo_c2p = None
|
||||
if self._fifo_p2c is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_p2c)
|
||||
self._fifo_p2c = None
|
||||
if self._fifo_dir is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.rmdir(self._fifo_dir)
|
||||
self._fifo_dir = None
|
||||
|
||||
async def _pipe_relay(self) -> None:
|
||||
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
|
||||
assert self._pipe_read_fd is not None
|
||||
assert self._pipe_write_fd is not None
|
||||
read_fd = self._pipe_read_fd
|
||||
write_fd = self._pipe_write_fd
|
||||
sequence = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 1. Read local data from runner: [uint32 size][size bytes]
|
||||
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
|
||||
if header is None:
|
||||
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
|
||||
break
|
||||
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
|
||||
local_data = await to_thread.run_sync(
|
||||
partial(_pipe_read_exact, read_fd, data_size)
|
||||
)
|
||||
if local_data is None:
|
||||
logger.warning("JACCL pipe relay: EOF reading data payload")
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
|
||||
)
|
||||
|
||||
# 2. Emit JacclSideChannelData event
|
||||
waiter = anyio.Event()
|
||||
self._gathered_waiters[sequence] = (waiter, None)
|
||||
await self._event_sender.send(
|
||||
JacclSideChannelData(
|
||||
instance_id=self.bound_instance.instance.instance_id,
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
sequence=sequence,
|
||||
data=local_data,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Wait for gathered result
|
||||
await waiter.wait()
|
||||
_, gathered_event = self._gathered_waiters.pop(sequence)
|
||||
assert gathered_event is not None
|
||||
|
||||
# 4. Order gathered data by runner rank and concatenate
|
||||
instance = self.bound_instance.instance
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
|
||||
ordered_data = b"".join(
|
||||
gathered_event.gathered_data[rid] for rid in runner_order
|
||||
)
|
||||
|
||||
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
|
||||
total_size = len(ordered_data)
|
||||
response = struct.pack("<I", total_size) + ordered_data
|
||||
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
|
||||
)
|
||||
sequence += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"JACCL pipe relay: OS error: {e}")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
|
||||
|
||||
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
|
||||
"""Called by the worker when a JacclSideChannelGathered event arrives."""
|
||||
seq = event.sequence
|
||||
if seq not in self._gathered_waiters:
|
||||
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
|
||||
return
|
||||
waiter, _ = self._gathered_waiters[seq]
|
||||
self._gathered_waiters[seq] = (waiter, event)
|
||||
waiter.set()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||
|
||||
Reference in New Issue
Block a user