mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
Compare commits
3 Commits
move-messa
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
025ed9fd82 | ||
|
|
19bc09550d | ||
|
|
7cadca4f27 |
@@ -38,6 +38,8 @@ class Scenario:
|
||||
expected_function: str | None = None
|
||||
required_arg_keys: list[str] | None = None
|
||||
tool_result: str | None = None
|
||||
nested_array_key: str | None = None
|
||||
required_item_keys: list[str] | None = None
|
||||
|
||||
|
||||
def load_scenarios(path: Path) -> list[Scenario]:
|
||||
@@ -105,6 +107,8 @@ def load_scenarios(path: Path) -> list[Scenario]:
|
||||
expected_function=s.get("expected_function"),
|
||||
required_arg_keys=s.get("required_arg_keys"),
|
||||
tool_result=tool_result,
|
||||
nested_array_key=s.get("nested_array_key"),
|
||||
required_item_keys=s.get("required_item_keys"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -147,6 +151,35 @@ def validate_args(args_str: str, required_keys: list[str]) -> tuple[bool, str |
|
||||
return True, None
|
||||
|
||||
|
||||
def validate_nested_args(
|
||||
args_str: str,
|
||||
array_key: str,
|
||||
required_item_keys: list[str],
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check that args[array_key] is a list of objects with required keys."""
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
return False, f"Invalid JSON: {exc}"
|
||||
if not isinstance(args, dict):
|
||||
return False, f"Expected dict, got {type(args).__name__}"
|
||||
arr = args.get(array_key)
|
||||
if not isinstance(arr, list):
|
||||
return False, f"'{array_key}' is not an array (got {type(arr).__name__})"
|
||||
if len(arr) == 0:
|
||||
return False, f"'{array_key}' is empty"
|
||||
for i, item in enumerate(arr):
|
||||
if not isinstance(item, dict):
|
||||
return (
|
||||
False,
|
||||
f"'{array_key}[{i}]' is not an object (got {type(item).__name__})",
|
||||
)
|
||||
missing = [k for k in required_item_keys if k not in item]
|
||||
if missing:
|
||||
return False, f"'{array_key}[{i}]' missing keys: {missing}"
|
||||
return True, None
|
||||
|
||||
|
||||
def call_api(
|
||||
client: httpx.Client,
|
||||
host: str,
|
||||
@@ -699,6 +732,15 @@ def run_scenario(
|
||||
checks["valid_arguments"] = ok
|
||||
else:
|
||||
checks["valid_arguments"] = True
|
||||
if scenario.nested_array_key and scenario.required_item_keys:
|
||||
ok, nested_err = validate_nested_args(
|
||||
parsed.tool_call["arguments"],
|
||||
scenario.nested_array_key,
|
||||
scenario.required_item_keys,
|
||||
)
|
||||
checks["valid_nested_structure"] = ok
|
||||
if not ok:
|
||||
args_err = nested_err
|
||||
else:
|
||||
checks["correct_function"] = False
|
||||
checks["valid_arguments"] = False
|
||||
|
||||
@@ -39,6 +39,30 @@ description = "Product category to filter by"
|
||||
type = "number"
|
||||
description = "Maximum price in USD"
|
||||
|
||||
[tools.create_todos]
|
||||
description = "Create a structured todo list"
|
||||
required = ["todos"]
|
||||
|
||||
[tools.create_todos.properties.todos]
|
||||
type = "array"
|
||||
description = "List of todo items"
|
||||
|
||||
[tools.create_todos.properties.todos.items]
|
||||
type = "object"
|
||||
required = ["content", "status", "priority"]
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.content]
|
||||
type = "string"
|
||||
description = "The todo item text"
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.status]
|
||||
type = "string"
|
||||
description = "Status: pending, in_progress, or completed"
|
||||
|
||||
[tools.create_todos.properties.todos.items.properties.priority]
|
||||
type = "string"
|
||||
description = "Priority: low, normal, or high"
|
||||
|
||||
# -- Should call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
@@ -219,6 +243,48 @@ role = "tool"
|
||||
tool_call_id = "call_4"
|
||||
content = '{"temperature": "18C", "condition": "cloudy"}'
|
||||
|
||||
# -- Nested object schema (regression for lossy chat template rendering) --
|
||||
|
||||
[[scenarios]]
|
||||
name = "nested_schema_tool_call"
|
||||
description = "Tool call with nested object array schema -> create_todos"
|
||||
expect_tool_call = true
|
||||
expected_function = "create_todos"
|
||||
required_arg_keys = ["todos"]
|
||||
nested_array_key = "todos"
|
||||
required_item_keys = ["content", "status", "priority"]
|
||||
tools = ["create_todos"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Create a todo list with 3 items to learn Python"
|
||||
|
||||
# -- Tool name integrity (regression for harmony token leaking into name) --
|
||||
|
||||
[tools.glob]
|
||||
description = "Search for files matching a glob pattern in the codebase"
|
||||
required = ["pattern"]
|
||||
|
||||
[tools.glob.properties.pattern]
|
||||
type = "string"
|
||||
description = "The glob pattern to match files against, e.g. '**/*.py'"
|
||||
|
||||
[tools.glob.properties.path]
|
||||
type = "string"
|
||||
description = "The directory to search in"
|
||||
|
||||
[[scenarios]]
|
||||
name = "tool_name_integrity"
|
||||
description = "Tool name must not contain harmony tokens like <|channel|>"
|
||||
expect_tool_call = true
|
||||
expected_function = "glob"
|
||||
required_arg_keys = ["pattern"]
|
||||
tools = ["glob"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Find all Python files in the src directory"
|
||||
|
||||
# -- Should NOT call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
|
||||
@@ -3,16 +3,17 @@
|
||||
messages,
|
||||
currentResponse,
|
||||
isLoading,
|
||||
prefillProgress,
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken,
|
||||
setEditingImage,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import type { Message } from "$lib/stores/app.svelte";
|
||||
import type { MessageAttachment } from "$lib/stores/app.svelte";
|
||||
import MarkdownContent from "./MarkdownContent.svelte";
|
||||
import TokenHeatmap from "./TokenHeatmap.svelte";
|
||||
import PrefillProgressBar from "./PrefillProgressBar.svelte";
|
||||
import ImageLightbox from "./ImageLightbox.svelte";
|
||||
|
||||
interface Props {
|
||||
@@ -25,6 +26,7 @@
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
const prefill = $derived(prefillProgress());
|
||||
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
@@ -428,6 +430,9 @@
|
||||
{:else}
|
||||
<!-- Assistant message styling -->
|
||||
<div class="p-3 sm:p-4">
|
||||
{#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}
|
||||
<PrefillProgressBar progress={prefill} class="mb-3" />
|
||||
{/if}
|
||||
{#if message.thinking && message.thinking.trim().length > 0}
|
||||
<div
|
||||
class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40"
|
||||
|
||||
@@ -26,7 +26,8 @@
|
||||
downloadedOnNodes = [],
|
||||
}: HuggingFaceResultItemProps = $props();
|
||||
|
||||
function formatNumber(num: number): string {
|
||||
function formatNumber(num: number | undefined): string {
|
||||
if (num == null) return "0";
|
||||
if (num >= 1000000) {
|
||||
return `${(num / 1000000).toFixed(1)}M`;
|
||||
} else if (num >= 1000) {
|
||||
|
||||
52
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
52
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
@@ -0,0 +1,52 @@
|
||||
<script lang="ts">
|
||||
import type { PrefillProgress } from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
progress: PrefillProgress;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { progress, class: className = "" }: Props = $props();
|
||||
|
||||
const percentage = $derived(
|
||||
progress.total > 0
|
||||
? Math.round((progress.processed / progress.total) * 100)
|
||||
: 0,
|
||||
);
|
||||
|
||||
function formatTokenCount(count: number | undefined): string {
|
||||
if (count == null) return "0";
|
||||
if (count >= 1000) {
|
||||
return `${(count / 1000).toFixed(1)}k`;
|
||||
}
|
||||
return count.toString();
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="prefill-progress {className}">
|
||||
<div
|
||||
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
|
||||
>
|
||||
<span>Processing prompt</span>
|
||||
<span class="font-mono">
|
||||
{formatTokenCount(progress.processed)} / {formatTokenCount(
|
||||
progress.total,
|
||||
)} tokens
|
||||
</span>
|
||||
</div>
|
||||
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
|
||||
style="width: {percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
|
||||
{percentage}%
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.prefill-progress {
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
@@ -273,6 +273,11 @@ export interface TokenData {
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface PrefillProgress {
|
||||
processed: number;
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -520,6 +525,7 @@ class AppStore {
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
tps = $state<number | null>(null); // Tokens per second
|
||||
totalTokens = $state<number>(0); // Total tokens in current response
|
||||
prefillProgress = $state<PrefillProgress | null>(null);
|
||||
|
||||
// Topology state
|
||||
topologyData = $state<TopologyData | null>(null);
|
||||
@@ -2005,6 +2011,7 @@ class AppStore {
|
||||
reader: ReadableStreamDefaultReader<Uint8Array>,
|
||||
targetConversationId: string,
|
||||
onChunk: (parsed: T) => void,
|
||||
onEvent?: Record<string, (data: unknown) => void>,
|
||||
): Promise<void> {
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
@@ -2025,6 +2032,24 @@ class AppStore {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) continue;
|
||||
|
||||
// Handle SSE comments (": key json") for prefill progress etc.
|
||||
if (trimmed.startsWith(": ") && onEvent) {
|
||||
const comment = trimmed.slice(2);
|
||||
const spaceIdx = comment.indexOf(" ");
|
||||
if (spaceIdx > 0) {
|
||||
const key = comment.slice(0, spaceIdx);
|
||||
if (onEvent[key]) {
|
||||
try {
|
||||
const parsed = JSON.parse(comment.slice(spaceIdx + 1));
|
||||
onEvent[key](parsed);
|
||||
} catch {
|
||||
// Skip malformed JSON in comment
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (trimmed.startsWith("data: ")) {
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
@@ -2309,6 +2334,11 @@ class AppStore {
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
// Clear prefill progress when first token data arrives
|
||||
if (this.prefillProgress) {
|
||||
this.prefillProgress = null;
|
||||
}
|
||||
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
|
||||
@@ -2371,8 +2401,26 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
},
|
||||
{
|
||||
prefill_progress: (data) => {
|
||||
// TaggedModel wraps as {"PrefillProgressChunk": {...}}
|
||||
// model_dump_json() uses snake_case (by_alias defaults to False)
|
||||
const raw = data as Record<string, unknown>;
|
||||
const inner = (raw["PrefillProgressChunk"] ?? raw) as {
|
||||
processed_tokens: number;
|
||||
total_tokens: number;
|
||||
};
|
||||
this.prefillProgress = {
|
||||
processed: inner.processed_tokens,
|
||||
total: inner.total_tokens,
|
||||
};
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Clear prefill progress after stream ends
|
||||
this.prefillProgress = null;
|
||||
|
||||
// Calculate final TPS
|
||||
if (firstTokenTime !== null && tokenCount > 1) {
|
||||
const totalGenerationTime = performance.now() - firstTokenTime;
|
||||
@@ -3043,6 +3091,7 @@ export const isLoading = () => appStore.isLoading;
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const prefillProgress = () => appStore.prefillProgress;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const runners = () => appStore.runners;
|
||||
|
||||
@@ -932,13 +932,6 @@
|
||||
};
|
||||
}
|
||||
|
||||
// Debug: Log downloads data when it changes
|
||||
$effect(() => {
|
||||
if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
console.log("[Download Debug] Current downloads:", downloadsData);
|
||||
}
|
||||
});
|
||||
|
||||
// Helper to get download status for an instance
|
||||
function getInstanceDownloadStatus(
|
||||
instanceId: string,
|
||||
|
||||
@@ -123,14 +123,17 @@ class DownloadCoordinator:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
|
||||
def _test_internet_connection(self) -> None:
|
||||
try:
|
||||
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
|
||||
self.shard_downloader.set_internet_connection(True)
|
||||
except OSError:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
logger.debug(
|
||||
f"Internet connectivity: {self.shard_downloader.internet_connection}"
|
||||
)
|
||||
# Try multiple endpoints since some ISPs/networks block specific IPs
|
||||
for host in ("1.1.1.1", "8.8.8.8", "1.0.0.1"):
|
||||
try:
|
||||
socket.create_connection((host, 443), timeout=3).close()
|
||||
self.shard_downloader.set_internet_connection(True)
|
||||
logger.debug(f"Internet connectivity: True (via {host})")
|
||||
return
|
||||
except OSError:
|
||||
continue
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
logger.debug("Internet connectivity: False")
|
||||
|
||||
async def _check_internet_connection(self) -> None:
|
||||
first_connection = True
|
||||
|
||||
@@ -19,7 +19,12 @@ from exo.shared.types.api import (
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
@@ -123,67 +128,81 @@ def chunk_to_response(
|
||||
|
||||
async def generate_chat_stream(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
PrefillProgressChunk | ErrorChunk | ToolCallChunk | TokenChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
# Use SSE comment so third-party clients ignore it
|
||||
yield f": prefill_progress {chunk.model_dump_json()}\n\n"
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
case ErrorChunk():
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
case ToolCallChunk():
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(
|
||||
update={"usage": last_usage}
|
||||
)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -197,38 +216,43 @@ async def collect_chat_response(
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
case ErrorChunk():
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
)
|
||||
)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
@@ -5,7 +5,12 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import FinishReason, Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
@@ -160,7 +165,9 @@ def claude_request_to_text_generation(
|
||||
async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -172,6 +179,9 @@ async def collect_claude_response(
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
@@ -230,7 +240,9 @@ async def collect_claude_response(
|
||||
async def generate_claude_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Claude Messages API streaming events from TokenChunks."""
|
||||
# Initial message_start event
|
||||
@@ -256,6 +268,9 @@ async def generate_claude_stream(
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
@@ -5,7 +5,12 @@ from itertools import count
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
FunctionCallInputItem,
|
||||
@@ -121,7 +126,9 @@ def responses_request_to_text_generation(
|
||||
async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
@@ -134,6 +141,9 @@ async def collect_responses_response(
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
@@ -189,7 +199,9 @@ async def collect_responses_response(
|
||||
async def generate_responses_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
@@ -243,6 +255,9 @@ async def generate_responses_stream(
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
@@ -137,6 +138,7 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -145,6 +147,7 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
@@ -220,7 +223,8 @@ class API:
|
||||
)
|
||||
|
||||
self._text_generation_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
CommandId,
|
||||
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
@@ -526,19 +530,23 @@ class API:
|
||||
|
||||
async def _token_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
) -> AsyncGenerator[
|
||||
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
|
||||
]:
|
||||
"""Yield chunks for a given command until completion.
|
||||
|
||||
This is the internal low-level stream used by all API adapters.
|
||||
"""
|
||||
try:
|
||||
self._text_generation_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
@@ -565,6 +573,9 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._token_chunk_stream(command_id):
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
continue
|
||||
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -1292,8 +1303,18 @@ class API:
|
||||
|
||||
return total_available
|
||||
|
||||
async def get_models(self) -> ModelList:
|
||||
"""Returns list of available models."""
|
||||
async def get_models(self, status: str | None = Query(default=None)) -> ModelList:
|
||||
"""Returns list of available models, optionally filtered by being downloaded."""
|
||||
cards = await get_model_cards()
|
||||
|
||||
if status == "downloaded":
|
||||
downloaded_model_ids: set[str] = set()
|
||||
for node_downloads in self.state.downloads.values():
|
||||
for dl in node_downloads:
|
||||
if isinstance(dl, DownloadCompleted):
|
||||
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
|
||||
cards = [c for c in cards if c.model_id in downloaded_model_ids]
|
||||
|
||||
return ModelList(
|
||||
data=[
|
||||
ModelListModel(
|
||||
@@ -1311,7 +1332,7 @@ class API:
|
||||
base_model=card.base_model,
|
||||
capabilities=card.capabilities,
|
||||
)
|
||||
for card in await get_model_cards()
|
||||
for card in cards
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1435,6 +1456,21 @@ class API:
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
elif isinstance(event, PrefillProgress):
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
try:
|
||||
await queue.send(
|
||||
PrefillProgressChunk(
|
||||
model=event.model,
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from exo.shared.types.events import (
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -64,6 +65,7 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
): # Pass-through events that don't modify state
|
||||
|
||||
@@ -76,4 +76,13 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
class PrefillProgressChunk(BaseChunk):
|
||||
"""Data class for prefill progress events during streaming."""
|
||||
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
GenerationChunk = (
|
||||
TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk | PrefillProgressChunk
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -102,6 +102,13 @@ class InputChunkReceived(BaseEvent):
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
model: ModelId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -148,6 +155,7 @@ Event = (
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
|
||||
@@ -67,3 +67,8 @@ class ToolCallResponse(BaseRunnerResponse):
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
|
||||
class PrefillProgressResponse(BaseRunnerResponse):
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
@@ -58,6 +58,7 @@ def prefill(
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
@@ -85,6 +86,9 @@ def prefill(
|
||||
if has_ssm:
|
||||
snapshots.append(snapshot_ssm_states(cache))
|
||||
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
|
||||
mx_barrier(group)
|
||||
@@ -99,7 +103,7 @@ def prefill(
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=8192,
|
||||
prefill_step_size=4096,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
@@ -257,6 +261,7 @@ def mlx_generate(
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -311,7 +316,13 @@ def mlx_generate(
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
group,
|
||||
on_prefill_progress,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -407,6 +408,56 @@ def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
|
||||
func["arguments"] = json.loads(args)
|
||||
|
||||
|
||||
def _collect_nested_property_names(schema: dict[str, Any]) -> set[str]:
|
||||
names: set[str] = set()
|
||||
properties: dict[str, Any] = schema.get("properties", {}) # type: ignore[reportAny]
|
||||
for prop_spec in properties.values(): # pyright: ignore[reportAny]
|
||||
if not isinstance(prop_spec, dict):
|
||||
continue
|
||||
if prop_spec.get("type") == "array": # type: ignore[reportAny]
|
||||
items: dict[str, Any] | None = prop_spec.get("items") # type: ignore[reportAny]
|
||||
if isinstance(items, dict) and items.get("type") == "object": # type: ignore[reportAny]
|
||||
inner_props: dict[str, Any] = items.get("properties", {}) # type: ignore[reportAny]
|
||||
for k in inner_props: # pyright: ignore[reportUnknownVariableType]
|
||||
names.add(str(k)) # pyright: ignore[reportUnknownArgumentType]
|
||||
names.update(_collect_nested_property_names(items)) # pyright: ignore[reportUnknownArgumentType]
|
||||
return names
|
||||
|
||||
|
||||
def _schemas_lost_in_prompt(prompt: str, tools: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if nested property names from any tool schema are absent."""
|
||||
for tool in tools:
|
||||
fn: dict[str, Any] = tool.get("function", {}) # type: ignore
|
||||
params: dict[str, Any] = fn.get("parameters", {}) # type: ignore
|
||||
nested = _collect_nested_property_names(params)
|
||||
if nested and not all(name in prompt for name in nested):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_LOSSY_TEMPLATE_PATTERN = re.compile(
|
||||
r"""inner_type\s*==\s*["']object \| object["']\s*or\s*inner_type\|length\s*>\s*\d+""",
|
||||
)
|
||||
|
||||
|
||||
def _patch_lossy_chat_template(template: str) -> str | None:
|
||||
"""Patch chat templates that collapse nested object schemas to ``any[]``.
|
||||
|
||||
Some templates (e.g., GPT-OSS) have a guard like::
|
||||
|
||||
inner_type == "object | object" or inner_type|length > 50
|
||||
|
||||
The length check silently drops complex array-of-object schemas.
|
||||
We remove the length guard, keeping only the object-union check.
|
||||
Returns the patched template, or *None* if no patch was needed.
|
||||
"""
|
||||
patched, n = _LOSSY_TEMPLATE_PATTERN.subn(
|
||||
lambda m: m.group(0).split(" or ")[0], # keep only the object-union check
|
||||
template,
|
||||
)
|
||||
return patched if n > 0 else None
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task_params: TextGenerationTaskParams,
|
||||
@@ -453,14 +504,28 @@ def apply_chat_template(
|
||||
extra_kwargs["enable_thinking"] = task_params.enable_thinking
|
||||
extra_kwargs["thinking"] = task_params.enable_thinking
|
||||
|
||||
patched_template: str | None = None
|
||||
if task_params.tools:
|
||||
original_template: str | None = getattr(tokenizer, "chat_template", None)
|
||||
if isinstance(original_template, str):
|
||||
patched_template = _patch_lossy_chat_template(original_template)
|
||||
if patched_template is not None:
|
||||
logger.info(
|
||||
"Patched lossy chat template (removed inner_type length guard)"
|
||||
)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=task_params.tools,
|
||||
**({"chat_template": patched_template} if patched_template is not None else {}),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if task_params.tools and _schemas_lost_in_prompt(prompt, task_params.tools):
|
||||
logger.warning("Chat template lost nested tool schemas even after patching")
|
||||
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
@@ -298,6 +299,18 @@ def main(
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
# Define callback to send prefill progress events directly
|
||||
def on_prefill_progress(processed: int, total: int) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
PrefillProgress(
|
||||
command_id=command_id,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
|
||||
@@ -311,6 +324,7 @@ def main(
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
group=group,
|
||||
)
|
||||
|
||||
@@ -599,11 +613,21 @@ def parse_gpt_oss(
|
||||
ch = stream.current_channel
|
||||
recipient = stream.current_recipient
|
||||
|
||||
# Debug: log every token with state
|
||||
logger.debug(
|
||||
f"parse_gpt_oss token={response.token} text={response.text!r} "
|
||||
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
|
||||
f"state={stream.state} current_tool={current_tool_name!r}"
|
||||
)
|
||||
|
||||
if recipient != current_tool_name:
|
||||
if current_tool_name is not None:
|
||||
prefix = "functions."
|
||||
if current_tool_name.startswith(prefix):
|
||||
current_tool_name = current_tool_name[len(prefix) :]
|
||||
logger.info(
|
||||
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
|
||||
)
|
||||
yield ToolCallResponse(
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
|
||||
Reference in New Issue
Block a user